1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

move freq params to dataloader

This commit is contained in:
Young
2021-01-31 13:34:57 +00:00
committed by you-n-g
parent bdc70c192a
commit 802dac81c9
6 changed files with 51 additions and 31 deletions

View File

@@ -10,7 +10,6 @@ class HighFreqHandler(DataHandlerLP):
instruments="csi300",
start_time=None,
end_time=None,
freq="1min",
infer_processors=[],
learn_processors=[],
fit_start_time=None,
@@ -37,13 +36,13 @@ class HighFreqHandler(DataHandlerLP):
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
freq=freq,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
@@ -124,20 +123,19 @@ class HighFreqBacktestHandler(DataHandler):
instruments="csi300",
start_time=None,
end_time=None,
freq="1min",
):
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
freq=freq,
data_loader=data_loader,
)

View File

@@ -90,7 +90,6 @@ _default_config = {
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
"default_disk_cache": 1, # 0:skip/1:use
"disable_disk_cache": False, # disable disk cache; if High-frequency data generally disable_disk_cache=True
"mem_cache_size_limit": 500,
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'
# default 1 hour

View File

@@ -961,8 +961,7 @@ class BaseProvider:
is a provider class.
"""
disk_cache = C.default_disk_cache if disk_cache is None else disk_cache
if C.disable_disk_cache:
disk_cache = False
fields = list(fields) # In case of tuple.
try:
return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache)
except TypeError:

View File

@@ -57,10 +57,10 @@ class DataHandler(Serializable):
instruments=None,
start_time=None,
end_time=None,
freq="day",
data_loader: Tuple[dict, str, DataLoader] = None,
init_data=True,
fetch_orig=True,
**kwargs,
):
"""
Parameters
@@ -71,14 +71,14 @@ class DataHandler(Serializable):
start_time of the original data.
end_time :
end_time of the original data.
freq :
frequency of data
data_loader : Tuple[dict, str, DataLoader]
data loader to load the data.
init_data :
intialize the original data in the constructor.
fetch_orig : bool
Return the original data instead of copy if possible.
**kwargs:
it will be passed into data_loader
"""
# Set logger
self.logger = get_module_logger("DataHandler")
@@ -86,23 +86,43 @@ class DataHandler(Serializable):
# Setup data loader
assert data_loader is not None # to make start_time end_time could have None default value
# what data source to load data
self.data_loader = init_instance_by_config(
data_loader,
None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module,
accept_types=DataLoader,
**kwargs,
)
# what data to be loaded from data source
# For IDE auto-completion.
self.instruments = instruments
self.start_time = start_time
self.end_time = end_time
self.freq = freq
self.fetch_orig = fetch_orig
if init_data:
with TimeInspector.logt("Init data"):
self.init()
super().__init__()
def init(self, enable_cache: bool = True):
def conf_data(self, **kwargs):
"""
configuration of data.
# what data to be loaded from data source
This method will be used when loading pickled handler from dataset.
The data will be initialized with different time range.
"""
attr_list = {"instruments", "start_time", "end_time"}
for k, v in kwargs.items():
if k in attr_list:
setattr(self, k, v)
else:
raise KeyError("Such config is not supported.")
def init(self, enable_cache: bool = False):
"""
initialize the data.
In case of running intialization for multiple time, it will do nothing for the second time.
@@ -123,7 +143,7 @@ class DataHandler(Serializable):
# Setup data.
# _data may be with multiple column index level. The outer level indicates the feature set name
with TimeInspector.logt("Loading data"):
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq)
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
# TODO: cache
CS_ALL = "__all" # return all columns with single-level index column
@@ -262,7 +282,6 @@ class DataHandlerLP(DataHandler):
instruments=None,
start_time=None,
end_time=None,
freq="day",
data_loader: Tuple[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
@@ -328,7 +347,7 @@ class DataHandlerLP(DataHandler):
self.process_type = process_type
self.drop_raw = drop_raw
super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs)
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
def get_all_processors(self):
return self.infer_processors + self.learn_processors

View File

@@ -21,7 +21,7 @@ class DataLoader(abc.ABC):
"""
@abc.abstractmethod
def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
"""
load the data as pd.DataFrame.
@@ -78,6 +78,7 @@ class DLWParser(DataLoader):
<config> := <fields_info>
<fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
# NOTE: list or tuple will be treated as the things when parsing
"""
self.is_group = isinstance(config, dict)
@@ -87,18 +88,22 @@ class DLWParser(DataLoader):
self.fields = self._parse_fields_info(config)
def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
if isinstance(fields_info, list):
if len(fields_info) == 0:
raise ValueError("The size of fields must be greater than 0")
if not isinstance(fields_info, (list, tuple)):
raise TypeError("Unsupported type")
if isinstance(fields_info[0], str):
exprs = names = fields_info
elif isinstance(fields_info, tuple):
elif isinstance(fields_info[0], (list, tuple)):
exprs, names = fields_info
else:
raise NotImplementedError(f"This type of input is not supported")
return exprs, names
@abc.abstractmethod
def load_group_df(
self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day"
) -> pd.DataFrame:
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
"""
load the dataframe for specific group
@@ -118,25 +123,25 @@ class DLWParser(DataLoader):
"""
pass
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if self.is_group:
df = pd.concat(
{
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
for grp, (exprs, names) in self.fields.items()
},
axis=1,
)
else:
exprs, names = self.fields
df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
return df
class QlibDataLoader(DLWParser):
"""Same as QlibDataLoader. The fields can be define by config"""
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True):
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True, freq="day"):
"""
Parameters
----------
@@ -156,11 +161,10 @@ class QlibDataLoader(DLWParser):
self.filter_pipe = filter_pipe
self.swap_level = swap_level
self.freq = freq
super().__init__(config)
def load_group_df(
self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day"
) -> pd.DataFrame:
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
instruments = "all"
@@ -169,7 +173,7 @@ class QlibDataLoader(DLWParser):
elif self.filter_pipe is not None:
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
df = D.features(instruments, exprs, start_time, end_time, freq)
df = D.features(instruments, exprs, start_time, end_time, self.freq)
df.columns = names
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
@@ -194,7 +198,7 @@ class StaticDataLoader(DataLoader):
self.join = join
self._data = None
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
self._maybe_load_raw_data()
if instruments is None:
df = self._data

View File

@@ -3,6 +3,7 @@
from contextlib import contextmanager
from .expm import MLflowExpManager
from .exp import Experiment
from .recorder import Recorder
from ..utils import Wrapper
@@ -165,7 +166,7 @@ class QlibRecorder:
"""
return self.get_exp(experiment_id, experiment_name).list_recorders()
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
"""
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
True, if no valid experiment is found, this method will create one for you. Otherwise, it will