1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00

update comments

This commit is contained in:
bxdd
2021-03-30 00:38:15 +08:00
parent 1074284666
commit 136830bc2b
7 changed files with 38 additions and 54 deletions

View File

@@ -70,10 +70,3 @@ class HighFreqNorm(Processor):
columns=["FEATURE_%d" % i for i in range(12 * 240)],
).sort_index()
return df_new_features
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
if fit_start_time:
self.fit_start_time = fit_start_time
if fit_end_time:
self.fit_end_time = fit_end_time
super().config(**kwargs)

View File

@@ -27,7 +27,7 @@ from qlib.tests.data import GetData
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
class HighfreqWorkflow(object):
class HighfreqWorkflow:
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
@@ -177,8 +177,8 @@ class HighfreqWorkflow(object):
dataset_backtest.setup_data(handler_kwargs={})
##=============get data=============
(xtest,) = dataset.prepare(["test"])
(backtest_test,) = dataset_backtest.prepare(["test"])
xtest = dataset.prepare("test")
backtest_test = dataset_backtest.prepare("test")
print(xtest, backtest_test)
return

View File

@@ -14,7 +14,7 @@ from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.tests.data import GetData
class RollingDataWorkflow(object):
class RollingDataWorkflow:
MARKET = "csi300"
start_time = "2010-01-01"

View File

@@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
from inspect import getfullargspec
import pandas as pd
import numpy as np
@@ -16,7 +17,7 @@ class Dataset(Serializable):
Preparing data for model training and inferencing.
"""
def __init__(self, *args, **kwargs):
def __init__(self, **kwargs):
"""
init is designed to finish following steps:
@@ -28,16 +29,16 @@ class Dataset(Serializable):
The data could specify the info to caculate the essential data for preparation
"""
self.setup_data(*args, **kwargs)
self.setup_data(**kwargs)
super().__init__()
def config(self, *arg, **kwargs):
def config(self, **kwargs):
"""
config is designed to configure and parameters that cannot be learned from the data
"""
super().config(*arg, **kwargs)
super().config(**kwargs)
def setup_data(self, *args, **kwargs):
def setup_data(self, **kwargs):
"""
Setup the data.
@@ -53,7 +54,7 @@ class Dataset(Serializable):
"""
pass
def prepare(self, *args, **kwargs) -> object:
def prepare(self, **kwargs) -> object:
"""
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
The parameters should specify the scope for the prepared data
@@ -115,7 +116,7 @@ class DatasetH(Dataset):
self.segments = segments.copy()
super().__init__(**kwargs)
def config(self, handler_kwargs: dict = None, segments: dict = None, **kwargs):
def config(self, handler_kwargs: dict = None, **kwargs):
"""
Initialize the DatasetH
@@ -133,11 +134,11 @@ class DatasetH(Dataset):
Config of segments which is same as 'segments' in self.__init__
"""
super().config(**kwargs)
if handler_kwargs is not None:
self.handler.config(**handler_kwargs)
if segments is not None:
self.segments = segments.copy()
if "segments" in kwargs:
self.segments = deepcopy(kwargs.pop("segments"))
super().config(**kwargs)
def setup_data(self, handler_kwargs: dict = None, **kwargs):
"""
@@ -449,10 +450,10 @@ class TSDatasetH(DatasetH):
self.step_len = step_len
super().__init__(**kwargs)
def config(self, step_len=None, **kwargs):
def config(self, **kwargs):
if "step_len" in kwargs:
self.step_len = kwargs.pop("step_len")
super().config(**kwargs)
if step_len:
self.step_len = step_len
def setup_data(self, **kwargs):
super().setup_data(**kwargs)

View File

@@ -103,7 +103,7 @@ class DataHandler(Serializable):
self.setup_data()
super().__init__()
def config(self, instruments=None, start_time=None, end_time=None, **kwargs):
def config(self, **kwargs):
"""
configuration of data.
# what data to be loaded from data source
@@ -112,13 +112,16 @@ class DataHandler(Serializable):
The data will be initialized with different time range.
"""
attr_list = {"instruments", "start_time", "end_time"}
for k, v in kwargs.items():
if k in attr_list:
setattr(self, k, v)
for attr in attr_list:
if attr in kwargs:
kwargs.pop(attr)
super().config(**kwargs)
if instruments:
self.instruments = instruments
if start_time:
self.start_time = start_time
if end_time:
self.end_time = end_time
def setup_data(self, enable_cache: bool = False):
"""

View File

@@ -261,7 +261,7 @@ class DataLoaderDH(DataLoader):
self.is_group = is_group
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
self.fetch_kwargs = {**self.fetch_kwargs, **fetch_kwargs}
self.fetch_kwargs.update(fetch_kwargs)
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is not None:

View File

@@ -73,7 +73,15 @@ class Processor(Serializable):
return True
def config(self, **kwargs):
super().config(kwargs.get("dump_all", None), kwargs.get("exclude", None))
attr_list = {"fit_start_time", "fit_end_time"}
for k, v in kwargs.items():
if k in attr_list and getattr(self, k, None) is not None:
setattr(self, k, v)
for attr in attr_list:
if attr in kwargs:
kwargs.pop(attr)
super().config(**kwargs)
class DropnaProcessor(Processor):
@@ -195,13 +203,6 @@ class MinMaxNorm(Processor):
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
return df
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
if fit_start_time:
self.fit_start_time = fit_start_time
if fit_end_time:
self.fit_end_time = fit_end_time
super().config(**kwargs)
class ZScoreNorm(Processor):
"""ZScore Normalization"""
@@ -231,13 +232,6 @@ class ZScoreNorm(Processor):
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
return df
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
if fit_start_time:
self.fit_start_time = fit_start_time
if fit_end_time:
self.fit_end_time = fit_end_time
super().config(**kwargs)
class RobustZScoreNorm(Processor):
"""Robust ZScore Normalization
@@ -274,13 +268,6 @@ class RobustZScoreNorm(Processor):
df.clip(-3, 3, inplace=True)
return df
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
if fit_start_time:
self.fit_start_time = fit_start_time
if fit_end_time:
self.fit_end_time = fit_end_time
super().config(**kwargs)
class CSZScoreNorm(Processor):
"""Cross Sectional ZScore Normalization"""