mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 18:11:18 +08:00
add highfreq example
This commit is contained in:
0
examples/high_freq/__init__.py
Normal file
0
examples/high_freq/__init__.py
Normal file
220
examples/high_freq/highfreq_handler.py
Normal file
220
examples/high_freq/highfreq_handler.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
|
||||
class HighFreqHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi500",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="1min",
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
drop_raw=True,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
p["kwargs"].update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = []
|
||||
learn_processors = []
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
freq=freq,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
drop_raw=drop_raw,
|
||||
)
|
||||
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Eq($paused, 0.0), {0})"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
fields += [
|
||||
"{0}/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$open"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"{0}/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$high"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"{0}/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$low"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += ["{0}/Ref(DayLast({0}), 240)".format(template_fillnan.format(template_paused.format("$close")))]
|
||||
fields += [
|
||||
"{0}/Ref(DayLast({1}), 240)".format(
|
||||
"If(IsNull({1}), {0}, If(Or(Or(Or(Eq({1}, np.inf), Eq({1}, -np.inf)), Eq({1}, 0)), Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2})))), {0}, {1}))".format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$vwap"),
|
||||
template_paused.format("$low"),
|
||||
template_paused.format("$high"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$open"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$high"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$low"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({0}), 240)".format(template_fillnan.format(template_paused.format("$close")))
|
||||
]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
"If(IsNull({1}), {0}, If(Or(Or(Or(Eq({1}, np.inf), Eq({1}, -np.inf)), Eq({1}, 0)), Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2})))), {0}, {1}))".format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$vwap"),
|
||||
template_paused.format("$low"),
|
||||
template_paused.format("$high"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
fields += [
|
||||
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(
|
||||
"If(IsNull({1}), 0, If(Or(Gt({2}, Mul(1.001, {4})), Lt({2}, Mul(0.999, {3}))), 0, {1}))".format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format("$vwap"),
|
||||
template_paused.format("$low"),
|
||||
template_paused.format("$high"),
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume"]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format(
|
||||
"If(IsNull({1}), 0, If(Or(Gt({2}, Mul(1.001, {4})), Lt({2}, Mul(0.999, {3}))), 0, {1}))".format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format("$vwap"),
|
||||
template_paused.format("$low"),
|
||||
template_paused.format("$high"),
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += [template_paused.format("Date($close)")]
|
||||
names += ["date"]
|
||||
return fields, names
|
||||
|
||||
|
||||
class HighFreqBacktestHandler(DataHandler):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi300",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="1min",
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": self.get_feature_config(),
|
||||
"swap_level": False,
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
freq=freq,
|
||||
data_loader=data_loader,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
template_if = "If(Eq({1}, np.nan), {0}, {1})"
|
||||
template_paused = "Select(Eq($paused, 0.0), {0})"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
|
||||
fields += [template_fillnan.format(template_paused.format("$close")),]
|
||||
names += ["$close0"]
|
||||
fields += [
|
||||
"If(Eq({1}, np.nan), 0, If(Or(Gt({2}, Mul(1.001, {4})), Lt({2}, Mul(0.999, {3}))), 0, {1}))".format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format("$vwap"),
|
||||
template_paused.format("$low"),
|
||||
template_paused.format("$high"),
|
||||
)
|
||||
]
|
||||
names += ["$volume0"]
|
||||
return fields, names
|
||||
62
examples/high_freq/highfreq_ops.py
Normal file
62
examples/high_freq/highfreq_ops.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import importlib
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
from qlib.config import C
|
||||
from qlib.data.data import Cal
|
||||
|
||||
|
||||
class DayFirst(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(DayFirst, self).__init__(feature, "day_first")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calender_day(freq=freq)[0]
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("first")
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(DayLast, self).__init__(feature, "day_last")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calender_day(freq=freq)[0]
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(FFillNan, self).__init__(feature, "fill_nan")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(Date, self).__init__(feature, "date")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calender_day(freq=freq)[0]
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return pd.Series(_calendar[series.index], index=series.index)
|
||||
|
||||
class Select(PairOperator):
|
||||
def __init__(self, condition, feature):
|
||||
super(Select, self).__init__(condition, feature, "select")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
return series_feature.loc[series_condition]
|
||||
|
||||
class IsNull(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(IsNull, self).__init__(feature, "isnull")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.isnull()
|
||||
70
examples/high_freq/highfreq_processor.py
Normal file
70
examples/high_freq/highfreq_processor.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.data.dataset.utils import fetch_df_by_index
|
||||
|
||||
|
||||
class HighFreqNorm(Processor):
|
||||
def __init__(self, fit_start_time, fit_end_time):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
|
||||
def fit(self, df_features):
|
||||
fetch_df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
del df
|
||||
df_values = fetch_df.values
|
||||
names = {
|
||||
"price": slice(0, 10),
|
||||
"volume": slice(10, 12),
|
||||
}
|
||||
self.feature_med = {}
|
||||
self.feature_std = {}
|
||||
self.feature_vmax = {}
|
||||
self.feature_vmin = {}
|
||||
for name, name_val in names.items():
|
||||
part_values = df_values[:, name_val]
|
||||
if name == "volume":
|
||||
df_features.loc(axis=1)[name_val] = np.log1p(part_values)
|
||||
self.feature_med[name] = np.nanmedian(part_values)
|
||||
part_values = part_values - self.feature_med # mean, copy
|
||||
self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12
|
||||
part_values = part_values / self.feature_std
|
||||
self.feature_vmax[name] = np.nanmax(part_values)
|
||||
self.feature_vmin[name] = np.nanmin(part_values)
|
||||
|
||||
def __call__(self, df_features):
|
||||
df_features.set_index("date", append=True, drop=True, inplace=True)
|
||||
df_values = df_features.values
|
||||
names = {
|
||||
"price": slice(0, 10),
|
||||
"volume": slice(10, 12),
|
||||
}
|
||||
|
||||
for name, name_val in names.items():
|
||||
part_values = df_values[:, name_val]
|
||||
if name == "volume":
|
||||
part_values[:] = np.log1p(part_values)
|
||||
part_values -= self.feature_med[name]
|
||||
part_values /= self.feature_std[name]
|
||||
slice0 = part_values > 3.0
|
||||
slice1 = part_values > 3.5
|
||||
slice2 = part_values < -3.0
|
||||
slice3 = part_values < -3.5
|
||||
|
||||
part_values[slice0] = 3.0 + (part_values[slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5
|
||||
part_values[slice1] = 3.5
|
||||
part_values[slice2] = -3.0 - (part_values[slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5
|
||||
part_values[slice3] = -3.5
|
||||
# print("start_call_feature_reshape")
|
||||
idx = df_features.index.droplevel("datetime").drop_duplicates()
|
||||
feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)
|
||||
feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)
|
||||
|
||||
df_new_features = pd.DataFrame(
|
||||
data=np.concatenate((feat, feat_1), axis=1),
|
||||
index=idx,
|
||||
columns=["FEATURE_%d" % i for i in range(12 * 240)],
|
||||
).sort_index()
|
||||
|
||||
return df_new_features
|
||||
137
examples/high_freq/workflow.py
Normal file
137
examples/high_freq/workflow.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
|
||||
from highfreq_ops import DayFirst, DayLast, FFillNan, Date, Select, IsNull
|
||||
|
||||
def save_dataset(dataset, path: [Path, str]):
|
||||
"""
|
||||
save dataset to path
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : [Path, str]
|
||||
path to save
|
||||
"""
|
||||
dataset.to_pickle(path=path)
|
||||
|
||||
def load_dataset(path: [Path, str], init_type=DataHandlerLP.IT_LS):
|
||||
"""
|
||||
load dataset from path
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : [Path, str]
|
||||
path to load
|
||||
|
||||
init_type : str
|
||||
- if `init_type` == DataHandlerLP.IT_FIT_SEQ:
|
||||
|
||||
the input of `DataHandlerLP.fit` will be the output of the previous processor
|
||||
|
||||
- if `init_type` == DataHandlerLP.IT_FIT_IND:
|
||||
|
||||
the input of `DataHandlerLP.fit` will be the original df
|
||||
|
||||
- if `init_type` == DataHandlerLP.IT_LS:
|
||||
|
||||
The state of the object has been load by pickle
|
||||
"""
|
||||
fd = open(path, 'rb')
|
||||
dataset = pickle.load(fd)
|
||||
dataset.init(init_type=init_type)
|
||||
fd.close()
|
||||
return dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "/mnt/v-xiabi/data/qlib/high_freq" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, custom_ops=[DayFirst, DayLast, FFillNan, Date, Select, IsNull], redis_port=233, region=REG_CN, auto_mount=False)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG0 = {
|
||||
"start_time": "2017-01-01 00:00:00",
|
||||
"end_time": "2020-11-30 15:00:00",
|
||||
"freq": "1min",
|
||||
"fit_start_time": "2017-01-01 00:00:00",
|
||||
"fit_end_time": "2020-08-31 15:00:00",
|
||||
"instruments": "all",
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
|
||||
}
|
||||
DATA_HANDLER_CONFIG1 = {
|
||||
"start_time": "2017-01-01 00:00:00",
|
||||
"end_time": "2020-11-30 15:00:00",
|
||||
"freq": "1min",
|
||||
"instruments": "all",
|
||||
}
|
||||
|
||||
task = {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "HighFreqHandler",
|
||||
"module_path": "highfreq_handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG0,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2017-01-01 00:00:00", "2020-08-31 15:00:00"),
|
||||
"test": (
|
||||
"2020-09-01 00:00:00",
|
||||
"2020-11-30 15:00:00",
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
"dataset_backtest": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "HighFreqBacktestHandler",
|
||||
"module_path": "highfreq_hander",
|
||||
"kwargs": DATA_HANDLER_CONFIG1,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2017-01-01 00:00:00", "2020-08-31 15:00:00"),
|
||||
"test": (
|
||||
"2020-09-01 00:00:00",
|
||||
"2020-11-30 15:00:00",
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
Cal.get_calender_day(freq="1min") # TO FIX: load the calendar day for cache
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
dataset_backtest = init_instance_by_config(task["dataset_backtest"])
|
||||
|
||||
@@ -49,6 +49,7 @@ class Alpha360(DataHandlerLP):
|
||||
instruments="csi500",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
infer_processors=_DEFAULT_INFER_PROCESSORS,
|
||||
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
||||
fit_start_time=None,
|
||||
@@ -69,9 +70,10 @@ class Alpha360(DataHandlerLP):
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments,
|
||||
start_time,
|
||||
end_time,
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
freq="day",
|
||||
data_loader=data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
@@ -130,6 +132,7 @@ class Alpha158(DataHandlerLP):
|
||||
instruments="csi500",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
infer_processors=[],
|
||||
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
||||
fit_start_time=None,
|
||||
@@ -147,9 +150,10 @@ class Alpha158(DataHandlerLP):
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments,
|
||||
start_time,
|
||||
end_time,
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
freq=freq,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
|
||||
@@ -123,6 +123,16 @@ class CalendarProvider(abc.ABC):
|
||||
H["c"][flag] = _calendar, _calendar_index
|
||||
return _calendar, _calendar_index
|
||||
|
||||
def get_calender_day(self, freq="day", future=False):
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar, _calendar_index = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(list(map(lambda x: x.date(), self._load_calendar(freq, future))))
|
||||
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
|
||||
H["c"][flag] = _calendar, _calendar_index
|
||||
return _calendar, _calendar_index
|
||||
|
||||
def _uri(self, start_time, end_time, freq, future=False):
|
||||
"""Get the uri of calendar generation task."""
|
||||
return hash_args(start_time, end_time, freq, future)
|
||||
@@ -686,7 +696,10 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
|
||||
# 2) The the precision should be configurable
|
||||
try:
|
||||
series = series.astype(np.float32)
|
||||
if series.dtype == np.float64:
|
||||
series = series.astype(np.float32)
|
||||
elif series.dtype == np.bool:
|
||||
series = series.astype(np.int8)
|
||||
except ValueError:
|
||||
pass
|
||||
if not series.empty:
|
||||
|
||||
@@ -87,6 +87,36 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
super().__init__(handler, segments)
|
||||
|
||||
|
||||
def init(self, init_type: str = DataHandlerLP.IT_FIT_SEQ, enable_cache: bool = False):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_type : str
|
||||
- if `init_type` == DataHandlerLP.IT_FIT_SEQ:
|
||||
|
||||
the input of `DataHandlerLP.fit` will be the output of the previous processor
|
||||
|
||||
- if `init_type` == DataHandlerLP.IT_FIT_IND:
|
||||
|
||||
the input of `DataHandlerLP.fit` will be the original df
|
||||
|
||||
- if `init_type` == DataHandlerLP.IT_LS:
|
||||
|
||||
The state of the object has been load by pickle
|
||||
|
||||
enable_cache : bool
|
||||
default value is false:
|
||||
|
||||
- if `enable_cache` == True:
|
||||
|
||||
the processed data will be saved on disk, and handler will load the cached data from the disk directly
|
||||
when we call `init` next time
|
||||
"""
|
||||
self.handler.init(init_type=init_type, enable_cache=enable_cache)
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: list):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
@@ -116,8 +146,8 @@ class DatasetH(Dataset):
|
||||
'outsample': ("2017-01-01", "2020-08-01",),
|
||||
}
|
||||
"""
|
||||
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self._segments = segments.copy()
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
"""
|
||||
@@ -127,7 +157,7 @@ class DatasetH(Dataset):
|
||||
----------
|
||||
slc : slice
|
||||
"""
|
||||
return self._handler.fetch(slc, **kwargs)
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@@ -150,7 +180,7 @@ class DatasetH(Dataset):
|
||||
- ['train', 'valid']
|
||||
|
||||
col_set : str
|
||||
The col_set will be passed to self._handler when fetching data.
|
||||
The col_set will be passed to self.handler when fetching data.
|
||||
data_key : str
|
||||
The data to fetch: DK_*
|
||||
Default is DK_I, which indicate fetching data for **inference**.
|
||||
@@ -166,16 +196,16 @@ class DatasetH(Dataset):
|
||||
logger = get_module_logger("DatasetH")
|
||||
fetch_kwargs = {"col_set": col_set}
|
||||
fetch_kwargs.update(kwargs)
|
||||
if "data_key" in getfullargspec(self._handler.fetch).args:
|
||||
if "data_key" in getfullargspec(self.handler.fetch).args:
|
||||
fetch_kwargs["data_key"] = data_key
|
||||
else:
|
||||
logger.info(f"data_key[{data_key}] is ignored.")
|
||||
|
||||
# Handle all kinds of segments format
|
||||
if isinstance(segments, (list, tuple)):
|
||||
return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
|
||||
return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments]
|
||||
elif isinstance(segments, str):
|
||||
return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs)
|
||||
return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs)
|
||||
elif isinstance(segments, slice):
|
||||
return self._prepare_seg(segments, **fetch_kwargs)
|
||||
else:
|
||||
@@ -409,7 +439,7 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
super().setup_data(*args, **kwargs)
|
||||
cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
# Get the datatime index for building timestamp
|
||||
self.cal = cal
|
||||
|
||||
@@ -57,6 +57,7 @@ class DataHandler(Serializable):
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
init_data=True,
|
||||
fetch_orig=True,
|
||||
@@ -70,6 +71,8 @@ class DataHandler(Serializable):
|
||||
start_time of the original data.
|
||||
end_time :
|
||||
end_time of the original data.
|
||||
freq :
|
||||
frequency of data
|
||||
data_loader : Tuple[dict, str, DataLoader]
|
||||
data loader to load the data.
|
||||
init_data :
|
||||
@@ -92,6 +95,7 @@ class DataHandler(Serializable):
|
||||
self.instruments = instruments
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.freq = freq
|
||||
self.fetch_orig = fetch_orig
|
||||
if init_data:
|
||||
with TimeInspector.logt("Init data"):
|
||||
@@ -119,7 +123,7 @@ class DataHandler(Serializable):
|
||||
# Setup data.
|
||||
# _data may be with multiple column index level. The outer level indicates the feature set name
|
||||
with TimeInspector.logt("Loading data"):
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
|
||||
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq)
|
||||
# TODO: cache
|
||||
|
||||
CS_ALL = "__all" # return all columns with single-level index column
|
||||
@@ -258,10 +262,12 @@ class DataHandlerLP(DataHandler):
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
process_type=PTYPE_A,
|
||||
drop_raw=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -303,6 +309,8 @@ class DataHandlerLP(DataHandler):
|
||||
- self._learn will be processed by infer_processors + learn_processors
|
||||
|
||||
- (e.g. self._infer processed by learn_processors )
|
||||
drop_raw: bool
|
||||
Whether to drop the raw data
|
||||
"""
|
||||
|
||||
# Setup preprocessor
|
||||
@@ -319,7 +327,8 @@ class DataHandlerLP(DataHandler):
|
||||
)
|
||||
|
||||
self.process_type = process_type
|
||||
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
|
||||
self.drop_raw = drop_raw
|
||||
super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs)
|
||||
|
||||
def get_all_processors(self):
|
||||
return self.infer_processors + self.learn_processors
|
||||
@@ -348,7 +357,7 @@ class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
# data for inference
|
||||
_infer_df = self._data
|
||||
if len(self.infer_processors) > 0: # avoid modifying the original data
|
||||
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
|
||||
_infer_df = _infer_df.copy()
|
||||
|
||||
for proc in self.infer_processors:
|
||||
@@ -378,6 +387,8 @@ class DataHandlerLP(DataHandler):
|
||||
_learn_df = proc(_learn_df)
|
||||
self._learn = _learn_df
|
||||
|
||||
if self.drop_raw:
|
||||
del self._data
|
||||
# init type
|
||||
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
|
||||
@@ -416,7 +427,11 @@ class DataHandlerLP(DataHandler):
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
|
||||
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
try:
|
||||
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
except AttributeError:
|
||||
print("please set drop_raw = False if you want to use raw data")
|
||||
raise
|
||||
return df
|
||||
|
||||
def fetch(
|
||||
|
||||
@@ -19,7 +19,7 @@ class DataLoader(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
|
||||
"""
|
||||
load the data as pd.DataFrame.
|
||||
|
||||
@@ -94,7 +94,7 @@ class DLWParser(DataLoader):
|
||||
return exprs, names
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
|
||||
"""
|
||||
load the dataframe for specific group
|
||||
|
||||
@@ -114,25 +114,25 @@ class DLWParser(DataLoader):
|
||||
"""
|
||||
pass
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
|
||||
for grp, (exprs, names) in self.fields.items()
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
exprs, names = self.fields
|
||||
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
|
||||
df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
|
||||
return df
|
||||
|
||||
|
||||
class QlibDataLoader(DLWParser):
|
||||
"""Same as QlibDataLoader. The fields can be define by config"""
|
||||
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
|
||||
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -140,11 +140,15 @@ class QlibDataLoader(DLWParser):
|
||||
Please refer to the doc of DLWParser
|
||||
filter_pipe :
|
||||
Filter pipe for the instruments
|
||||
swap_level :
|
||||
Whether to swap level of MultiIndex
|
||||
"""
|
||||
self.filter_pipe = filter_pipe
|
||||
self.swap_level = swap_level
|
||||
print("swap level", swap_level)
|
||||
super().__init__(config)
|
||||
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
instruments = "all"
|
||||
@@ -153,9 +157,10 @@ class QlibDataLoader(DLWParser):
|
||||
elif self.filter_pipe is not None:
|
||||
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
|
||||
|
||||
df = D.features(instruments, exprs, start_time, end_time)
|
||||
df = D.features(instruments, exprs, start_time, end_time, freq)
|
||||
df.columns = names
|
||||
df = df.swaplevel().sort_index() # NOTE: always return <datetime, instrument>
|
||||
if self.swap_level:
|
||||
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
|
||||
return df
|
||||
|
||||
|
||||
@@ -177,7 +182,7 @@ class StaticDataLoader(DataLoader):
|
||||
self.join = join
|
||||
self._data = None
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
|
||||
self._maybe_load_raw_data()
|
||||
if instruments is None:
|
||||
df = self._data
|
||||
|
||||
Reference in New Issue
Block a user