mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
update comments
This commit is contained in:
@@ -70,10 +70,3 @@ class HighFreqNorm(Processor):
|
||||
columns=["FEATURE_%d" % i for i in range(12 * 240)],
|
||||
).sort_index()
|
||||
return df_new_features
|
||||
|
||||
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
|
||||
if fit_start_time:
|
||||
self.fit_start_time = fit_start_time
|
||||
if fit_end_time:
|
||||
self.fit_end_time = fit_end_time
|
||||
super().config(**kwargs)
|
||||
|
||||
@@ -27,7 +27,7 @@ from qlib.tests.data import GetData
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow(object):
|
||||
class HighfreqWorkflow:
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
@@ -177,8 +177,8 @@ class HighfreqWorkflow(object):
|
||||
dataset_backtest.setup_data(handler_kwargs={})
|
||||
|
||||
##=============get data=============
|
||||
(xtest,) = dataset.prepare(["test"])
|
||||
(backtest_test,) = dataset_backtest.prepare(["test"])
|
||||
xtest = dataset.prepare("test")
|
||||
backtest_test = dataset_backtest.prepare("test")
|
||||
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
@@ -14,7 +14,7 @@ from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class RollingDataWorkflow(object):
|
||||
class RollingDataWorkflow:
|
||||
|
||||
MARKET = "csi300"
|
||||
start_time = "2010-01-01"
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -16,7 +17,7 @@ class Dataset(Serializable):
|
||||
Preparing data for model training and inferencing.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
init is designed to finish following steps:
|
||||
|
||||
@@ -28,16 +29,16 @@ class Dataset(Serializable):
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
"""
|
||||
self.setup_data(*args, **kwargs)
|
||||
self.setup_data(**kwargs)
|
||||
super().__init__()
|
||||
|
||||
def config(self, *arg, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
config is designed to configure and parameters that cannot be learned from the data
|
||||
"""
|
||||
super().config(*arg, **kwargs)
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
def setup_data(self, **kwargs):
|
||||
"""
|
||||
Setup the data.
|
||||
|
||||
@@ -53,7 +54,7 @@ class Dataset(Serializable):
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, *args, **kwargs) -> object:
|
||||
def prepare(self, **kwargs) -> object:
|
||||
"""
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
The parameters should specify the scope for the prepared data
|
||||
@@ -115,7 +116,7 @@ class DatasetH(Dataset):
|
||||
self.segments = segments.copy()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, segments: dict = None, **kwargs):
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
@@ -133,11 +134,11 @@ class DatasetH(Dataset):
|
||||
Config of segments which is same as 'segments' in self.__init__
|
||||
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
if handler_kwargs is not None:
|
||||
self.handler.config(**handler_kwargs)
|
||||
if segments is not None:
|
||||
self.segments = segments.copy()
|
||||
if "segments" in kwargs:
|
||||
self.segments = deepcopy(kwargs.pop("segments"))
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
@@ -449,10 +450,10 @@ class TSDatasetH(DatasetH):
|
||||
self.step_len = step_len
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, step_len=None, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
if "step_len" in kwargs:
|
||||
self.step_len = kwargs.pop("step_len")
|
||||
super().config(**kwargs)
|
||||
if step_len:
|
||||
self.step_len = step_len
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
@@ -103,7 +103,7 @@ class DataHandler(Serializable):
|
||||
self.setup_data()
|
||||
super().__init__()
|
||||
|
||||
def config(self, instruments=None, start_time=None, end_time=None, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
@@ -112,13 +112,16 @@ class DataHandler(Serializable):
|
||||
The data will be initialized with different time range.
|
||||
|
||||
"""
|
||||
attr_list = {"instruments", "start_time", "end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
setattr(self, k, v)
|
||||
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
|
||||
super().config(**kwargs)
|
||||
if instruments:
|
||||
self.instruments = instruments
|
||||
if start_time:
|
||||
self.start_time = start_time
|
||||
if end_time:
|
||||
self.end_time = end_time
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
|
||||
@@ -261,7 +261,7 @@ class DataLoaderDH(DataLoader):
|
||||
|
||||
self.is_group = is_group
|
||||
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
|
||||
self.fetch_kwargs = {**self.fetch_kwargs, **fetch_kwargs}
|
||||
self.fetch_kwargs.update(fetch_kwargs)
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is not None:
|
||||
|
||||
@@ -73,7 +73,15 @@ class Processor(Serializable):
|
||||
return True
|
||||
|
||||
def config(self, **kwargs):
|
||||
super().config(kwargs.get("dump_all", None), kwargs.get("exclude", None))
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list and getattr(self, k, None) is not None:
|
||||
setattr(self, k, v)
|
||||
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class DropnaProcessor(Processor):
|
||||
@@ -195,13 +203,6 @@ class MinMaxNorm(Processor):
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
|
||||
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
|
||||
if fit_start_time:
|
||||
self.fit_start_time = fit_start_time
|
||||
if fit_end_time:
|
||||
self.fit_end_time = fit_end_time
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class ZScoreNorm(Processor):
|
||||
"""ZScore Normalization"""
|
||||
@@ -231,13 +232,6 @@ class ZScoreNorm(Processor):
|
||||
df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
|
||||
return df
|
||||
|
||||
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
|
||||
if fit_start_time:
|
||||
self.fit_start_time = fit_start_time
|
||||
if fit_end_time:
|
||||
self.fit_end_time = fit_end_time
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class RobustZScoreNorm(Processor):
|
||||
"""Robust ZScore Normalization
|
||||
@@ -274,13 +268,6 @@ class RobustZScoreNorm(Processor):
|
||||
df.clip(-3, 3, inplace=True)
|
||||
return df
|
||||
|
||||
def config(self, fit_start_time=None, fit_end_time=None, **kwargs):
|
||||
if fit_start_time:
|
||||
self.fit_start_time = fit_start_time
|
||||
if fit_end_time:
|
||||
self.fit_end_time = fit_end_time
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class CSZScoreNorm(Processor):
|
||||
"""Cross Sectional ZScore Normalization"""
|
||||
|
||||
Reference in New Issue
Block a user