mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
Update features for hyb nn
This commit is contained in:
@@ -112,7 +112,7 @@ class DatasetH(Dataset):
|
||||
'outsample': ("2017-01-01", "2020-08-01",),
|
||||
}
|
||||
"""
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -243,7 +243,7 @@ class TSDataSampler:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
|
||||
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None):
|
||||
"""
|
||||
Build a dataset which looks like torch.data.utils.Dataset.
|
||||
|
||||
@@ -272,9 +272,18 @@ class TSDataSampler:
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE: append last line with full NaN for better performance in `__getitem__`
|
||||
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
|
||||
|
||||
kwargs = {"object": self.data}
|
||||
if dtype is not None:
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE:
|
||||
# - append last line with full NaN for better performance in `__getitem__`
|
||||
# - Keep the same dtype will result in a better performance
|
||||
self.data_arr = np.append(
|
||||
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
|
||||
# the data type will be changed
|
||||
@@ -282,13 +291,16 @@ class TSDataSampler:
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
self.data_idx = deepcopy(self.data.index)
|
||||
|
||||
del self.data # save memory
|
||||
|
||||
def get_index(self):
|
||||
"""
|
||||
Get the pandas index of the data, it will be useful in following scenarios
|
||||
- Special sampler will be used (e.g. user want to sample day by day)
|
||||
"""
|
||||
return self.data.index[self.start_idx : self.end_idx]
|
||||
return self.data_idx[self.start_idx : self.end_idx]
|
||||
|
||||
def config(self, **kwargs):
|
||||
# Config the attributes
|
||||
@@ -461,7 +473,7 @@ class TSDatasetH(DatasetH):
|
||||
cal = sorted(cal)
|
||||
self.cal = cal
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
start, end = slc.start, slc.stop
|
||||
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
|
||||
@@ -470,6 +482,14 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
|
||||
return data
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
|
||||
"""
|
||||
dtype = kwargs.pop("dtype")
|
||||
start, end = slc.start, slc.stop
|
||||
data = self._prepare_raw_seg(slc=slc, **kwargs)
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype)
|
||||
return tsds
|
||||
|
||||
@@ -7,7 +7,7 @@ import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from inspect import getfullargspec
|
||||
from typing import Union, Tuple, List, Iterator, Optional
|
||||
from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -166,6 +166,7 @@ class DataHandler(Serializable):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -188,6 +189,14 @@ class DataHandler(Serializable):
|
||||
- if isinstance(col_set, List[str]):
|
||||
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
proc_func: Callable
|
||||
- Give a hook for processing data before fetching
|
||||
- An example to explain the necessity of the hook:
|
||||
- A Dataset learned some processors to process data which is related to data segmentation
|
||||
- It will apply them every time when preparing data.
|
||||
- The learned processor require the dataframe remains the same format when fitting and applying
|
||||
- However the data format will change according to the parameters.
|
||||
- So the processors should be applied to the underlayer data.
|
||||
|
||||
squeeze : bool
|
||||
whether squeeze columns and index
|
||||
@@ -196,8 +205,15 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
if proc_func is None:
|
||||
df = self._data
|
||||
else:
|
||||
# FIXME: fetching by time first will be more friendly to `proc_func`
|
||||
# Copy in case of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(self._data, col_set)
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
@@ -481,6 +497,7 @@ class DataHandlerLP(DataHandler):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -495,12 +512,18 @@ class DataHandlerLP(DataHandler):
|
||||
select a set of meaningful columns.(e.g. features, columns).
|
||||
data_key : str
|
||||
the data to fetch: DK_*.
|
||||
proc_func: Callable
|
||||
please refer to the doc of DataHandler.fetch
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._get_df_by_key(data_key)
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
@@ -13,6 +13,7 @@ from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class DataLoader(abc.ABC):
|
||||
@@ -224,6 +225,10 @@ class DataLoaderDH(DataLoader):
|
||||
DataLoader based on (D)ata (H)andler
|
||||
It is designed to load multiple data from data handler
|
||||
- If you just want to load data from single datahandler, you can write them in single data handler
|
||||
|
||||
TODO: What make this module not that easy to use.
|
||||
- For online scenario
|
||||
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
|
||||
"""
|
||||
|
||||
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
|
||||
@@ -265,7 +270,7 @@ class DataLoaderDH(DataLoader):
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is not None:
|
||||
LOG.warning(f"instruments[{instruments}] is ignored")
|
||||
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
|
||||
@@ -6,6 +6,8 @@ from qlib.workflow import R
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.model.base import Model
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
@@ -25,8 +27,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task_config["model"])
|
||||
dataset = init_instance_by_config(task_config["dataset"])
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name=experiment_name):
|
||||
@@ -37,6 +39,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
# This dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import copy
|
||||
import json
|
||||
@@ -26,6 +27,7 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple, Any, Text, Optional
|
||||
from types import ModuleType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
@@ -235,7 +237,10 @@ def init_instance_by_config(
|
||||
'model_path': path, # It is optional if module is given
|
||||
}
|
||||
str example.
|
||||
"ClassName": getattr(module, config)() will be used.
|
||||
1) specify a pickle object
|
||||
- path like 'file:///<path to pickle file>/obj.pkl'
|
||||
2) specify a class name
|
||||
- "ClassName": getattr(module, config)() will be used.
|
||||
object example:
|
||||
instance of accept_types
|
||||
default_module : Python module
|
||||
@@ -257,6 +262,13 @@ def init_instance_by_config(
|
||||
if isinstance(config, accept_types):
|
||||
return config
|
||||
|
||||
if isinstance(config, str):
|
||||
# path like 'file:///<path to pickle file>/obj.pkl'
|
||||
pr = urlparse(config)
|
||||
if pr.scheme == "file":
|
||||
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
@@ -33,16 +33,40 @@ class Serializable:
|
||||
@property
|
||||
def exclude(self):
|
||||
"""
|
||||
What attribute will be dumped
|
||||
What attribute will not be dumped
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
def config(self, dump_all: bool = None, exclude: list = None):
|
||||
if dump_all is not None:
|
||||
self._dump_all = dump_all
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
if exclude is not None:
|
||||
self._exclude = exclude
|
||||
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
|
||||
"""
|
||||
configure the serializable object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dump_all : bool
|
||||
will the object dump all object
|
||||
exclude : list
|
||||
What attribute will not be dumped
|
||||
recursive : bool
|
||||
will the configuration be recursive
|
||||
"""
|
||||
|
||||
params = {"dump_all": dump_all, "exclude": exclude}
|
||||
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
attr_name = f"_{k}"
|
||||
setattr(self, attr_name, v)
|
||||
|
||||
if recursive:
|
||||
for obj in self.__dict__.values():
|
||||
# set flag to prevent endless loop
|
||||
self.__dict__[self.FLAG_KEY] = True
|
||||
if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:
|
||||
obj.config(**params, recursive=True)
|
||||
del self.__dict__[self.FLAG_KEY]
|
||||
|
||||
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
|
||||
@@ -186,6 +186,9 @@ class SigAnaRecord(SignalRecord):
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
|
||||
logger.warn(f"Empty label.")
|
||||
return
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
|
||||
Reference in New Issue
Block a user