mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Merge remote-tracking branch 'microsoft/qlib/main' into online_srv
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -16,22 +17,28 @@ class Dataset(Serializable):
|
||||
Preparing data for model training and inferencing.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
init is designed to finish following steps:
|
||||
|
||||
- init the sub instance and the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
|
||||
|
||||
- initialize the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
"""
|
||||
self.setup_data(*args, **kwargs)
|
||||
self.setup_data(**kwargs)
|
||||
super().__init__()
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
config is designed to configure and parameters that cannot be learned from the data
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
"""
|
||||
Setup the data.
|
||||
|
||||
@@ -39,7 +46,7 @@ class Dataset(Serializable):
|
||||
|
||||
- User have a Dataset object with learned status on disk.
|
||||
|
||||
- User load the Dataset object from the disk(Note the init function is skiped).
|
||||
- User load the Dataset object from the disk.
|
||||
|
||||
- User call `setup_data` to load new data.
|
||||
|
||||
@@ -47,7 +54,7 @@ class Dataset(Serializable):
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, *args, **kwargs) -> object:
|
||||
def prepare(self, **kwargs) -> object:
|
||||
"""
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
The parameters should specify the scope for the prepared data
|
||||
@@ -76,44 +83,7 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
- arguments of DataHandler.init, such as 'enable_cache', etc.
|
||||
|
||||
segment_kwargs : dict
|
||||
Config of segments which is same as 'segments' in DatasetH.setup_data
|
||||
|
||||
"""
|
||||
if handler_kwargs:
|
||||
if not isinstance(handler_kwargs, dict):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}")
|
||||
kwargs_init = {}
|
||||
kwargs_conf_data = {}
|
||||
conf_data_arg = {"instruments", "start_time", "end_time"}
|
||||
for k, v in handler_kwargs.items():
|
||||
if k in conf_data_arg:
|
||||
kwargs_conf_data.update({k: v})
|
||||
else:
|
||||
kwargs_init.update({k: v})
|
||||
|
||||
self.handler.conf_data(**kwargs_conf_data)
|
||||
self.handler.init(**kwargs_init)
|
||||
|
||||
if segment_kwargs:
|
||||
if not isinstance(segment_kwargs, dict):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
|
||||
self.segments = segment_kwargs.copy()
|
||||
|
||||
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
|
||||
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -144,6 +114,49 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
kwargs : dict
|
||||
Config of DatasetH, such as
|
||||
|
||||
- segments : dict
|
||||
Config of segments which is same as 'segments' in self.__init__
|
||||
|
||||
"""
|
||||
if handler_kwargs is not None:
|
||||
self.handler.config(**handler_kwargs)
|
||||
if "segments" in kwargs:
|
||||
self.segments = deepcopy(kwargs.pop("segments"))
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Setup the Data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
init arguments of DataHanlder, which could include the following arguments:
|
||||
|
||||
- init_type : Init Type of Handler
|
||||
|
||||
- enable_cache : wheter to enable cache
|
||||
|
||||
"""
|
||||
super().setup_data(**kwargs)
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
@@ -433,15 +446,19 @@ class TSDatasetH(DatasetH):
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, *args, **kwargs):
|
||||
def __init__(self, step_len=30, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
super().setup_data(*args, **kwargs)
|
||||
def config(self, **kwargs):
|
||||
if "step_len" in kwargs:
|
||||
self.step_len = kwargs.pop("step_len")
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
# Get the datatime index for building timestamp
|
||||
self.cal = cal
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
|
||||
@@ -6,6 +6,7 @@ import abc
|
||||
import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from inspect import getfullargspec
|
||||
from typing import Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
@@ -16,7 +17,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import get_level_index, fetch_df_by_index
|
||||
from .utils import fetch_df_by_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -102,10 +103,10 @@ class DataHandler(Serializable):
|
||||
self.fetch_orig = fetch_orig
|
||||
if init_data:
|
||||
with TimeInspector.logt("Init data"):
|
||||
self.init()
|
||||
self.setup_data()
|
||||
super().__init__()
|
||||
|
||||
def conf_data(self, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
@@ -118,13 +119,16 @@ class DataHandler(Serializable):
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise KeyError("Such config is not supported.")
|
||||
|
||||
def init(self, enable_cache: bool = False):
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
Set Up the data in case of running intialization for multiple time
|
||||
|
||||
It is responsible for maintaining following variable
|
||||
1) self._data
|
||||
@@ -412,14 +416,28 @@ class DataHandlerLP(DataHandler):
|
||||
if self.drop_raw:
|
||||
del self._data
|
||||
|
||||
def config(self, processor_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
|
||||
This method will be used when loading pickled handler from dataset.
|
||||
The data will be initialized with different time range.
|
||||
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
if processor_kwargs is not None:
|
||||
for processor in self.get_all_processors():
|
||||
processor.config(**processor_kwargs)
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
|
||||
IT_LS = "load_state" # The state of the object has been load by pickle
|
||||
|
||||
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
|
||||
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
Set up the data in case of running intialization for multiple time
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -434,7 +452,7 @@ class DataHandlerLP(DataHandler):
|
||||
when we call `init` next time
|
||||
"""
|
||||
# init raw data
|
||||
super().init(enable_cache=enable_cache)
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
with TimeInspector.logt("fit & process data"):
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
|
||||
@@ -217,3 +217,64 @@ class StaticDataLoader(DataLoader):
|
||||
join=self.join,
|
||||
)
|
||||
self._data.sort_index(inplace=True)
|
||||
|
||||
|
||||
class DataLoaderDH(DataLoader):
|
||||
"""DataLoaderDH
|
||||
DataLoader based on (D)ata (H)andler
|
||||
It is designed to load multiple data from data handler
|
||||
- If you just want to load data from single datahandler, you can write them in single data handler
|
||||
"""
|
||||
|
||||
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler_config : dict
|
||||
handler_config will be used to describe the handlers
|
||||
|
||||
.. code-block::
|
||||
|
||||
<handler_config> := {
|
||||
"group_name1": <handler>
|
||||
"group_name2": <handler>
|
||||
}
|
||||
or
|
||||
<handler_config> := <handler>
|
||||
<handler> := DataHandler Instance | DataHandler Config
|
||||
|
||||
fetch_kwargs : dict
|
||||
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
|
||||
|
||||
is_group: bool
|
||||
is_group will be used to describe whether the key of handler_config is group
|
||||
|
||||
"""
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
|
||||
if is_group:
|
||||
self.handlers = {
|
||||
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
|
||||
}
|
||||
else:
|
||||
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
|
||||
|
||||
self.is_group = is_group
|
||||
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
|
||||
self.fetch_kwargs.update(fetch_kwargs)
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is not None:
|
||||
LOG.warning(f"instruments[{instruments}] is ignored")
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
for grp, dh in self.handlers.items()
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
return df
|
||||
|
||||
11
qlib/data/dataset/processor.py
Executable file → Normal file
11
qlib/data/dataset/processor.py
Executable file → Normal file
@@ -72,6 +72,17 @@ class Processor(Serializable):
|
||||
"""
|
||||
return True
|
||||
|
||||
def config(self, **kwargs):
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list and hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class DropnaProcessor(Processor):
|
||||
def __init__(self, fields_group=None):
|
||||
|
||||
Reference in New Issue
Block a user