1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00
Files
qlib/examples/benchmarks/TFT/data_formatters/base.py
2020-11-25 20:36:28 +08:00

236 lines
7.3 KiB
Python

# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Default data formatting functions for experiments.
For new datasets, inherit form GenericDataFormatter and implement
all abstract functions.
These dataset-specific methods:
1) Define the column and input types for tabular dataframes used by model
2) Perform the necessary input feature engineering & normalisation steps
3) Reverts the normalisation for predictions
4) Are responsible for train, validation and test splits
"""
import abc
import enum
# Type defintions
class DataTypes(enum.IntEnum):
"""Defines numerical types of each column."""
REAL_VALUED = 0
CATEGORICAL = 1
DATE = 2
class InputTypes(enum.IntEnum):
"""Defines input types of each column."""
TARGET = 0
OBSERVED_INPUT = 1
KNOWN_INPUT = 2
STATIC_INPUT = 3
ID = 4 # Single column used as an entity identifier
TIME = 5 # Single column exclusively used as a time index
class GenericDataFormatter(abc.ABC):
"""Abstract base class for all data formatters.
User can implement the abstract methods below to perform dataset-specific
manipulations.
"""
@abc.abstractmethod
def set_scalers(self, df):
"""Calibrates scalers using the data supplied."""
raise NotImplementedError()
@abc.abstractmethod
def transform_inputs(self, df):
"""Performs feature transformation."""
raise NotImplementedError()
@abc.abstractmethod
def format_predictions(self, df):
"""Reverts any normalisation to give predictions in original scale."""
raise NotImplementedError()
@abc.abstractmethod
def split_data(self, df):
"""Performs the default train, validation and test splits."""
raise NotImplementedError()
@property
@abc.abstractmethod
def _column_definition(self):
"""Defines order, input type and data type of each column."""
raise NotImplementedError()
@abc.abstractmethod
def get_fixed_params(self):
"""Defines the fixed parameters used by the model for training.
Requires the following keys:
'total_time_steps': Defines the total number of time steps used by TFT
'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
'num_epochs': Maximum number of epochs for training
'early_stopping_patience': Early stopping param for keras
'multiprocessing_workers': # of cpus for data processing
Returns:
A dictionary of fixed parameters, e.g.:
fixed_params = {
'total_time_steps': 252 + 5,
'num_encoder_steps': 252,
'num_epochs': 100,
'early_stopping_patience': 5,
'multiprocessing_workers': 5,
}
"""
raise NotImplementedError
# Shared functions across data-formatters
@property
def num_classes_per_cat_input(self):
"""Returns number of categories per relevant input.
This is seqeuently required for keras embedding layers.
"""
return self._num_classes_per_cat_input
def get_num_samples_for_calibration(self):
"""Gets the default number of training and validation samples.
Use to sub-sample the data for network calibration and a value of -1 uses
all available samples.
Returns:
Tuple of (training samples, validation samples)
"""
return -1, -1
def get_column_definition(self):
""""Returns formatted column definition in order expected by the TFT."""
column_definition = self._column_definition
# Sanity checks first.
# Ensure only one ID and time column exist
def _check_single_column(input_type):
length = len([tup for tup in column_definition if tup[2] == input_type])
if length != 1:
raise ValueError('Illegal number of inputs ({}) of type {}'.format(
length, input_type))
_check_single_column(InputTypes.ID)
_check_single_column(InputTypes.TIME)
identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
real_inputs = [
tup for tup in column_definition if tup[1] == DataTypes.REAL_VALUED and
tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
categorical_inputs = [
tup for tup in column_definition if tup[1] == DataTypes.CATEGORICAL and
tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
return identifier + time + real_inputs + categorical_inputs
def _get_input_columns(self):
"""Returns names of all input columns."""
return [
tup[0]
for tup in self.get_column_definition()
if tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
def _get_tft_input_indices(self):
"""Returns the relevant indexes and input sizes required by TFT."""
# Functions
def _extract_tuples_from_data_type(data_type, defn):
return [
tup for tup in defn if tup[1] == data_type and
tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
def _get_locations(input_types, defn):
return [i for i, tup in enumerate(defn) if tup[2] in input_types]
# Start extraction
column_definition = [
tup for tup in self.get_column_definition()
if tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL,
column_definition)
real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED,
column_definition)
locations = {
'input_size':
len(self._get_input_columns()),
'output_size':
len(_get_locations({InputTypes.TARGET}, column_definition)),
'category_counts':
self.num_classes_per_cat_input,
'input_obs_loc':
_get_locations({InputTypes.TARGET}, column_definition),
'static_input_loc':
_get_locations({InputTypes.STATIC_INPUT}, column_definition),
'known_regular_inputs':
_get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT},
real_inputs),
'known_categorical_inputs':
_get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT},
categorical_inputs),
}
return locations
def get_experiment_params(self):
"""Returns fixed model parameters for experiments."""
required_keys = [
'total_time_steps', 'num_encoder_steps', 'num_epochs',
'early_stopping_patience', 'multiprocessing_workers'
]
fixed_params = self.get_fixed_params()
for k in required_keys:
if k not in fixed_params:
raise ValueError('Field {}'.format(k) +
' missing from fixed parameter definitions!')
fixed_params['column_definition'] = self.get_column_definition()
fixed_params.update(self._get_tft_input_indices())
return fixed_params