From 136830bc2bf8281838d96c22fb0cdd45e93ae16b Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 30 Mar 2021 00:38:15 +0800 Subject: [PATCH] update comments --- examples/highfreq/highfreq_processor.py | 7 ----- examples/highfreq/workflow.py | 6 ++--- examples/rolling_process_data/workflow.py | 2 +- qlib/data/dataset/__init__.py | 27 ++++++++++---------- qlib/data/dataset/handler.py | 17 ++++++++----- qlib/data/dataset/loader.py | 2 +- qlib/data/dataset/processor.py | 31 +++++++---------------- 7 files changed, 38 insertions(+), 54 deletions(-) diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py index d843c6ac0..f0ab0dec2 100644 --- a/examples/highfreq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -70,10 +70,3 @@ class HighFreqNorm(Processor): columns=["FEATURE_%d" % i for i in range(12 * 240)], ).sort_index() return df_new_features - - def config(self, fit_start_time=None, fit_end_time=None, **kwargs): - if fit_start_time: - self.fit_start_time = fit_start_time - if fit_end_time: - self.fit_end_time = fit_end_time - super().config(**kwargs) diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 94c9b689f..5660ab2e9 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -27,7 +27,7 @@ from qlib.tests.data import GetData from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut -class HighfreqWorkflow(object): +class HighfreqWorkflow: SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} @@ -177,8 +177,8 @@ class HighfreqWorkflow(object): dataset_backtest.setup_data(handler_kwargs={}) ##=============get data============= - (xtest,) = dataset.prepare(["test"]) - (backtest_test,) = dataset_backtest.prepare(["test"]) + xtest = dataset.prepare("test") + backtest_test = dataset_backtest.prepare("test") print(xtest, backtest_test) return diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 02f43889d..5757aaa87 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -14,7 +14,7 @@ from qlib.utils import exists_qlib_data, init_instance_by_config from qlib.tests.data import GetData -class RollingDataWorkflow(object): +class RollingDataWorkflow: MARKET = "csi300" start_time = "2010-01-01" diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 668ea833b..b3eaac7a3 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional from ...utils import init_instance_by_config, np_ffill from ...log import get_module_logger from .handler import DataHandler, DataHandlerLP +from copy import deepcopy from inspect import getfullargspec import pandas as pd import numpy as np @@ -16,7 +17,7 @@ class Dataset(Serializable): Preparing data for model training and inferencing. """ - def __init__(self, *args, **kwargs): + def __init__(self, **kwargs): """ init is designed to finish following steps: @@ -28,16 +29,16 @@ class Dataset(Serializable): The data could specify the info to caculate the essential data for preparation """ - self.setup_data(*args, **kwargs) + self.setup_data(**kwargs) super().__init__() - def config(self, *arg, **kwargs): + def config(self, **kwargs): """ config is designed to configure and parameters that cannot be learned from the data """ - super().config(*arg, **kwargs) + super().config(**kwargs) - def setup_data(self, *args, **kwargs): + def setup_data(self, **kwargs): """ Setup the data. @@ -53,7 +54,7 @@ class Dataset(Serializable): """ pass - def prepare(self, *args, **kwargs) -> object: + def prepare(self, **kwargs) -> object: """ The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.) The parameters should specify the scope for the prepared data @@ -115,7 +116,7 @@ class DatasetH(Dataset): self.segments = segments.copy() super().__init__(**kwargs) - def config(self, handler_kwargs: dict = None, segments: dict = None, **kwargs): + def config(self, handler_kwargs: dict = None, **kwargs): """ Initialize the DatasetH @@ -133,11 +134,11 @@ class DatasetH(Dataset): Config of segments which is same as 'segments' in self.__init__ """ - super().config(**kwargs) if handler_kwargs is not None: self.handler.config(**handler_kwargs) - if segments is not None: - self.segments = segments.copy() + if "segments" in kwargs: + self.segments = deepcopy(kwargs.pop("segments")) + super().config(**kwargs) def setup_data(self, handler_kwargs: dict = None, **kwargs): """ @@ -449,10 +450,10 @@ class TSDatasetH(DatasetH): self.step_len = step_len super().__init__(**kwargs) - def config(self, step_len=None, **kwargs): + def config(self, **kwargs): + if "step_len" in kwargs: + self.step_len = kwargs.pop("step_len") super().config(**kwargs) - if step_len: - self.step_len = step_len def setup_data(self, **kwargs): super().setup_data(**kwargs) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 2190deeb1..7fb7090d2 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -103,7 +103,7 @@ class DataHandler(Serializable): self.setup_data() super().__init__() - def config(self, instruments=None, start_time=None, end_time=None, **kwargs): + def config(self, **kwargs): """ configuration of data. # what data to be loaded from data source @@ -112,13 +112,16 @@ class DataHandler(Serializable): The data will be initialized with different time range. """ + attr_list = {"instruments", "start_time", "end_time"} + for k, v in kwargs.items(): + if k in attr_list: + setattr(self, k, v) + + for attr in attr_list: + if attr in kwargs: + kwargs.pop(attr) + super().config(**kwargs) - if instruments: - self.instruments = instruments - if start_time: - self.start_time = start_time - if end_time: - self.end_time = end_time def setup_data(self, enable_cache: bool = False): """ diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 1cda5c025..58aca1d4f 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -261,7 +261,7 @@ class DataLoaderDH(DataLoader): self.is_group = is_group self.fetch_kwargs = {"col_set": DataHandler.CS_RAW} - self.fetch_kwargs = {**self.fetch_kwargs, **fetch_kwargs} + self.fetch_kwargs.update(fetch_kwargs) def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if instruments is not None: diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index d25d36c88..8f69a5dff 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -73,7 +73,15 @@ class Processor(Serializable): return True def config(self, **kwargs): - super().config(kwargs.get("dump_all", None), kwargs.get("exclude", None)) + attr_list = {"fit_start_time", "fit_end_time"} + for k, v in kwargs.items(): + if k in attr_list and getattr(self, k, None) is not None: + setattr(self, k, v) + + for attr in attr_list: + if attr in kwargs: + kwargs.pop(attr) + super().config(**kwargs) class DropnaProcessor(Processor): @@ -195,13 +203,6 @@ class MinMaxNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df - def config(self, fit_start_time=None, fit_end_time=None, **kwargs): - if fit_start_time: - self.fit_start_time = fit_start_time - if fit_end_time: - self.fit_end_time = fit_end_time - super().config(**kwargs) - class ZScoreNorm(Processor): """ZScore Normalization""" @@ -231,13 +232,6 @@ class ZScoreNorm(Processor): df.loc(axis=1)[self.cols] = normalize(df[self.cols].values) return df - def config(self, fit_start_time=None, fit_end_time=None, **kwargs): - if fit_start_time: - self.fit_start_time = fit_start_time - if fit_end_time: - self.fit_end_time = fit_end_time - super().config(**kwargs) - class RobustZScoreNorm(Processor): """Robust ZScore Normalization @@ -274,13 +268,6 @@ class RobustZScoreNorm(Processor): df.clip(-3, 3, inplace=True) return df - def config(self, fit_start_time=None, fit_end_time=None, **kwargs): - if fit_start_time: - self.fit_start_time = fit_start_time - if fit_end_time: - self.fit_end_time = fit_end_time - super().config(**kwargs) - class CSZScoreNorm(Processor): """Cross Sectional ZScore Normalization"""