1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 18:40:58 +08:00

Update features for hyb nn

This commit is contained in:
Young
2021-04-09 13:48:01 +00:00
parent 18bf4b5477
commit a366c11d67
7 changed files with 111 additions and 20 deletions

View File

@@ -112,7 +112,7 @@ class DatasetH(Dataset):
'outsample': ("2017-01-01", "2020-08-01",),
}
"""
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
super().__init__(**kwargs)
@@ -243,7 +243,7 @@ class TSDataSampler:
"""
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None):
"""
Build a dataset which looks like torch.data.utils.Dataset.
@@ -272,9 +272,18 @@ class TSDataSampler:
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
# NOTE: append last line with full NaN for better performance in `__getitem__`
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
kwargs = {"object": self.data}
if dtype is not None:
kwargs["dtype"] = dtype
self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
# NOTE:
# - append last line with full NaN for better performance in `__getitem__`
# - Keep the same dtype will result in a better performance
self.data_arr = np.append(
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
)
self.nan_idx = -1 # The last line is all NaN
# the data type will be changed
@@ -282,13 +291,16 @@ class TSDataSampler:
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_df, self.idx_map = self.build_index(self.data)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
self.data_idx = deepcopy(self.data.index)
del self.data # save memory
def get_index(self):
"""
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
return self.data.index[self.start_idx : self.end_idx]
return self.data_idx[self.start_idx : self.end_idx]
def config(self, **kwargs):
# Config the attributes
@@ -461,7 +473,7 @@ class TSDatasetH(DatasetH):
cal = sorted(cal)
self.cal = cal
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
@@ -470,6 +482,14 @@ class TSDatasetH(DatasetH):
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
return data
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
"""
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
"""
dtype = kwargs.pop("dtype")
start, end = slc.start, slc.stop
data = self._prepare_raw_seg(slc=slc, **kwargs)
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype)
return tsds

View File

@@ -7,7 +7,7 @@ import bisect
import logging
import warnings
from inspect import getfullargspec
from typing import Union, Tuple, List, Iterator, Optional
from typing import Callable, Union, Tuple, List, Iterator, Optional
import pandas as pd
import numpy as np
@@ -166,6 +166,7 @@ class DataHandler(Serializable):
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -188,6 +189,14 @@ class DataHandler(Serializable):
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple levels
proc_func: Callable
- Give a hook for processing data before fetching
- An example to explain the necessity of the hook:
- A Dataset learned some processors to process data which is related to data segmentation
- It will apply them every time when preparing data.
- The learned processor require the dataframe remains the same format when fitting and applying
- However the data format will change according to the parameters.
- So the processors should be applied to the underlayer data.
squeeze : bool
whether squeeze columns and index
@@ -196,8 +205,15 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
if proc_func is None:
df = self._data
else:
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(self._data, col_set)
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
@@ -481,6 +497,7 @@ class DataHandlerLP(DataHandler):
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -495,12 +512,18 @@ class DataHandlerLP(DataHandler):
select a set of meaningful columns.(e.g. features, columns).
data_key : str
the data to fetch: DK_*.
proc_func: Callable
please refer to the doc of DataHandler.fetch
Returns
-------
pd.DataFrame:
"""
df = self._get_df_by_key(data_key)
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)

View File

@@ -13,6 +13,7 @@ from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config
from qlib.log import get_module_logger
class DataLoader(abc.ABC):
@@ -224,6 +225,10 @@ class DataLoaderDH(DataLoader):
DataLoader based on (D)ata (H)andler
It is designed to load multiple data from data handler
- If you just want to load data from single datahandler, you can write them in single data handler
TODO: What make this module not that easy to use.
- For online scenario
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
"""
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
@@ -265,7 +270,7 @@ class DataLoaderDH(DataLoader):
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is not None:
LOG.warning(f"instruments[{instruments}] is ignored")
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
if self.is_group:
df = pd.concat(

View File

@@ -6,6 +6,8 @@ from qlib.workflow import R
from qlib.workflow.recorder import Recorder
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.data.dataset import Dataset
from qlib.model.base import Model
def task_train(task_config: dict, experiment_name: str) -> Recorder:
@@ -25,8 +27,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
"""
# model initiaiton
model = init_instance_by_config(task_config["model"])
dataset = init_instance_by_config(task_config["dataset"])
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# start exp
with R.start(experiment_name=experiment_name):
@@ -37,6 +39,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
R.save_objects(**{"task": task_config}) # keep the original format and datatype
# This dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis

View File

@@ -6,6 +6,7 @@ from __future__ import division
from __future__ import print_function
import os
import pickle
import re
import copy
import json
@@ -26,6 +27,7 @@ import pandas as pd
from pathlib import Path
from typing import Union, Tuple, Any, Text, Optional
from types import ModuleType
from urllib.parse import urlparse
from ..config import C
from ..log import get_module_logger, set_log_with_config
@@ -235,7 +237,10 @@ def init_instance_by_config(
'model_path': path, # It is optional if module is given
}
str example.
"ClassName": getattr(module, config)() will be used.
1) specify a pickle object
- path like 'file:///<path to pickle file>/obj.pkl'
2) specify a class name
- "ClassName": getattr(module, config)() will be used.
object example:
instance of accept_types
default_module : Python module
@@ -257,6 +262,13 @@ def init_instance_by_config(
if isinstance(config, accept_types):
return config
if isinstance(config, str):
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
return pickle.load(f)
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)

View File

@@ -33,16 +33,40 @@ class Serializable:
@property
def exclude(self):
"""
What attribute will be dumped
What attribute will not be dumped
"""
return getattr(self, "_exclude", [])
def config(self, dump_all: bool = None, exclude: list = None):
if dump_all is not None:
self._dump_all = dump_all
FLAG_KEY = "_qlib_serial_flag"
if exclude is not None:
self._exclude = exclude
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
"""
configure the serializable object
Parameters
----------
dump_all : bool
will the object dump all object
exclude : list
What attribute will not be dumped
recursive : bool
will the configuration be recursive
"""
params = {"dump_all": dump_all, "exclude": exclude}
for k, v in params.items():
if v is not None:
attr_name = f"_{k}"
setattr(self, attr_name, v)
if recursive:
for obj in self.__dict__.values():
# set flag to prevent endless loop
self.__dict__[self.FLAG_KEY] = True
if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:
obj.config(**params, recursive=True)
del self.__dict__[self.FLAG_KEY]
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
self.config(dump_all=dump_all, exclude=exclude)

View File

@@ -186,6 +186,9 @@ class SigAnaRecord(SignalRecord):
pred = self.load("pred.pkl")
label = self.load("label.pkl")
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
logger.warn(f"Empty label.")
return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
metrics = {
"IC": ic.mean(),