mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
Draft version of refactoring handler
This commit is contained in:
@@ -156,17 +156,17 @@ Data Handler
|
||||
|
||||
Users can use ``Data Handler`` in an automatic workflow by ``Estimator``, refer to `Estimator: Workflow Management <estimator.html>`_ for more details.
|
||||
|
||||
Also, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data(standardization, remove NaN, etc.) and build datasets. It is a subclass of ``qlib.data.dataset.handler.BaseDataHandler``, which provides some interfaces as follows.
|
||||
Also, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data(standardization, remove NaN, etc.) and build datasets. It is a subclass of ``qlib.data.dataset.handler.DataHandlerLP``, which provides some interfaces as follows.
|
||||
|
||||
Base Class & Interface
|
||||
----------------------
|
||||
|
||||
Qlib provides a base class `qlib.data.dataset.BaseDataHandler <../reference/api.html#qlib.data.dataset.handler.BaseDataHandler>`_, which provides the following interfaces:
|
||||
Qlib provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_, which provides the following interfaces:
|
||||
|
||||
- `setup_feature`
|
||||
- `load_feature`
|
||||
Implement the interface to load the data features.
|
||||
|
||||
- `setup_label`
|
||||
- `load_label`
|
||||
Implement the interface to load the data labels and calculate the users' labels.
|
||||
|
||||
- `setup_processed_data`
|
||||
@@ -174,11 +174,7 @@ Qlib provides a base class `qlib.data.dataset.BaseDataHandler <../reference/api.
|
||||
|
||||
Qlib also provides two functions to help users init the data handler, users can override them for users' needs.
|
||||
|
||||
- `_init_kwargs`
|
||||
Users can init the kwargs of the data handler in this function, some kwargs may be used when init the raw df.
|
||||
Kwargs are the other attributes in data.args, like dropna_label, dropna_feature
|
||||
|
||||
- `_init_raw_df`
|
||||
- `_init_raw_data`
|
||||
Users can init the raw df, feature names, and label names of data handler in this function.
|
||||
If the index of feature df and label df are not the same, users need to override this method to merge them (e.g. inner, left, right merge).
|
||||
|
||||
|
||||
@@ -284,7 +284,7 @@ To know more about ``Interday Model``, please refer to `Interday Model: Training
|
||||
Data Section
|
||||
-----------------
|
||||
|
||||
``Data Handler`` can be used to load raw data, prepare features and label columns, preprocess data (standardization, remove NaN, etc.), split training, validation, and test sets. It is a subclass of `qlib.data.dataset.handler.BaseDataHandler`.
|
||||
``Data Handler`` can be used to load raw data, prepare features and label columns, preprocess data (standardization, remove NaN, etc.), split training, validation, and test sets. It is a subclass of `qlib.data.dataset.handler.DataHandlerLP`.
|
||||
|
||||
Users can use the specified data handler by config as follows.
|
||||
|
||||
@@ -315,7 +315,7 @@ Users can use the specified data handler by config as follows.
|
||||
fend_time: 2018-12-11
|
||||
|
||||
- `class`
|
||||
Data handler class, str type, which should be a subclass of `qlib.data.dataset.handler.BaseDataHandler`, and implements 5 important interfaces for loading features, loading raw data, preprocessing raw data, slicing train, validation, and test data. The default value is `ALPHA360`. If users want to write a data handler to retrieve the data in ``Qlib``, `QlibDataHandler` is suggested.
|
||||
Data handler class, str type, which should be a subclass of `qlib.data.dataset.handler.DataHandlerLP`, and implements 5 important interfaces for loading features, loading raw data, preprocessing raw data, slicing train, validation, and test data. The default value is `ALPHA360`. If users want to write a data handler to retrieve the data in ``Qlib``, `QlibDataHandler` is suggested.
|
||||
|
||||
- `module_path`
|
||||
The module path, str type, absolute url is also supported, indicates the path of the `class` implementation of the data processor class. The default value is `qlib.data.dataset.handler`.
|
||||
@@ -363,7 +363,7 @@ Users can use the specified data handler by config as follows.
|
||||
Custom Data Handler
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Qlib support custom data handler, but it must be a subclass of the ``qlib.contrib.estimator.handler.BaseDataHandler``, the config for custom data handler may be as follows.
|
||||
Qlib support custom data handler, but it must be a subclass of the ``qlib.data.dataset.handler.DataHandlerLP``, the config for custom data handler may be as follows.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
|
||||
131
examples/workflow_by_code.py
Normal file
131
examples/workflow_by_code.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(target_dir=provider_uri)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "CSI300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_date": "2008-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"fit_start_time":"2008-01-01",
|
||||
"fit_end_time":"2014-12-31",
|
||||
"market": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2008-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
}
|
||||
|
||||
# use default DataHandler
|
||||
# custom DataHandler, refer to: TODO: DataHandler API url
|
||||
handler = Alpha158(**DATA_HANDLER_CONFIG)
|
||||
|
||||
data = handler.fetch(slice('2008-01-01', '2014-12-31'), key=handler.DK_I)
|
||||
print(data)
|
||||
|
||||
sys.exit(0) # I have tested the code above ---------------------------------------------
|
||||
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
|
||||
**TRAINER_CONFIG
|
||||
)
|
||||
|
||||
MODEL_CONFIG = {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
}
|
||||
# use default model
|
||||
# custom Model, refer to: TODO: Model API url
|
||||
model = LGBModel(**MODEL_CONFIG)
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
_pred = model.predict(x_test)
|
||||
_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
|
||||
|
||||
# backtest requires pred_score
|
||||
pred_score = pd.DataFrame(index=_pred.index)
|
||||
pred_score["score"] = _pred.iloc(axis=1)[0]
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
0
qlib/contrib/data/__init__.py
Normal file
0
qlib/contrib/data/__init__.py
Normal file
@@ -2,7 +2,9 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from ...data.dataset.handler import ConfigQLibDataHandler
|
||||
from ...data.dataset.processor import Processor, MinMaxNorm, ZscoreNorm, get_cls_kwargs
|
||||
from ...log import TimeInspector
|
||||
import copy
|
||||
|
||||
|
||||
class ALPHA360(ConfigQLibDataHandler):
|
||||
@@ -22,19 +24,36 @@ class QLibDataHandlerV1(ConfigQLibDataHandler):
|
||||
"rolling": {},
|
||||
}
|
||||
|
||||
def __init__(self, start_date, end_date, processors=None, **kwargs):
|
||||
if processors is None:
|
||||
processors = ["PanelProcessor"] # V1 default processor
|
||||
super().__init__(start_date, end_date, processors, **kwargs)
|
||||
def __init__(self, start_date, end_date, infer_processors=[], learn_processors=["DropnaLabel"], fit_start_time=None, fit_end_time=None, **kwargs):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p)
|
||||
if isinstance(klass, (MinMaxNorm, ZscoreNorm)):
|
||||
assert(fit_start_time is not None and fit_end_time is not None)
|
||||
pkwargs.update({
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
})
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
def setup_label(self):
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
super().__init__(start_date, end_date, infer_processors=infer_processors, learn_processors=learn_processors, **kwargs)
|
||||
|
||||
def load_label(self):
|
||||
"""
|
||||
load the labels df
|
||||
:return: df_labels
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
df_labels = super().setup_label()
|
||||
df_labels = super().load_label()
|
||||
|
||||
## calculate new labels
|
||||
df_labels["LABEL1"] = df_labels["LABEL0"].groupby(level="datetime").apply(lambda x: (x - x.mean()) / x.std())
|
||||
@@ -56,8 +75,6 @@ class Alpha158(QLibDataHandlerV1):
|
||||
"rolling": {},
|
||||
}
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["labels"] = ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
super(Alpha158, self)._init_kwargs(**kwargs)
|
||||
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -1,249 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...log import TimeInspector
|
||||
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
class Processor(abc.ABC):
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
self.feature_names = feature_names
|
||||
self.label_names = label_names
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df_train, df_valid, df_test):
|
||||
pass
|
||||
|
||||
|
||||
class PanelProcessor(Processor):
|
||||
"""Panel Preprocessor"""
|
||||
|
||||
STD_NORM = "Std"
|
||||
MINMAX_NORM = "MinMax"
|
||||
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
super().__init__(feature_names, label_names)
|
||||
# Options.
|
||||
self.dropna_label = kwargs.get("dropna_label", True)
|
||||
self.dropna_feature = kwargs.get("dropna_feature", False)
|
||||
self.normalize_method = kwargs.get("normalize_method", None)
|
||||
self.replace_inf = kwargs.get("replace_inf_feature", False)
|
||||
|
||||
def __call__(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Preprocess the data
|
||||
:param df: the dataframe to process data.
|
||||
"""
|
||||
# Drop null labels.
|
||||
if self.dropna_label:
|
||||
df_train, df_valid, df_test = self._process_drop_null_label(df_train, df_valid, df_test)
|
||||
|
||||
# Dropna if need.
|
||||
if self.dropna_feature:
|
||||
df_train, df_valid, df_test = self._process_drop_null_feature(df_train, df_valid, df_test)
|
||||
|
||||
# replace the 'inf' with the mean the corresponding dimension
|
||||
if self.replace_inf:
|
||||
df_train, df_valid, df_test = self._process_replace_inf_feature(df_train, df_valid, df_test)
|
||||
|
||||
# normalize data in given method.
|
||||
if self.normalize_method is not None:
|
||||
df_train, df_valid, df_test = self._process_normalize_feature(df_train, df_valid, df_test)
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_drop_null_label(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Drop null labels.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
df_train = df_train.dropna(subset=self.label_names)
|
||||
df_valid = df_valid.dropna(subset=self.label_names)
|
||||
# The test data's label is Unkown. They can not be seen when preprocessing
|
||||
TimeInspector.log_cost_time("Finished dropping null labels.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_drop_null_feature(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Drop data which contain null features if needed.
|
||||
"""
|
||||
# TODO - `Pandas.dropna` is a low performance method.
|
||||
TimeInspector.set_time_mark()
|
||||
df_train = df_train.dropna(subset=self.feature_names)
|
||||
df_valid = df_valid.dropna(subset=self.feature_names)
|
||||
df_test = df_test.dropna(subset=self.feature_names)
|
||||
TimeInspector.log_cost_time("Finished dropping nan.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_replace_inf_feature(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
replace the 'inf' in feature with the mean of this dimension.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
def replace_inf(data):
|
||||
def process_inf(df):
|
||||
for col in df.columns:
|
||||
df[col] = df[col].replace([np.inf, -np.inf], df[col][~np.isinf(df[col])].mean())
|
||||
return df
|
||||
|
||||
data = data.groupby("datetime").apply(process_inf)
|
||||
data.sort_index(inplace=True)
|
||||
return data
|
||||
|
||||
df_train = replace_inf(df_train)
|
||||
df_valid = replace_inf(df_valid)
|
||||
df_test = replace_inf(df_test)
|
||||
TimeInspector.log_cost_time("Finished replace inf.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_normalize_feature(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Normalize data if needed, we provide two method now: min-max normalization and standard normalization.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
if self.normalize_method == self.MINMAX_NORM:
|
||||
min_train = np.nanmin(df_train[self.feature_names].values, axis=0)
|
||||
max_train = np.nanmax(df_train[self.feature_names].values, axis=0)
|
||||
ignore = min_train == max_train
|
||||
|
||||
def normalize(x, min_train=min_train, max_train=max_train, ignore=ignore):
|
||||
if (~ignore).all():
|
||||
return (x - min_train) / (max_train - min_train)
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - min_train) / (max_train - min_train)
|
||||
return x
|
||||
|
||||
elif self.normalize_method == self.STD_NORM:
|
||||
mean_train = np.nanmean(df_train[self.feature_names].values, axis=0)
|
||||
std_train = np.nanstd(df_train[self.feature_names].values, axis=0)
|
||||
ignore = std_train == 0
|
||||
|
||||
def normalize(x, mean_train=mean_train, std_train=std_train, ignore=ignore):
|
||||
if (~ignore).all():
|
||||
return (x - mean_train) / std_train
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - mean_train) / std_train
|
||||
return x
|
||||
|
||||
else:
|
||||
raise ValueError("Normalize method {} is not allowed".format(self.normalize_method))
|
||||
|
||||
df_train.loc(axis=1)[self.feature_names] = normalize(df_train[self.feature_names].values)
|
||||
df_valid.loc(axis=1)[self.feature_names] = normalize(df_valid[self.feature_names].values)
|
||||
df_test.loc(axis=1)[self.feature_names] = normalize(df_test[self.feature_names].values)
|
||||
|
||||
TimeInspector.log_cost_time("Finished normalizing data.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
|
||||
class ConfigSectionProcessor(Processor):
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
super().__init__(feature_names, label_names)
|
||||
# Options
|
||||
self.fillna_feature = kwargs.get("fillna_feature", True)
|
||||
self.fillna_label = kwargs.get("fillna_label", True)
|
||||
self.clip_feature_outlier = kwargs.get("clip_feature_outlier", False)
|
||||
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
|
||||
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
|
||||
|
||||
def __call__(self, *args):
|
||||
return [self._transform(x) for x in args]
|
||||
|
||||
def _transform(self, df):
|
||||
def _label_norm(x):
|
||||
x = x - x.mean() # copy
|
||||
x /= x.std()
|
||||
if self.clip_label_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.fillna_label:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
def _feature_norm(x):
|
||||
x = x - x.median() # copy
|
||||
x /= x.abs().median() * 1.4826
|
||||
if self.clip_feature_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.shrink_feature_outlier:
|
||||
x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)
|
||||
x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
|
||||
if self.fillna_feature:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
# Copy
|
||||
df_new = df.copy()
|
||||
|
||||
# Label
|
||||
cols = df.columns[df.columns.str.contains("^LABEL")]
|
||||
df_new[cols] = df[cols].groupby(level="datetime").apply(_label_norm)
|
||||
|
||||
# Features
|
||||
cols = df.columns[df.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_new[cols] = df[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_new[cols] = df[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
"KSFT",
|
||||
"OPEN",
|
||||
"HIGH",
|
||||
"LOW",
|
||||
"CLOSE",
|
||||
"VWAP",
|
||||
"ROC",
|
||||
"MA",
|
||||
"BETA",
|
||||
"RESI",
|
||||
"QTLU",
|
||||
"QTLD",
|
||||
"RSV",
|
||||
"SUMP",
|
||||
"SUMN",
|
||||
"SUMD",
|
||||
"VSUMP",
|
||||
"VSUMN",
|
||||
"VSUMD",
|
||||
]
|
||||
pat = "|".join(["^" + x for x in _cols])
|
||||
cols = df.columns[df.columns.str.contains(pat) & (~df.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_new[cols] = df[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_new[cols] = df[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^RSQR")]
|
||||
df_new[cols] = df[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_new[cols] = df[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^MIN|^LOW0")]
|
||||
df_new[cols] = df[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^CORR|^CORD")]
|
||||
df_new[cols] = df[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^WVMA")]
|
||||
df_new[cols] = df[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
TimeInspector.log_cost_time("Finished preprocessing data.")
|
||||
|
||||
return df_new
|
||||
@@ -10,14 +10,14 @@ import numpy as np
|
||||
from scipy.stats import pearsonr
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from .handler import BaseDataHandler
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from .launcher import CONFIG_MANAGER
|
||||
from .fetcher import create_fetcher_with_config
|
||||
from ...utils import drop_nan_by_y_index, transform_end_date
|
||||
|
||||
|
||||
class BaseTrainer(object):
|
||||
def __init__(self, model_class, model_save_path, model_args, data_handler: BaseDataHandler, sacred_ex, **kwargs):
|
||||
def __init__(self, model_class, model_save_path, model_args, data_handler: DataHandlerLP, sacred_ex, **kwargs):
|
||||
# 1. Model.
|
||||
self.model_class = model_class
|
||||
self.model_save_path = model_save_path
|
||||
|
||||
@@ -11,6 +11,7 @@ from ...utils import get_pre_trading_date
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
# TODO: The base strategies will be moved out of contrib to core code
|
||||
class BaseStrategy:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -5,270 +5,342 @@
|
||||
import abc
|
||||
import bisect
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date
|
||||
from ...utils.serial import Serializable
|
||||
from pathlib import Path
|
||||
|
||||
from . import processor as processor_module
|
||||
|
||||
|
||||
class BaseDataHandler(abc.ABC):
|
||||
def __init__(self, processors=[], **kwargs):
|
||||
"""
|
||||
:param start_date:
|
||||
:param end_date:
|
||||
:param kwargs:
|
||||
"""
|
||||
# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
|
||||
class DataHandler(Serializable):
|
||||
'''
|
||||
The steps to using a handler
|
||||
1. initialized data handler (call by `init`).
|
||||
2. use the data
|
||||
|
||||
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported(The order will implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325
|
||||
SH600006 22.672380 7095624.0 22.508326 22.573947 0.557785
|
||||
|
||||
'''
|
||||
def __init__(self, init_data=True):
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
|
||||
# init data using kwargs
|
||||
self._init_kwargs(**kwargs)
|
||||
|
||||
# Setup data.
|
||||
self.raw_df, self.feature_names, self.label_names = self._init_raw_df()
|
||||
self._data = {}
|
||||
if init_data:
|
||||
self.init()
|
||||
super().__init__()
|
||||
|
||||
# Setup preprocessor
|
||||
self.processors = []
|
||||
for klass in processors:
|
||||
if isinstance(klass, str):
|
||||
try:
|
||||
klass = getattr(processor_module, klass)
|
||||
except:
|
||||
raise ValueError("unknown Processor %s" % klass)
|
||||
self.processors.append(klass(self.feature_names, self.label_names, **kwargs))
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
def init(self, force_reload: bool=True):
|
||||
"""
|
||||
init the kwargs of DataHandler
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
force_reload : bool
|
||||
force to reload the data even if the data have been initialized
|
||||
"""
|
||||
pass
|
||||
# if force_reload or hasattr(self, '_initialized', False):
|
||||
|
||||
def _init_raw_df(self):
|
||||
def get_level_index(self, df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
"""
|
||||
|
||||
get the level index of `df` given `level`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
data
|
||||
level : Union[str, int]
|
||||
index level
|
||||
|
||||
Returns
|
||||
-------
|
||||
int:
|
||||
The level index in the multiple index
|
||||
"""
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
return df.index.names.index(level)
|
||||
except (AttributeError, ValueError):
|
||||
# NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
|
||||
return ('datetime', 'instrument').index(level)
|
||||
elif isinstance(level, int):
|
||||
return level
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def _fetch_df(self, df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int]):
|
||||
"""
|
||||
fetch data from `data` with `selector` and `level`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
the data frame to be selected
|
||||
selector : Union[pd.Timestamp, slice, str, list]
|
||||
selector
|
||||
level : Union[pd.Timestamp, slice, str]
|
||||
the level to use the selector
|
||||
"""
|
||||
# Try to get the right index
|
||||
idx_slc = (selector, slice(None, None))
|
||||
if self.get_level_index(df, level) == 1:
|
||||
idx_slc = idx_slc[1], idx_slc[0]
|
||||
return df.loc(axis=0)[idx_slc]
|
||||
|
||||
def fetch(self, selector: Union[pd.Timestamp, slice, str], level='datetime', key=None) -> Union[pd.DataFrame, dict]:
|
||||
if key is None:
|
||||
res = {}
|
||||
for k, df in self._data.items():
|
||||
res[k] = self._fetch_df(df, selector, level)
|
||||
else:
|
||||
res = self._fetch_df(self._data[key], selector, level)
|
||||
return res
|
||||
|
||||
|
||||
class DataHandlerLP(DataHandler):
|
||||
'''
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
'''
|
||||
# data key
|
||||
DK_R = 'raw'
|
||||
DK_I = 'infer'
|
||||
DK_L = 'learn'
|
||||
|
||||
# process type
|
||||
PTYPE_I = 'independent'
|
||||
# - _proc_infer_df will processed by infer_processors
|
||||
# - _proc_learn_df will be processed by learn_processors
|
||||
PTYPE_A = 'append'
|
||||
# - _proc_infer_df will processed by infer_processors
|
||||
# - _proc_learn_df will be processed by infer_processors + learn_processors
|
||||
# - (e.g. _proc_infer_df processed by learn_processors )
|
||||
|
||||
def __init__(self, infer_processors=[], learn_processors=[], process_type=PTYPE_A, **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
infer_processors : list
|
||||
list of <description info> of processors to generate data for inference
|
||||
example of <description info>:
|
||||
1) classname & kwargs:
|
||||
{
|
||||
"class": "MinMaxNorm",
|
||||
"kwargs": {
|
||||
"fit_start_time": "20080101",
|
||||
"fit_end_time": "20121231"
|
||||
}
|
||||
}
|
||||
2) Only classname:
|
||||
"DropnaFeature"
|
||||
3) object instance of Processor
|
||||
|
||||
learn_processors : list
|
||||
similar to infer_processors, but for generating data for learning models
|
||||
|
||||
process_type: str
|
||||
PTYPE_I = 'independent'
|
||||
- _proc_infer_df will processed by infer_processors
|
||||
- _proc_learn_df will be processed by learn_processors
|
||||
PTYPE_A = 'append'
|
||||
- _proc_infer_df will processed by infer_processors
|
||||
- _proc_learn_df will be processed by infer_processors + learn_processors
|
||||
- (e.g. _proc_infer_df processed by learn_processors )
|
||||
"""
|
||||
|
||||
# Setup preprocessor
|
||||
self.infer_processors = [] # for lint
|
||||
self.learn_processors = [] # for lint
|
||||
for pname in 'infer_processors', 'learn_processors':
|
||||
for proc in locals()[pname]:
|
||||
getattr(self, pname).append(processor_module.init_proc_obj(proc))
|
||||
|
||||
self.process_type = process_type
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_all_processors(self):
|
||||
return self.infer_processors + self.learn_processors
|
||||
|
||||
def _init_raw_data(self):
|
||||
"""
|
||||
initialize the raw data
|
||||
the raw data will be saved in to `self._data['raw']`
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_init_raw_data` method")
|
||||
|
||||
def fit(self):
|
||||
for proc in self.get_all_processors():
|
||||
proc.fit(self)
|
||||
|
||||
def fit_process_data(self):
|
||||
"""
|
||||
fit and process data
|
||||
|
||||
The input of the `fit` will be the output of the previous processor
|
||||
"""
|
||||
self.process_data(with_fit=True)
|
||||
|
||||
|
||||
def process_data(self, with_fit: bool=False):
|
||||
"""
|
||||
process_data data. Fun `processor.fit` if necessary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
with_fit : bool
|
||||
The input of the `fit` will be the output of the previous processor
|
||||
"""
|
||||
# data for inference
|
||||
_infer_df = self._data[DataHandlerLP.DK_R]
|
||||
for proc in self.infer_processors:
|
||||
if not proc.is_for_infer():
|
||||
raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
|
||||
if with_fit:
|
||||
proc.fit(self, _infer_df)
|
||||
_infer_df = proc(_infer_df)
|
||||
|
||||
# data for learning
|
||||
if self.process_type == DataHandlerLP.PTYPE_I:
|
||||
_learn_df = self._data[DataHandlerLP.DK_R]
|
||||
elif self.process_type == DataHandlerLP.PTYPE_A:
|
||||
# based on `infer_df` and append the processor
|
||||
_learn_df = _infer_df
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
for proc in self.learn_processors:
|
||||
if with_fit:
|
||||
proc.fit(self, _learn_df)
|
||||
_learn_df = proc(_learn_df)
|
||||
|
||||
self._data.update({
|
||||
DataHandlerLP.DK_I: _infer_df,
|
||||
DataHandlerLP.DK_L: _learn_df,
|
||||
})
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = 'fit_seq' # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = 'fit_ind' # the input of `fit` will be the original df
|
||||
IT_LS = 'load_state' # The state of the object has been load by pickle
|
||||
|
||||
def init(self, init_type: str=IT_FIT_SEQ, path: Path=None):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_type : str
|
||||
'fit' or 'load_state'
|
||||
path : path
|
||||
if `init_type` == 'load_state': `path` will be used to load_state
|
||||
"""
|
||||
self._init_raw_data()
|
||||
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
self.fit()
|
||||
self.process_data()
|
||||
elif init_type == DataHandlerLP.IT_LS:
|
||||
self.process_data()
|
||||
elif init_type == DataHandlerLP.IT_FIT_SEQ:
|
||||
self.fit_process_data()
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
|
||||
class DataHandlerLPWL(DataHandlerLP):
|
||||
'''
|
||||
DataHandler with (L)earnable (P)rocessor with (L)abel
|
||||
'''
|
||||
|
||||
def _init_raw_data(self):
|
||||
"""
|
||||
init raw_df, feature_names, label_names of DataHandler
|
||||
if the index of df_feature and df_label are not same, user need to overload this method to merge (e.g. inner, left, right merge).
|
||||
|
||||
"""
|
||||
df_features = self.setup_feature()
|
||||
df_features = self.load_feature()
|
||||
feature_names = df_features.columns
|
||||
|
||||
df_labels = self.setup_label()
|
||||
df_labels = self.load_label()
|
||||
label_names = df_labels.columns
|
||||
|
||||
raw_df = df_features.merge(df_labels, left_index=True, right_index=True, how="left")
|
||||
self.feature_names = feature_names
|
||||
self.label_names = label_names
|
||||
self._data['raw'] = raw_df
|
||||
|
||||
return raw_df, feature_names, label_names
|
||||
|
||||
def reset_label(self, df_labels):
|
||||
for col in self.label_names:
|
||||
del self.raw_df[col]
|
||||
self.label_names = df_labels.columns
|
||||
self.raw_df = self.raw_df.merge(df_labels, left_index=True, right_index=True, how="left")
|
||||
|
||||
def split_rolling_periods(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq="day",
|
||||
):
|
||||
"""
|
||||
Calculating the Rolling split periods, the period rolling on market calendar.
|
||||
:param train_start_date:
|
||||
:param train_end_date:
|
||||
:param validate_start_date:
|
||||
:param validate_end_date:
|
||||
:param test_start_date:
|
||||
:param test_end_date:
|
||||
:param rolling_period: The market period of rolling
|
||||
:param calendar_freq: The frequence of the market calendar
|
||||
:yield: Rolling split periods
|
||||
"""
|
||||
|
||||
def get_start_index(calendar, start_date):
|
||||
start_index = bisect.bisect_left(calendar, start_date)
|
||||
return start_index
|
||||
|
||||
def get_end_index(calendar, end_date):
|
||||
end_index = bisect.bisect_right(calendar, end_date)
|
||||
return end_index - 1
|
||||
|
||||
calendar = self.raw_df.index.get_level_values("datetime").unique()
|
||||
|
||||
train_start_index = get_start_index(calendar, pd.Timestamp(train_start_date))
|
||||
train_end_index = get_end_index(calendar, pd.Timestamp(train_end_date))
|
||||
valid_start_index = get_start_index(calendar, pd.Timestamp(validate_start_date))
|
||||
valid_end_index = get_end_index(calendar, pd.Timestamp(validate_end_date))
|
||||
test_start_index = get_start_index(calendar, pd.Timestamp(test_start_date))
|
||||
test_end_index = test_start_index + rolling_period - 1
|
||||
|
||||
need_stop_split = False
|
||||
|
||||
bound_test_end_index = get_end_index(calendar, pd.Timestamp(test_end_date))
|
||||
|
||||
while not need_stop_split:
|
||||
|
||||
if test_end_index > bound_test_end_index:
|
||||
test_end_index = bound_test_end_index
|
||||
need_stop_split = True
|
||||
|
||||
yield (
|
||||
calendar[train_start_index],
|
||||
calendar[train_end_index],
|
||||
calendar[valid_start_index],
|
||||
calendar[valid_end_index],
|
||||
calendar[test_start_index],
|
||||
calendar[test_end_index],
|
||||
)
|
||||
|
||||
train_start_index += rolling_period
|
||||
train_end_index += rolling_period
|
||||
valid_start_index += rolling_period
|
||||
valid_end_index += rolling_period
|
||||
test_start_index += rolling_period
|
||||
test_end_index += rolling_period
|
||||
|
||||
def get_rolling_data(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq="day",
|
||||
):
|
||||
# Set generator.
|
||||
for period in self.split_rolling_periods(
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq,
|
||||
):
|
||||
(
|
||||
x_train,
|
||||
y_train,
|
||||
x_validate,
|
||||
y_validate,
|
||||
x_test,
|
||||
y_test,
|
||||
) = self.get_split_data(*period)
|
||||
yield x_train, y_train, x_validate, y_validate, x_test, y_test
|
||||
|
||||
def get_split_data(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
):
|
||||
"""
|
||||
all return types are DataFrame
|
||||
"""
|
||||
## TODO: loc can be slow, expecially when we put it at the second level index.
|
||||
if self.raw_df.index.names[0] == "instrument":
|
||||
df_train = self.raw_df.loc(axis=0)[:, train_start_date:train_end_date]
|
||||
df_validate = self.raw_df.loc(axis=0)[:, validate_start_date:validate_end_date]
|
||||
df_test = self.raw_df.loc(axis=0)[:, test_start_date:test_end_date]
|
||||
else:
|
||||
df_train = self.raw_df.loc[train_start_date:train_end_date]
|
||||
df_validate = self.raw_df.loc[validate_start_date:validate_end_date]
|
||||
df_test = self.raw_df.loc[test_start_date:test_end_date]
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
df_train, df_validate, df_test = self.setup_process_data(df_train, df_validate, df_test)
|
||||
TimeInspector.log_cost_time("Finished setup processed data.")
|
||||
|
||||
x_train = df_train[self.feature_names]
|
||||
y_train = df_train[self.label_names]
|
||||
|
||||
x_validate = df_validate[self.feature_names]
|
||||
y_validate = df_validate[self.label_names]
|
||||
|
||||
x_test = df_test[self.feature_names]
|
||||
y_test = df_test[self.label_names]
|
||||
|
||||
return x_train, y_train, x_validate, y_validate, x_test, y_test
|
||||
|
||||
def setup_process_data(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
process the train, valid and test data
|
||||
:return: the processed train, valid and test data.
|
||||
"""
|
||||
for processor in self.processors:
|
||||
df_train, df_valid, df_test = processor(df_train, df_valid, df_test)
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def get_origin_test_label_with_date(self, test_start_date, test_end_date, freq="day"):
|
||||
"""Get origin test label
|
||||
|
||||
:param test_start_date: test start date
|
||||
:param test_end_date: test end date
|
||||
:param freq: freq
|
||||
:return: pd.DataFrame
|
||||
"""
|
||||
test_end_date = transform_end_date(test_end_date, freq=freq)
|
||||
return self.raw_df.loc[(slice(None), slice(test_start_date, test_end_date)), self.label_names]
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_feature(self):
|
||||
def load_feature(self):
|
||||
"""
|
||||
Implement this method to load raw feature.
|
||||
the format of the feature is below
|
||||
return: df_features
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError(f"Please implement `load_feature`")
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_label(self):
|
||||
def load_label(self):
|
||||
"""
|
||||
Implement this method to load and calculate label.
|
||||
the format of the label is below
|
||||
|
||||
return: df_label
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError(f"Please implement `load_label`")
|
||||
|
||||
def get_feature_names(self):
|
||||
return self.feature_names
|
||||
|
||||
def get_label_names(self):
|
||||
return self.label_names
|
||||
|
||||
|
||||
class QLibDataHandler(BaseDataHandler):
|
||||
class QLibDataHandler(DataHandlerLPWL):
|
||||
def __init__(self, start_date, end_date, *args, **kwargs):
|
||||
# Dates.
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
|
||||
# Instruments
|
||||
instruments = kwargs.get("instruments", None)
|
||||
instruments = kwargs.pop("instruments", None)
|
||||
if instruments is None:
|
||||
market = kwargs.get("market", "csi500").lower()
|
||||
data_filter_list = kwargs.get("data_filter_list", list())
|
||||
market = kwargs.pop("market", "csi500").lower()
|
||||
data_filter_list = kwargs.pop("data_filter_list", list())
|
||||
self.instruments = D.instruments(market, filter_pipe=data_filter_list)
|
||||
else:
|
||||
self.instruments = instruments
|
||||
|
||||
# Config of features and labels
|
||||
self._fields = kwargs.get("fields", [])
|
||||
self._names = kwargs.get("names", [])
|
||||
self._labels = kwargs.get("labels", [])
|
||||
self._label_names = kwargs.get("label_names", [])
|
||||
self._fields = kwargs.pop("fields", [])
|
||||
self._names = kwargs.pop("names", [])
|
||||
self._labels = kwargs.pop("labels", [])
|
||||
self._label_names = kwargs.pop("label_names", [])
|
||||
|
||||
# Check arguments
|
||||
assert len(self._fields) > 0, "features list is empty"
|
||||
@@ -278,7 +350,9 @@ class QLibDataHandler(BaseDataHandler):
|
||||
# If test_end_date is -1 or greater than the last date, the last date is used
|
||||
self.end_date = transform_end_date(self.end_date)
|
||||
|
||||
def setup_feature(self):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def load_feature(self):
|
||||
"""
|
||||
Load the raw data.
|
||||
return: df_features
|
||||
@@ -297,7 +371,7 @@ class QLibDataHandler(BaseDataHandler):
|
||||
|
||||
return df_features
|
||||
|
||||
def setup_label(self):
|
||||
def load_label(self):
|
||||
"""
|
||||
Build up labels in df through users' method
|
||||
:return: df_labels
|
||||
@@ -498,12 +572,7 @@ def parse_config_to_fields(config):
|
||||
class ConfigQLibDataHandler(QLibDataHandler):
|
||||
config_template = {} # template
|
||||
|
||||
def __init__(self, start_date, end_date, processors=None, **kwargs):
|
||||
if processors is None:
|
||||
processors = ["ConfigSectionProcessor"] # default processor
|
||||
super().__init__(start_date, end_date, processors, **kwargs)
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
def __init__(self, start_date, end_date, infer_processors=["ConfigSectionProcessor"], learn_processors=[], **kwargs):
|
||||
config = self.config_template.copy()
|
||||
if "config_update" in kwargs:
|
||||
config.update(kwargs["config_update"])
|
||||
@@ -512,4 +581,5 @@ class ConfigQLibDataHandler(QLibDataHandler):
|
||||
kwargs["names"] = names
|
||||
if "labels" not in kwargs:
|
||||
kwargs["labels"] = ["Ref($vwap, -2)/Ref($vwap, -1) - 1"]
|
||||
super()._init_kwargs(**kwargs)
|
||||
|
||||
super().__init__(start_date, end_date, infer_processors=infer_processors, learn_processors=learn_processors, **kwargs)
|
||||
|
||||
@@ -4,154 +4,209 @@
|
||||
import abc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
from ...log import TimeInspector
|
||||
from ...utils.serial import Serializable
|
||||
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
class Processor(abc.ABC):
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
self.feature_names = feature_names
|
||||
self.label_names = label_names
|
||||
class Processor(Serializable):
|
||||
|
||||
def fit(self, handler, df: pd.DataFrame=None):
|
||||
"""
|
||||
learn data processing parameters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler : DataHandlerLP
|
||||
The data handler to processing data
|
||||
df : pd.DataFrame
|
||||
When we fit and process data with processor one by one. The fit function reiles on the output of previous
|
||||
processor, i.e. `df`.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df_train, df_valid, df_test):
|
||||
def __call__(self, df: pd.DataFrame):
|
||||
"""
|
||||
process the data
|
||||
|
||||
NOTE: The processor should not change the content of `df`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The raw_df of handler or result from previous processor
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PanelProcessor(Processor):
|
||||
"""Panel Preprocessor"""
|
||||
def get_cls_kwargs(processor: [dict, str]) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from processor info
|
||||
|
||||
STD_NORM = "Std"
|
||||
MINMAX_NORM = "MinMax"
|
||||
Parameters
|
||||
----------
|
||||
processor : [dict, str]
|
||||
similar to processor
|
||||
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
super().__init__(feature_names, label_names)
|
||||
# Options.
|
||||
self.dropna_label = kwargs.get("dropna_label", True)
|
||||
self.dropna_feature = kwargs.get("dropna_feature", False)
|
||||
self.normalize_method = kwargs.get("normalize_method", None)
|
||||
self.replace_inf = kwargs.get("replace_inf_feature", False)
|
||||
Returns
|
||||
-------
|
||||
(type, dict):
|
||||
the class object and it's arguments.
|
||||
"""
|
||||
if isinstance(processor, dict):
|
||||
# raise AttributeError
|
||||
klass = globals()[processor['class']]
|
||||
kwargs = processor['kwargs']
|
||||
elif isinstance(processor, str):
|
||||
klass = globals()[processor]
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return klass, kwargs
|
||||
|
||||
def __call__(self, df_train, df_valid, df_test):
|
||||
|
||||
# Place the function here to be able to reference the Processor
|
||||
def init_proc_obj(processor: [dict, str, Processor]) -> Processor:
|
||||
"""
|
||||
Initialize Processor Object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
processor : [dict, str, Processor]
|
||||
The info to initialize processor
|
||||
|
||||
Returns
|
||||
-------
|
||||
Processor:
|
||||
initialized Processor
|
||||
"""
|
||||
if not isinstance(processor, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(processor)
|
||||
processor = klass(**pkwargs)
|
||||
return processor
|
||||
|
||||
|
||||
class InferProcessor(Processor):
|
||||
'''This processor is usable for inference'''
|
||||
def is_for_infer(self) -> bool:
|
||||
"""
|
||||
Preprocess the data
|
||||
:param df: the dataframe to process data.
|
||||
Is this processor usable for inference
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
if it is usable for infenrece
|
||||
"""
|
||||
# Drop null labels.
|
||||
if self.dropna_label:
|
||||
df_train, df_valid, df_test = self._process_drop_null_label(df_train, df_valid, df_test)
|
||||
return True
|
||||
|
||||
# Dropna if need.
|
||||
if self.dropna_feature:
|
||||
df_train, df_valid, df_test = self._process_drop_null_feature(df_train, df_valid, df_test)
|
||||
|
||||
# replace the 'inf' with the mean the corresponding dimension
|
||||
if self.replace_inf:
|
||||
df_train, df_valid, df_test = self._process_replace_inf_feature(df_train, df_valid, df_test)
|
||||
|
||||
# normalize data in given method.
|
||||
if self.normalize_method is not None:
|
||||
df_train, df_valid, df_test = self._process_normalize_feature(df_train, df_valid, df_test)
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_drop_null_label(self, df_train, df_valid, df_test):
|
||||
class NInferProcessor(Processor):
|
||||
'''This processor is not usable for inference'''
|
||||
def is_for_infer(self) -> bool:
|
||||
"""
|
||||
Drop null labels.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
df_train = df_train.dropna(subset=self.label_names)
|
||||
df_valid = df_valid.dropna(subset=self.label_names)
|
||||
# The test data's label is Unkown. They can not be seen when preprocessing
|
||||
TimeInspector.log_cost_time("Finished dropping null labels.")
|
||||
Is this processor usable for inference
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_drop_null_feature(self, df_train, df_valid, df_test):
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
if it is usable for infenrece
|
||||
"""
|
||||
Drop data which contain null features if needed.
|
||||
"""
|
||||
# TODO - `Pandas.dropna` is a low performance method.
|
||||
TimeInspector.set_time_mark()
|
||||
df_train = df_train.dropna(subset=self.feature_names)
|
||||
df_valid = df_valid.dropna(subset=self.feature_names)
|
||||
df_test = df_test.dropna(subset=self.feature_names)
|
||||
TimeInspector.log_cost_time("Finished dropping nan.")
|
||||
return False
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_replace_inf_feature(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
replace the 'inf' in feature with the mean of this dimension.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
class DropnaFeature(InferProcessor):
|
||||
def fit(self, handler, df=None):
|
||||
self.feature_names = copy.deepcopy(handler.get_feature_names())
|
||||
|
||||
def __call__(self, df):
|
||||
return df.dropna(subset=self.feature_names)
|
||||
|
||||
|
||||
class DropnaLabel(InferProcessor):
|
||||
def fit(self, handler, df=None):
|
||||
self.label_names = copy.deepcopy(handler.get_label_names())
|
||||
|
||||
def __call__(self, df):
|
||||
return df.dropna(subset=self.label_names)
|
||||
|
||||
|
||||
class ProcessInf(InferProcessor):
|
||||
'''Process infinity '''
|
||||
def __call__(self, df):
|
||||
def replace_inf(data):
|
||||
def process_inf(df):
|
||||
for col in df.columns:
|
||||
# FIXME: Such behavior is very weird
|
||||
df[col] = df[col].replace([np.inf, -np.inf], df[col][~np.isinf(df[col])].mean())
|
||||
return df
|
||||
|
||||
data = data.groupby("datetime").apply(process_inf)
|
||||
data.sort_index(inplace=True)
|
||||
return data
|
||||
|
||||
df_train = replace_inf(df_train)
|
||||
df_valid = replace_inf(df_valid)
|
||||
df_test = replace_inf(df_test)
|
||||
TimeInspector.log_cost_time("Finished replace inf.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_normalize_feature(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Normalize data if needed, we provide two method now: min-max normalization and standard normalization.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
if self.normalize_method == self.MINMAX_NORM:
|
||||
min_train = np.nanmin(df_train[self.feature_names].values, axis=0)
|
||||
max_train = np.nanmax(df_train[self.feature_names].values, axis=0)
|
||||
ignore = min_train == max_train
|
||||
|
||||
def normalize(x, min_train=min_train, max_train=max_train, ignore=ignore):
|
||||
if (~ignore).all():
|
||||
return (x - min_train) / (max_train - min_train)
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - min_train) / (max_train - min_train)
|
||||
return x
|
||||
|
||||
elif self.normalize_method == self.STD_NORM:
|
||||
mean_train = np.nanmean(df_train[self.feature_names].values, axis=0)
|
||||
std_train = np.nanstd(df_train[self.feature_names].values, axis=0)
|
||||
ignore = std_train == 0
|
||||
|
||||
def normalize(x, mean_train=mean_train, std_train=std_train, ignore=ignore):
|
||||
if (~ignore).all():
|
||||
return (x - mean_train) / std_train
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - mean_train) / std_train
|
||||
return x
|
||||
|
||||
else:
|
||||
raise ValueError("Normalize method {} is not allowed".format(self.normalize_method))
|
||||
|
||||
df_train.loc(axis=1)[self.feature_names] = normalize(df_train[self.feature_names].values)
|
||||
df_valid.loc(axis=1)[self.feature_names] = normalize(df_valid[self.feature_names].values)
|
||||
df_test.loc(axis=1)[self.feature_names] = normalize(df_test[self.feature_names].values)
|
||||
|
||||
TimeInspector.log_cost_time("Finished normalizing data.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
return replace_inf(df)
|
||||
|
||||
|
||||
class ConfigSectionProcessor(Processor):
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
super().__init__(feature_names, label_names)
|
||||
class MinMaxNorm(InferProcessor):
|
||||
def __init__(self, fit_start_time, fit_end_time):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
|
||||
def fit(self, handler, df):
|
||||
# TODO: 看看这里怎么取数据
|
||||
self.min_val = np.nanmin(df[handler.get_feature_names()].values, axis=0)
|
||||
self.max_val = np.nanmax(df[handler.get_feature_names()].values, axis=0)
|
||||
self.ignore = self.min_val == self.max_val
|
||||
self.feature_names = copy.deepcopy(handler.get_feature_names())
|
||||
|
||||
def __call__(self, df):
|
||||
# FIXME: The df will be changed inplace. It's very dangerous
|
||||
# The code below is ugly
|
||||
df = df.copy() # currently copy is used
|
||||
def normalize(x, min_val=self.min_val, max_val=self.max_val, ignore=self.ignore):
|
||||
if (~ignore).all():
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - min_val) / (max_val - min_val)
|
||||
return x
|
||||
df.loc(axis=1)[self.feature_names] = normalize(df[self.feature_names].values)
|
||||
return df
|
||||
|
||||
|
||||
class ZscoreNorm(InferProcessor):
|
||||
def __init__(self, fit_start_time, fit_end_time):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
|
||||
def fit(self, handler, df):
|
||||
self.mean_train = np.nanmean(df[handler.get_feature_names()].values, axis=0)
|
||||
self.std_train = np.nanstd(df[handler.get_feature_names()].values, axis=0)
|
||||
self.ignore = self.std_train == 0
|
||||
self.feature_names = handler.get_feature_names()
|
||||
|
||||
def __call__(self, df):
|
||||
# FIXME: The df will be changed inplace. It's very dangerous
|
||||
# The code below is ugly
|
||||
df = df.copy() # currently copy is used
|
||||
def normalize(x, mean_train=self.mean_train, std_train=self.std_train, ignore=self.ignore):
|
||||
if (~ignore).all():
|
||||
return (x - mean_train) / std_train
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - mean_train) / std_train
|
||||
return x
|
||||
df.loc(axis=1)[self.feature_names] = normalize(df[self.feature_names].values)
|
||||
return df
|
||||
|
||||
|
||||
class ConfigSectionProcessor(InferProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
# Options
|
||||
self.fillna_feature = kwargs.get("fillna_feature", True)
|
||||
self.fillna_label = kwargs.get("fillna_label", True)
|
||||
@@ -159,8 +214,12 @@ class ConfigSectionProcessor(Processor):
|
||||
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
|
||||
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
|
||||
|
||||
def __call__(self, *args):
|
||||
return [self._transform(x) for x in args]
|
||||
def fit(self, handler, df=None):
|
||||
self.feature_names = handler.get_feature_names()
|
||||
self.label_names = handler.get_label_names()
|
||||
|
||||
def __call__(self, df):
|
||||
return self._transform(df)
|
||||
|
||||
def _transform(self, df):
|
||||
def _label_norm(x):
|
||||
|
||||
142
qlib/model/task.py
Normal file
142
qlib/model/task.py
Normal file
@@ -0,0 +1,142 @@
|
||||
'''
|
||||
Please implement similar function here
|
||||
|
||||
# Rolling relealted
|
||||
|
||||
def split_rolling_periods(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq="day",
|
||||
):
|
||||
"""
|
||||
Calculating the Rolling split periods, the period rolling on market calendar.
|
||||
:param train_start_date:
|
||||
:param train_end_date:
|
||||
:param validate_start_date:
|
||||
:param validate_end_date:
|
||||
:param test_start_date:
|
||||
:param test_end_date:
|
||||
:param rolling_period: The market period of rolling
|
||||
:param calendar_freq: The frequence of the market calendar
|
||||
:yield: Rolling split periods
|
||||
"""
|
||||
|
||||
def get_start_index(calendar, start_date):
|
||||
start_index = bisect.bisect_left(calendar, start_date)
|
||||
return start_index
|
||||
|
||||
def get_end_index(calendar, end_date):
|
||||
end_index = bisect.bisect_right(calendar, end_date)
|
||||
return end_index - 1
|
||||
|
||||
calendar = self.raw_df.index.get_level_values("datetime").unique()
|
||||
|
||||
train_start_index = get_start_index(calendar, pd.Timestamp(train_start_date))
|
||||
train_end_index = get_end_index(calendar, pd.Timestamp(train_end_date))
|
||||
valid_start_index = get_start_index(calendar, pd.Timestamp(validate_start_date))
|
||||
valid_end_index = get_end_index(calendar, pd.Timestamp(validate_end_date))
|
||||
test_start_index = get_start_index(calendar, pd.Timestamp(test_start_date))
|
||||
test_end_index = test_start_index + rolling_period - 1
|
||||
|
||||
need_stop_split = False
|
||||
|
||||
bound_test_end_index = get_end_index(calendar, pd.Timestamp(test_end_date))
|
||||
|
||||
while not need_stop_split:
|
||||
|
||||
if test_end_index > bound_test_end_index:
|
||||
test_end_index = bound_test_end_index
|
||||
need_stop_split = True
|
||||
|
||||
yield (
|
||||
calendar[train_start_index],
|
||||
calendar[train_end_index],
|
||||
calendar[valid_start_index],
|
||||
calendar[valid_end_index],
|
||||
calendar[test_start_index],
|
||||
calendar[test_end_index],
|
||||
)
|
||||
|
||||
train_start_index += rolling_period
|
||||
train_end_index += rolling_period
|
||||
valid_start_index += rolling_period
|
||||
valid_end_index += rolling_period
|
||||
test_start_index += rolling_period
|
||||
test_end_index += rolling_period
|
||||
|
||||
def get_rolling_data(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq="day",
|
||||
):
|
||||
# Set generator.
|
||||
for period in self.split_rolling_periods(
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq,
|
||||
):
|
||||
(
|
||||
x_train,
|
||||
y_train,
|
||||
x_validate,
|
||||
y_validate,
|
||||
x_test,
|
||||
y_test,
|
||||
) = self.get_split_data(*period)
|
||||
yield x_train, y_train, x_validate, y_validate, x_test, y_test
|
||||
|
||||
def get_split_data(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
):
|
||||
"""
|
||||
all return types are DataFrame
|
||||
"""
|
||||
## TODO: loc can be slow, expecially when we put it at the second level index.
|
||||
if self.raw_df.index.names[0] == "instrument":
|
||||
df_train = self.raw_df.loc(axis=0)[:, train_start_date:train_end_date]
|
||||
df_validate = self.raw_df.loc(axis=0)[:, validate_start_date:validate_end_date]
|
||||
df_test = self.raw_df.loc(axis=0)[:, test_start_date:test_end_date]
|
||||
else:
|
||||
df_train = self.raw_df.loc[train_start_date:train_end_date]
|
||||
df_validate = self.raw_df.loc[validate_start_date:validate_end_date]
|
||||
df_test = self.raw_df.loc[test_start_date:test_end_date]
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
df_train, df_validate, df_test = self.process_data(df_train, df_validate, df_test)
|
||||
TimeInspector.log_cost_time("Finished setup processed data.")
|
||||
|
||||
x_train = df_train[self.feature_names]
|
||||
y_train = df_train[self.label_names]
|
||||
|
||||
x_validate = df_validate[self.feature_names]
|
||||
y_validate = df_validate[self.label_names]
|
||||
|
||||
x_test = df_test[self.feature_names]
|
||||
y_test = df_test[self.label_names]
|
||||
|
||||
return x_train, y_train, x_validate, y_validate, x_test, y_test
|
||||
|
||||
'''
|
||||
@@ -24,8 +24,8 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from .config import C
|
||||
from .log import get_module_logger
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
|
||||
log = get_module_logger("utils")
|
||||
|
||||
@@ -377,7 +377,7 @@ def is_tradable_date(cur_date):
|
||||
date : pandas.Timestamp
|
||||
current date
|
||||
"""
|
||||
from .data import D
|
||||
from ..data import D
|
||||
|
||||
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
|
||||
|
||||
@@ -390,7 +390,7 @@ def get_date_range(trading_date, shift, future=False):
|
||||
:param future: bool
|
||||
:return:
|
||||
"""
|
||||
from .data import D
|
||||
from ..data import D
|
||||
|
||||
calendar = D.calendar(future=future)
|
||||
if pd.to_datetime(trading_date) not in list(calendar):
|
||||
@@ -445,7 +445,7 @@ def transform_end_date(end_date=None, freq="day"):
|
||||
date : pandas.Timestamp
|
||||
current date
|
||||
"""
|
||||
from .data import D
|
||||
from ..data import D
|
||||
|
||||
last_date = D.calendar(freq=freq)[-1]
|
||||
if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)):
|
||||
130
qlib/utils/objm.py
Normal file
130
qlib/utils/objm.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from qlib.config import C
|
||||
|
||||
|
||||
class ObjManager:
|
||||
def save_obj(self, obj: object, name: str):
|
||||
"""
|
||||
save obj as name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : object
|
||||
object to be saved
|
||||
name : str
|
||||
name of the object
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement `save_obj`")
|
||||
|
||||
def save_objs(self, obj_name_l):
|
||||
"""
|
||||
save objects
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj_name_l : list of <obj, name>
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `save_objs` method")
|
||||
|
||||
def load_obj(self, name: str) -> object:
|
||||
"""
|
||||
load object by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name of the object
|
||||
|
||||
Returns
|
||||
-------
|
||||
object:
|
||||
loaded object
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load_obj` method")
|
||||
|
||||
def exists(self, name: str) -> bool:
|
||||
"""
|
||||
if the object named `name` exists
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
name of the objecT
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
If the object exists
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `exists` method")
|
||||
|
||||
def list(self) -> list:
|
||||
"""
|
||||
list the objects
|
||||
|
||||
Returns
|
||||
-------
|
||||
list:
|
||||
the list of returned objects
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `list` method")
|
||||
|
||||
def remove(self, fname=None):
|
||||
"""remove.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname :
|
||||
if file name is provided. specific file is removed
|
||||
otherwise, The all the objects will be removed.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `remove` method")
|
||||
|
||||
|
||||
class FileManager(ObjManager):
|
||||
'''
|
||||
Use file system to manage objects
|
||||
'''
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
self.path = Path(self.create_path())
|
||||
else:
|
||||
self.path = Path(path).resolve()
|
||||
|
||||
def create_path(self) -> str:
|
||||
try:
|
||||
return tempfile.mkdtemp(prefix=str(C['file_manager_path']) + os.sep)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(f"If path is not given, the `create_path` function should be implemented")
|
||||
|
||||
def save_obj(self, obj, name):
|
||||
with (self.path / name).open('wb') as f:
|
||||
pickle.dump(obj, f)
|
||||
|
||||
def save_objs(self, obj_name_l):
|
||||
for obj, name in obj_name_l:
|
||||
self.save_obj(obj, name)
|
||||
|
||||
def load_obj(self, name):
|
||||
with (self.path / name).open('rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def exists(self, name):
|
||||
return (self.path / name).exists()
|
||||
|
||||
def list(self):
|
||||
return list(self.path.iterdir())
|
||||
|
||||
def remove(self, fname=None):
|
||||
if fname is None:
|
||||
for fp in self.path.glob('*'):
|
||||
fp.unlink()
|
||||
self.path.rmdir()
|
||||
else:
|
||||
(self.path / fname).unlink()
|
||||
22
qlib/utils/serial.py
Normal file
22
qlib/utils/serial.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
|
||||
class Serializable:
|
||||
'''
|
||||
Serializable behaves like pickle.
|
||||
But it only save the state whose name starts with `_`
|
||||
'''
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
return {k: v for k, v in self.__dict__.items() if k.startswith('_') }
|
||||
|
||||
def __setstate__(self, state: dict):
|
||||
self.__dict__.update(state)
|
||||
|
||||
def to_pickle(self, path: [Path, str]):
|
||||
with Path(path).open('wb') as f:
|
||||
pickle.dump(self, f)
|
||||
Reference in New Issue
Block a user