1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
you-n-g be4646b4b7 Adjust rolling api (#1594)
* Intermediate version

* Fix yaml template & Successfully run rolling

* Be compatible with benchmark

* Get same results with previous linear model

* Black formatting

* Update black

* Update the placeholder mechanism

* Update CI

* Update CI

* Upgrade Black

* Fix CI and simplify code

* Fix CI

* Move the data processing caching mechanism into utils.

* Adjusting DDG-DA

* Organize import
2023-07-14 12:16:12 +08:00

223 lines
7.7 KiB
Python

# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Default data formatting functions for experiments.
For new datasets, inherit form GenericDataFormatter and implement
all abstract functions.
These dataset-specific methods:
1) Define the column and input types for tabular dataframes used by model
2) Perform the necessary input feature engineering & normalisation steps
3) Reverts the normalisation for predictions
4) Are responsible for train, validation and test splits
"""
import abc
import enum
# Type definitions
class DataTypes(enum.IntEnum):
"""Defines numerical types of each column."""
REAL_VALUED = 0
CATEGORICAL = 1
DATE = 2
class InputTypes(enum.IntEnum):
"""Defines input types of each column."""
TARGET = 0
OBSERVED_INPUT = 1
KNOWN_INPUT = 2
STATIC_INPUT = 3
ID = 4 # Single column used as an entity identifier
TIME = 5 # Single column exclusively used as a time index
class GenericDataFormatter(abc.ABC):
"""Abstract base class for all data formatters.
User can implement the abstract methods below to perform dataset-specific
manipulations.
"""
@abc.abstractmethod
def set_scalers(self, df):
"""Calibrates scalers using the data supplied."""
raise NotImplementedError()
@abc.abstractmethod
def transform_inputs(self, df):
"""Performs feature transformation."""
raise NotImplementedError()
@abc.abstractmethod
def format_predictions(self, df):
"""Reverts any normalisation to give predictions in original scale."""
raise NotImplementedError()
@abc.abstractmethod
def split_data(self, df):
"""Performs the default train, validation and test splits."""
raise NotImplementedError()
@property
@abc.abstractmethod
def _column_definition(self):
"""Defines order, input type and data type of each column."""
raise NotImplementedError()
@abc.abstractmethod
def get_fixed_params(self):
"""Defines the fixed parameters used by the model for training.
Requires the following keys:
'total_time_steps': Defines the total number of time steps used by TFT
'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
'num_epochs': Maximum number of epochs for training
'early_stopping_patience': Early stopping param for keras
'multiprocessing_workers': # of cpus for data processing
Returns:
A dictionary of fixed parameters, e.g.:
fixed_params = {
'total_time_steps': 252 + 5,
'num_encoder_steps': 252,
'num_epochs': 100,
'early_stopping_patience': 5,
'multiprocessing_workers': 5,
}
"""
raise NotImplementedError
# Shared functions across data-formatters
@property
def num_classes_per_cat_input(self):
"""Returns number of categories per relevant input.
This is seqeuently required for keras embedding layers.
"""
return self._num_classes_per_cat_input
def get_num_samples_for_calibration(self):
"""Gets the default number of training and validation samples.
Use to sub-sample the data for network calibration and a value of -1 uses
all available samples.
Returns:
Tuple of (training samples, validation samples)
"""
return -1, -1
def get_column_definition(self):
"""Returns formatted column definition in order expected by the TFT."""
column_definition = self._column_definition
# Sanity checks first.
# Ensure only one ID and time column exist
def _check_single_column(input_type):
length = len([tup for tup in column_definition if tup[2] == input_type])
if length != 1:
raise ValueError("Illegal number of inputs ({}) of type {}".format(length, input_type))
_check_single_column(InputTypes.ID)
_check_single_column(InputTypes.TIME)
identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
real_inputs = [
tup
for tup in column_definition
if tup[1] == DataTypes.REAL_VALUED and tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
categorical_inputs = [
tup
for tup in column_definition
if tup[1] == DataTypes.CATEGORICAL and tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
return identifier + time + real_inputs + categorical_inputs
def _get_input_columns(self):
"""Returns names of all input columns."""
return [tup[0] for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
def _get_tft_input_indices(self):
"""Returns the relevant indexes and input sizes required by TFT."""
# Functions
def _extract_tuples_from_data_type(data_type, defn):
return [tup for tup in defn if tup[1] == data_type and tup[2] not in {InputTypes.ID, InputTypes.TIME}]
def _get_locations(input_types, defn):
return [i for i, tup in enumerate(defn) if tup[2] in input_types]
# Start extraction
column_definition = [
tup for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, column_definition)
real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, column_definition)
locations = {
"input_size": len(self._get_input_columns()),
"output_size": len(_get_locations({InputTypes.TARGET}, column_definition)),
"category_counts": self.num_classes_per_cat_input,
"input_obs_loc": _get_locations({InputTypes.TARGET}, column_definition),
"static_input_loc": _get_locations({InputTypes.STATIC_INPUT}, column_definition),
"known_regular_inputs": _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs),
"known_categorical_inputs": _get_locations(
{InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, categorical_inputs
),
}
return locations
def get_experiment_params(self):
"""Returns fixed model parameters for experiments."""
required_keys = [
"total_time_steps",
"num_encoder_steps",
"num_epochs",
"early_stopping_patience",
"multiprocessing_workers",
]
fixed_params = self.get_fixed_params()
for k in required_keys:
if k not in fixed_params:
raise ValueError("Field {}".format(k) + " missing from fixed parameter definitions!")
fixed_params["column_definition"] = self.get_column_definition()
fixed_params.update(self._get_tft_input_indices())
return fixed_params