diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index b3eaac7a3..ef30c634e 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -112,7 +112,7 @@ class DatasetH(Dataset): 'outsample': ("2017-01-01", "2020-08-01",), } """ - self.handler = init_instance_by_config(handler, accept_types=DataHandler) + self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() super().__init__(**kwargs) @@ -243,7 +243,7 @@ class TSDataSampler: """ - def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"): + def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None): """ Build a dataset which looks like torch.data.utils.Dataset. @@ -272,9 +272,18 @@ class TSDataSampler: self.fillna_type = fillna_type assert get_level_index(data, "datetime") == 0 self.data = lazy_sort_index(data) - self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! - # NOTE: append last line with full NaN for better performance in `__getitem__` - self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0) + + kwargs = {"object": self.data} + if dtype is not None: + kwargs["dtype"] = dtype + + self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values! + # NOTE: + # - append last line with full NaN for better performance in `__getitem__` + # - Keep the same dtype will result in a better performance + self.data_arr = np.append( + self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0 + ) self.nan_idx = -1 # The last line is all NaN # the data type will be changed @@ -282,13 +291,16 @@ class TSDataSampler: self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_df, self.idx_map = self.build_index(self.data) self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance + self.data_idx = deepcopy(self.data.index) + + del self.data # save memory def get_index(self): """ Get the pandas index of the data, it will be useful in following scenarios - Special sampler will be used (e.g. user want to sample day by day) """ - return self.data.index[self.start_idx : self.end_idx] + return self.data_idx[self.start_idx : self.end_idx] def config(self, **kwargs): # Config the attributes @@ -461,7 +473,7 @@ class TSDatasetH(DatasetH): cal = sorted(cal) self.cal = cal - def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: + def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame: # Dataset decide how to slice data(Get more data for timeseries). start, end = slc.start, slc.stop start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start)) @@ -470,6 +482,14 @@ class TSDatasetH(DatasetH): # TSDatasetH will retrieve more data for complete data = super()._prepare_seg(slice(pad_start, end), **kwargs) + return data - tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len) + def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: + """ + split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data + """ + dtype = kwargs.pop("dtype") + start, end = slc.start, slc.stop + data = self._prepare_raw_seg(slc=slc, **kwargs) + tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype) return tsds diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 3c321ed9e..f1fa39c3b 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -7,7 +7,7 @@ import bisect import logging import warnings from inspect import getfullargspec -from typing import Union, Tuple, List, Iterator, Optional +from typing import Callable, Union, Tuple, List, Iterator, Optional import pandas as pd import numpy as np @@ -166,6 +166,7 @@ class DataHandler(Serializable): level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = CS_ALL, squeeze: bool = False, + proc_func: Callable = None, ) -> pd.DataFrame: """ fetch data from underlying data source @@ -188,6 +189,14 @@ class DataHandler(Serializable): - if isinstance(col_set, List[str]): select several sets of meaningful columns, the returned data has multiple levels + proc_func: Callable + - Give a hook for processing data before fetching + - An example to explain the necessity of the hook: + - A Dataset learned some processors to process data which is related to data segmentation + - It will apply them every time when preparing data. + - The learned processor require the dataframe remains the same format when fitting and applying + - However the data format will change according to the parameters. + - So the processors should be applied to the underlayer data. squeeze : bool whether squeeze columns and index @@ -196,8 +205,15 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ + if proc_func is None: + df = self._data + else: + # FIXME: fetching by time first will be more friendly to `proc_func` + # Copy in case of `proc_func` changing the data inplace.... + df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy()) + # Fetch column first will be more friendly to SepDataFrame - df = self._fetch_df_by_col(self._data, col_set) + df = self._fetch_df_by_col(df, col_set) df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) if squeeze: # squeeze columns @@ -481,6 +497,7 @@ class DataHandlerLP(DataHandler): level: Union[str, int] = "datetime", col_set=DataHandler.CS_ALL, data_key: str = DK_I, + proc_func: Callable = None, ) -> pd.DataFrame: """ fetch data from underlying data source @@ -495,12 +512,18 @@ class DataHandlerLP(DataHandler): select a set of meaningful columns.(e.g. features, columns). data_key : str the data to fetch: DK_*. + proc_func: Callable + please refer to the doc of DataHandler.fetch Returns ------- pd.DataFrame: """ df = self._get_df_by_key(data_key) + if proc_func is not None: + # FIXME: fetch by time first will be more friendly to proc_func + # Copy incase of `proc_func` changing the data inplace.... + df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy()) # Fetch column first will be more friendly to SepDataFrame df = self._fetch_df_by_col(df, col_set) return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 58aca1d4f..2ad110b89 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -13,6 +13,7 @@ from qlib.data import D from qlib.data import filter as filter_module from qlib.data.filter import BaseDFilter from qlib.utils import load_dataset, init_instance_by_config +from qlib.log import get_module_logger class DataLoader(abc.ABC): @@ -224,6 +225,10 @@ class DataLoaderDH(DataLoader): DataLoader based on (D)ata (H)andler It is designed to load multiple data from data handler - If you just want to load data from single datahandler, you can write them in single data handler + + TODO: What make this module not that easy to use. + - For online scenario + - The underlayer data handler should be configured. But data loader doesn't provide such interface & hook. """ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False): @@ -265,7 +270,7 @@ class DataLoaderDH(DataLoader): def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if instruments is not None: - LOG.warning(f"instruments[{instruments}] is ignored") + get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored") if self.is_group: df = pd.concat( diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 6e1ae8b9d..7ffca20ee 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -6,6 +6,8 @@ from qlib.workflow import R from qlib.workflow.recorder import Recorder from qlib.workflow.record_temp import SignalRecord from qlib.workflow.task.manage import TaskManager, run_task +from qlib.data.dataset import Dataset +from qlib.model.base import Model def task_train(task_config: dict, experiment_name: str) -> Recorder: @@ -25,8 +27,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder: """ # model initiaiton - model = init_instance_by_config(task_config["model"]) - dataset = init_instance_by_config(task_config["dataset"]) + model: Model = init_instance_by_config(task_config["model"]) + dataset: Dataset = init_instance_by_config(task_config["dataset"]) # start exp with R.start(experiment_name=experiment_name): @@ -37,6 +39,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder: recorder = R.get_recorder() R.save_objects(**{"params.pkl": model}) R.save_objects(**{"task": task_config}) # keep the original format and datatype + # This dataset is saved for online inference. So the concrete data should not be dumped + dataset.config(dump_all=False, recursive=True) R.save_objects(**{"dataset": dataset}) # generate records: prediction, backtest, and analysis diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 2a34035f3..7e71ba76c 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function import os +import pickle import re import copy import json @@ -26,6 +27,7 @@ import pandas as pd from pathlib import Path from typing import Union, Tuple, Any, Text, Optional from types import ModuleType +from urllib.parse import urlparse from ..config import C from ..log import get_module_logger, set_log_with_config @@ -235,7 +237,10 @@ def init_instance_by_config( 'model_path': path, # It is optional if module is given } str example. - "ClassName": getattr(module, config)() will be used. + 1) specify a pickle object + - path like 'file:////obj.pkl' + 2) specify a class name + - "ClassName": getattr(module, config)() will be used. object example: instance of accept_types default_module : Python module @@ -257,6 +262,13 @@ def init_instance_by_config( if isinstance(config, accept_types): return config + if isinstance(config, str): + # path like 'file:////obj.pkl' + pr = urlparse(config) + if pr.scheme == "file": + with open(os.path.join(pr.netloc, pr.path), "rb") as f: + return pickle.load(f) + klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module) return klass(**cls_kwargs, **kwargs) diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 4bc57eccd..b94be464b 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -33,16 +33,40 @@ class Serializable: @property def exclude(self): """ - What attribute will be dumped + What attribute will not be dumped """ return getattr(self, "_exclude", []) - def config(self, dump_all: bool = None, exclude: list = None): - if dump_all is not None: - self._dump_all = dump_all + FLAG_KEY = "_qlib_serial_flag" - if exclude is not None: - self._exclude = exclude + def config(self, dump_all: bool = None, exclude: list = None, recursive=False): + """ + configure the serializable object + + Parameters + ---------- + dump_all : bool + will the object dump all object + exclude : list + What attribute will not be dumped + recursive : bool + will the configuration be recursive + """ + + params = {"dump_all": dump_all, "exclude": exclude} + + for k, v in params.items(): + if v is not None: + attr_name = f"_{k}" + setattr(self, attr_name, v) + + if recursive: + for obj in self.__dict__.values(): + # set flag to prevent endless loop + self.__dict__[self.FLAG_KEY] = True + if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__: + obj.config(**params, recursive=True) + del self.__dict__[self.FLAG_KEY] def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None): self.config(dump_all=dump_all, exclude=exclude) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 808aca302..324b790ac 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -186,6 +186,9 @@ class SigAnaRecord(SignalRecord): pred = self.load("pred.pkl") label = self.load("label.pkl") + if label is None or not isinstance(label, pd.DataFrame) or label.empty: + logger.warn(f"Empty label.") + return ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0]) metrics = { "IC": ic.mean(),