1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 19:41:00 +08:00

Merge pull request #222 from bxdd/rl-highfreq-include-examples

Qlib Highfreq Support & Highfreq DataHanlder/Operator/Processor Examples
This commit is contained in:
you-n-g
2021-01-29 00:08:10 +08:00
committed by GitHub
16 changed files with 642 additions and 93 deletions

View File

View File

@@ -0,0 +1,174 @@
from qlib.data.dataset.handler import DataHandler, DataHandlerLP
from qlib.data.dataset.processor import Processor
from qlib.utils import get_cls_kwargs
from qlib.log import TimeInspector
class HighFreqHandler(DataHandlerLP):
def __init__(
self,
instruments="csi300",
start_time=None,
end_time=None,
freq="1min",
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
drop_raw=True,
):
def check_transform_proc(proc_l):
new_l = []
for p in proc_l:
p["kwargs"].update(
{
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
}
)
new_l.append(p)
return new_l
infer_processors = check_transform_proc(infer_processors)
learn_processors = check_transform_proc(learn_processors)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
freq=freq,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
drop_raw=drop_raw,
)
def get_feature_config(self):
fields = []
names = []
template_if = "If(IsNull({1}), {0}, {1})"
template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})"
template_fillnan = "BFillNan(FFillNan({0}))"
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
def get_normalized_price_feature(price_field, shift=0):
"""Get normalized price feature ops"""
if shift == 0:
template_norm = "{0}/Ref(DayLast({1}), 240)"
else:
template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)"
feature_ops = template_norm.format(
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(price_field),
),
template_fillnan.format(template_paused.format("$close")),
)
return feature_ops
fields += [get_normalized_price_feature("$open", 0)]
fields += [get_normalized_price_feature("$high", 0)]
fields += [get_normalized_price_feature("$low", 0)]
fields += [get_normalized_price_feature("$close", 0)]
fields += [get_normalized_price_feature(simpson_vwap, 0)]
names += ["$open", "$high", "$low", "$close", "$vwap"]
fields += [get_normalized_price_feature("$open", 240)]
fields += [get_normalized_price_feature("$high", 240)]
fields += [get_normalized_price_feature("$low", 240)]
fields += [get_normalized_price_feature("$close", 240)]
fields += [get_normalized_price_feature(simpson_vwap, 240)]
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
fields += [
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
template_paused.format("$low"),
template_paused.format("$high"),
)
)
]
names += ["$volume"]
fields += [
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format(
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
template_paused.format("$low"),
template_paused.format("$high"),
)
)
]
names += ["$volume_1"]
fields += [template_paused.format("Date($close)")]
names += ["date"]
return fields, names
class HighFreqBacktestHandler(DataHandler):
def __init__(
self,
instruments="csi300",
start_time=None,
end_time=None,
freq="1min",
):
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
},
}
super().__init__(
instruments=instruments,
start_time=start_time,
end_time=end_time,
freq=freq,
data_loader=data_loader,
)
def get_feature_config(self):
fields = []
names = []
template_if = "If(IsNull({1}), {0}, {1})"
template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})"
template_fillnan = "BFillNan(FFillNan({0}))"
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
fields += [
template_fillnan.format(template_paused.format("$close")),
]
names += ["$close0"]
fields += [
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(simpson_vwap),
)
]
names += ["$vwap0"]
fields += [
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
template_paused.format("$volume"),
template_paused.format(simpson_vwap),
template_paused.format("$low"),
template_paused.format("$high"),
)
]
names += ["$volume0"]
return fields, names

View File

@@ -0,0 +1,56 @@
import numpy as np
import pandas as pd
import importlib
from qlib.data.ops import ElemOperator, PairOperator
from qlib.config import C
from qlib.data.cache import H
from qlib.data.data import Cal
def get_calendar_day(freq="day", future=False):
flag = f"{freq}_future_{future}_day"
if flag in H["c"]:
_calendar = H["c"][flag]
else:
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
H["c"][flag] = _calendar
return _calendar
class DayLast(ElemOperator):
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
return series.groupby(_calendar[series.index]).transform("last")
class FFillNan(ElemOperator):
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="ffill")
class BFillNan(ElemOperator):
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="bfill")
class Date(ElemOperator):
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
return pd.Series(_calendar[series.index], index=series.index)
class Select(PairOperator):
def _load_internal(self, instrument, start_index, end_index, freq):
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
return series_feature.loc[series_condition]
class IsNull(ElemOperator):
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.isnull()

View File

@@ -0,0 +1,72 @@
import numpy as np
import pandas as pd
from qlib.data.dataset.processor import Processor
from qlib.data.dataset.utils import fetch_df_by_index
class HighFreqNorm(Processor):
def __init__(self, fit_start_time, fit_end_time):
self.fit_start_time = fit_start_time
self.fit_end_time = fit_end_time
def fit(self, df_features):
fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
del df_features
df_values = fetch_df.values
names = {
"price": slice(0, 10),
"volume": slice(10, 12),
}
self.feature_med = {}
self.feature_std = {}
self.feature_vmax = {}
self.feature_vmin = {}
for name, name_val in names.items():
part_values = df_values[:, name_val].astype(np.float32)
if name == "volume":
part_values = np.log1p(part_values)
self.feature_med[name] = np.nanmedian(part_values)
part_values = part_values - self.feature_med[name]
self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12
part_values = part_values / self.feature_std[name]
self.feature_vmax[name] = np.nanmax(part_values)
self.feature_vmin[name] = np.nanmin(part_values)
def __call__(self, df_features):
df_features.set_index("date", append=True, drop=True, inplace=True)
df_values = df_features.values
names = {
"price": slice(0, 10),
"volume": slice(10, 12),
}
for name, name_val in names.items():
if name == "volume":
df_values[:, name_val] = np.log1p(df_values[:, name_val])
df_values[:, name_val] -= self.feature_med[name]
df_values[:, name_val] /= self.feature_std[name]
slice0 = df_values[:, name_val] > 3.0
slice1 = df_values[:, name_val] > 3.5
slice2 = df_values[:, name_val] < -3.0
slice3 = df_values[:, name_val] < -3.5
df_values[:, name_val][slice0] = (
3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5
)
df_values[:, name_val][slice1] = 3.5
df_values[:, name_val][slice2] = (
-3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5
)
df_values[:, name_val][slice3] = -3.5
idx = df_features.index.droplevel("datetime").drop_duplicates()
idx.set_names(["instrument", "datetime"], inplace=True)
# Reshape is specifically for adapting to RL high-freq executor
feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)
feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)
df_new_features = pd.DataFrame(
data=np.concatenate((feat, feat_1), axis=1),
index=idx,
columns=["FEATURE_%d" % i for i in range(12 * 240)],
).sort_index()
return df_new_features

View File

@@ -0,0 +1,166 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import fire
from pathlib import Path
import qlib
import pickle
import numpy as np
import pandas as pd
from qlib.config import HIGH_FREQ_CONFIG
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import init_instance_by_config, exists_qlib_data
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.ops import Operators
from qlib.data.data import Cal
from qlib.tests.data import GetData
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull
class HighfreqWorkflow(object):
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None}
MARKET = "all"
BENCHMARK = "SH000300"
start_time = "2020-09-14 00:00:00"
end_time = "2021-01-18 16:00:00"
train_end_time = "2020-11-30 16:00:00"
test_start_time = "2020-12-01 00:00:00"
DATA_HANDLER_CONFIG0 = {
"start_time": start_time,
"end_time": end_time,
"freq": "1min",
"fit_start_time": start_time,
"fit_end_time": train_end_time,
"instruments": MARKET,
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
}
DATA_HANDLER_CONFIG1 = {
"start_time": start_time,
"end_time": end_time,
"freq": "1min",
"instruments": MARKET,
}
task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "HighFreqHandler",
"module_path": "highfreq_handler",
"kwargs": DATA_HANDLER_CONFIG0,
},
"segments": {
"train": (start_time, train_end_time),
"test": (
test_start_time,
end_time,
),
},
},
},
"dataset_backtest": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "HighFreqBacktestHandler",
"module_path": "highfreq_handler",
"kwargs": DATA_HANDLER_CONFIG1,
},
"segments": {
"train": (start_time, train_end_time),
"test": (
test_start_time,
end_time,
),
},
},
},
}
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
qlib.init(**QLIB_INIT_CONFIG)
def _prepare_calender_cache(self):
"""preload the calendar for cache"""
# This code used the copy-on-write feature of Linux to avoid calculating the calendar multiple times in the subprocess
# This code may accelerate, but may be not useful on Windows and Mac Os
Cal.calendar(freq="1min")
get_calendar_day(freq="1min")
def get_data(self):
"""use dataset to get highreq data"""
self._init_qlib()
self._prepare_calender_cache()
dataset = init_instance_by_config(self.task["dataset"])
xtrain, xtest = dataset.prepare(["train", "test"])
print(xtrain, xtest)
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
print(backtest_train, backtest_test)
del xtrain, xtest
del backtest_train, backtest_test
def dump_and_load_dataset(self):
"""dump and load dataset state on disk"""
self._init_qlib()
self._prepare_calender_cache()
dataset = init_instance_by_config(self.task["dataset"])
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
##=============dump dataset=============
dataset.to_pickle(path="dataset.pkl")
dataset_backtest.to_pickle(path="dataset_backtest.pkl")
del dataset, dataset_backtest
##=============reload dataset=============
with open("dataset.pkl", "rb") as file_dataset:
dataset = pickle.load(file_dataset)
with open("dataset_backtest.pkl", "rb") as file_dataset_backtest:
dataset_backtest = pickle.load(file_dataset_backtest)
self._prepare_calender_cache()
##=============reload_dataset=============
dataset.init(init_type=DataHandlerLP.IT_LS)
dataset_backtest.init()
##=============get data=============
xtrain, xtest = dataset.prepare(["train", "test"])
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
print(xtrain, xtest)
print(backtest_train, backtest_test)
del xtrain, xtest
del backtest_train, backtest_test
if __name__ == "__main__":
fire.Fire(HighfreqWorkflow)

View File

@@ -17,7 +17,7 @@ from qlib.contrib.evaluate import (
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
if __name__ == "__main__":
@@ -25,9 +25,6 @@ if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)

View File

@@ -193,6 +193,12 @@ MODE_CONF = {
},
}
HIGH_FREQ_CONFIG = {
"provider_uri": "~/.qlib/qlib_data/yahoo_cn_1min",
"dataset_cache": None,
"expression_cache": "DiskExpressionCache",
"region": REG_CN,
}
_default_region_config = {
REG_CN: {
@@ -291,12 +297,12 @@ class QlibConfig(Config):
def register(self):
from .utils import init_instance_by_config
from .data.ops import register_custom_ops
from .data.ops import register_all_ops
from .data.data import register_all_wrappers
from .workflow import R, QlibRecorder
from .workflow.utils import experiment_exit_handler
register_custom_ops(self)
register_all_ops(self)
register_all_wrappers(self)
# set up QlibRecorder
exp_manager = init_instance_by_config(self["exp_manager"])

View File

@@ -49,6 +49,7 @@ class Alpha360(DataHandlerLP):
instruments="csi500",
start_time=None,
end_time=None,
freq="day",
infer_processors=_DEFAULT_INFER_PROCESSORS,
learn_processors=_DEFAULT_LEARN_PROCESSORS,
fit_start_time=None,
@@ -69,9 +70,10 @@ class Alpha360(DataHandlerLP):
}
super().__init__(
instruments,
start_time,
end_time,
instruments=instruments,
start_time=start_time,
end_time=end_time,
freq="day",
data_loader=data_loader,
learn_processors=learn_processors,
infer_processors=infer_processors,
@@ -130,6 +132,7 @@ class Alpha158(DataHandlerLP):
instruments="csi500",
start_time=None,
end_time=None,
freq="day",
infer_processors=[],
learn_processors=_DEFAULT_LEARN_PROCESSORS,
fit_start_time=None,
@@ -147,9 +150,10 @@ class Alpha158(DataHandlerLP):
},
}
super().__init__(
instruments,
start_time,
end_time,
instruments=instruments,
start_time=start_time,
end_time=end_time,
freq=freq,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,

View File

@@ -157,7 +157,7 @@ class Expression(abc.ABC):
@abc.abstractmethod
def _load_internal(self, instrument, start_index, end_index, freq):
pass
raise NotImplementedError("This function must be implemented in your newly defined feature")
@abc.abstractmethod
def get_longest_back_rolling(self):

View File

@@ -117,7 +117,7 @@ class CalendarProvider(abc.ABC):
if flag in H["c"]:
_calendar, _calendar_index = H["c"][flag]
else:
_calendar = np.array(self._load_calendar(freq, future))
_calendar = np.array(self.load_calendar(freq, future))
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
H["c"][flag] = _calendar, _calendar_index
return _calendar, _calendar_index
@@ -504,7 +504,7 @@ class LocalCalendarProvider(CalendarProvider):
"""Calendar file uri."""
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
def _load_calendar(self, freq, future):
def load_calendar(self, freq, future):
"""Load original calendar timestamp from file.
Parameters
@@ -672,6 +672,8 @@ class LocalExpressionProvider(ExpressionProvider):
series = series.astype(np.float32)
except ValueError:
pass
except TypeError:
pass
if not series.empty:
series = series.loc[start_index:end_index]
return series

View File

@@ -87,6 +87,10 @@ class DatasetH(Dataset):
"""
super().__init__(handler, segments)
def init(self, **kwargs):
"""Initialize the DatasetH, Only parameters belonging to handler.init will be passed in"""
self.handler.init(**kwargs)
def setup_data(self, handler: Union[dict, DataHandler], segments: list):
"""
Setup the underlying data.
@@ -116,8 +120,8 @@ class DatasetH(Dataset):
'outsample': ("2017-01-01", "2020-08-01",),
}
"""
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
self._segments = segments.copy()
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
def _prepare_seg(self, slc: slice, **kwargs):
"""
@@ -127,7 +131,7 @@ class DatasetH(Dataset):
----------
slc : slice
"""
return self._handler.fetch(slc, **kwargs)
return self.handler.fetch(slc, **kwargs)
def prepare(
self,
@@ -150,7 +154,7 @@ class DatasetH(Dataset):
- ['train', 'valid']
col_set : str
The col_set will be passed to self._handler when fetching data.
The col_set will be passed to self.handler when fetching data.
data_key : str
The data to fetch: DK_*
Default is DK_I, which indicate fetching data for **inference**.
@@ -166,16 +170,16 @@ class DatasetH(Dataset):
logger = get_module_logger("DatasetH")
fetch_kwargs = {"col_set": col_set}
fetch_kwargs.update(kwargs)
if "data_key" in getfullargspec(self._handler.fetch).args:
if "data_key" in getfullargspec(self.handler.fetch).args:
fetch_kwargs["data_key"] = data_key
else:
logger.info(f"data_key[{data_key}] is ignored.")
# Handle all kinds of segments format
if isinstance(segments, (list, tuple)):
return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments]
elif isinstance(segments, str):
return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs)
return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs)
elif isinstance(segments, slice):
return self._prepare_seg(segments, **fetch_kwargs)
else:
@@ -409,7 +413,7 @@ class TSDatasetH(DatasetH):
def setup_data(self, *args, **kwargs):
super().setup_data(*args, **kwargs)
cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique()
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
# Get the datatime index for building timestamp
self.cal = cal

View File

@@ -57,6 +57,7 @@ class DataHandler(Serializable):
instruments=None,
start_time=None,
end_time=None,
freq="day",
data_loader: Tuple[dict, str, DataLoader] = None,
init_data=True,
fetch_orig=True,
@@ -70,6 +71,8 @@ class DataHandler(Serializable):
start_time of the original data.
end_time :
end_time of the original data.
freq :
frequency of data
data_loader : Tuple[dict, str, DataLoader]
data loader to load the data.
init_data :
@@ -92,6 +95,7 @@ class DataHandler(Serializable):
self.instruments = instruments
self.start_time = start_time
self.end_time = end_time
self.freq = freq
self.fetch_orig = fetch_orig
if init_data:
with TimeInspector.logt("Init data"):
@@ -119,7 +123,7 @@ class DataHandler(Serializable):
# Setup data.
# _data may be with multiple column index level. The outer level indicates the feature set name
with TimeInspector.logt("Loading data"):
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time, self.freq)
# TODO: cache
CS_ALL = "__all" # return all columns with single-level index column
@@ -258,10 +262,12 @@ class DataHandlerLP(DataHandler):
instruments=None,
start_time=None,
end_time=None,
freq="day",
data_loader: Tuple[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
process_type=PTYPE_A,
drop_raw=False,
**kwargs,
):
"""
@@ -303,6 +309,8 @@ class DataHandlerLP(DataHandler):
- self._learn will be processed by infer_processors + learn_processors
- (e.g. self._infer processed by learn_processors )
drop_raw: bool
Whether to drop the raw data
"""
# Setup preprocessor
@@ -319,7 +327,8 @@ class DataHandlerLP(DataHandler):
)
self.process_type = process_type
super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
self.drop_raw = drop_raw
super().__init__(instruments, start_time, end_time, freq, data_loader, **kwargs)
def get_all_processors(self):
return self.infer_processors + self.learn_processors
@@ -348,7 +357,7 @@ class DataHandlerLP(DataHandler):
"""
# data for inference
_infer_df = self._data
if len(self.infer_processors) > 0: # avoid modifying the original data
if len(self.infer_processors) > 0 and not self.drop_raw: # avoid modifying the original data
_infer_df = _infer_df.copy()
for proc in self.infer_processors:
@@ -378,6 +387,9 @@ class DataHandlerLP(DataHandler):
_learn_df = proc(_learn_df)
self._learn = _learn_df
if self.drop_raw:
del self._data
# init type
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
@@ -416,6 +428,10 @@ class DataHandlerLP(DataHandler):
# TODO: Be able to cache handler data. Save the memory for data processing
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
if data_key == self.DK_R and self.drop_raw:
raise AttributeError(
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
)
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
return df

View File

@@ -19,7 +19,7 @@ class DataLoader(abc.ABC):
"""
@abc.abstractmethod
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
def load(self, instruments, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
"""
load the data as pd.DataFrame.
@@ -94,7 +94,9 @@ class DLWParser(DataLoader):
return exprs, names
@abc.abstractmethod
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
def load_group_df(
self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day"
) -> pd.DataFrame:
"""
load the dataframe for specific group
@@ -114,25 +116,25 @@ class DLWParser(DataLoader):
"""
pass
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
if self.is_group:
df = pd.concat(
{
grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
grp: self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
for grp, (exprs, names) in self.fields.items()
},
axis=1,
)
else:
exprs, names = self.fields
df = self.load_group_df(instruments, exprs, names, start_time, end_time)
df = self.load_group_df(instruments, exprs, names, start_time, end_time, freq)
return df
class QlibDataLoader(DLWParser):
"""Same as QlibDataLoader. The fields can be define by config"""
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None, swap_level=True):
"""
Parameters
----------
@@ -140,11 +142,16 @@ class QlibDataLoader(DLWParser):
Please refer to the doc of DLWParser
filter_pipe :
Filter pipe for the instruments
swap_level :
Whether to swap level of MultiIndex
"""
self.filter_pipe = filter_pipe
self.swap_level = swap_level
super().__init__(config)
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
def load_group_df(
self, instruments, exprs: list, names: list, start_time=None, end_time=None, freq="day"
) -> pd.DataFrame:
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
instruments = "all"
@@ -153,9 +160,10 @@ class QlibDataLoader(DLWParser):
elif self.filter_pipe is not None:
warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
df = D.features(instruments, exprs, start_time, end_time)
df = D.features(instruments, exprs, start_time, end_time, freq)
df.columns = names
df = df.swaplevel().sort_index() # NOTE: always return <datetime, instrument>
if self.swap_level:
df = df.swaplevel().sort_index() # NOTE: if swaplevel, return <datetime, instrument>
return df
@@ -177,7 +185,7 @@ class StaticDataLoader(DataLoader):
self.join = join
self._data = None
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
def load(self, instruments=None, start_time=None, end_time=None, freq="day") -> pd.DataFrame:
self._maybe_load_raw_data()
if instruments is None:
df = self._data

View File

@@ -6,6 +6,7 @@ from __future__ import division
from __future__ import print_function
import sys
import abc
import numpy as np
import pandas as pd
@@ -17,7 +18,7 @@ from ..log import get_module_logger
try:
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi
except ImportError as err:
except ImportError:
print(
"#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####"
)
@@ -32,12 +33,39 @@ np.seterr(invalid="ignore")
class ElemOperator(ExpressionOps):
"""Element-wise Operator
Parameters
----------
feature : Expression
feature instance
Returns
----------
Expression
feature operation output
"""
def __init__(self, feature):
self.feature = feature
def __str__(self):
return "{}({})".format(type(self).__name__, self.feature)
def get_longest_back_rolling(self):
return self.feature.get_longest_back_rolling()
def get_extended_window_size(self):
return self.feature.get_extended_window_size()
class NpElemOperator(ElemOperator):
"""Numpy Element-wise Operator
Parameters
----------
feature : Expression
feature instance
func : str
feature operation method
numpy feature operation method
Returns
----------
@@ -48,22 +76,14 @@ class ElemOperator(ExpressionOps):
def __init__(self, feature, func):
self.feature = feature
self.func = func
def __str__(self):
return "{}({})".format(type(self).__name__, self.feature)
super(NpElemOperator, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return getattr(np, self.func)(series)
def get_longest_back_rolling(self):
return self.feature.get_longest_back_rolling()
def get_extended_window_size(self):
return self.feature.get_extended_window_size()
class Abs(ElemOperator):
class Abs(NpElemOperator):
"""Feature Absolute Value
Parameters
@@ -81,7 +101,7 @@ class Abs(ElemOperator):
super(Abs, self).__init__(feature, "abs")
class Sign(ElemOperator):
class Sign(NpElemOperator):
"""Feature Sign
Parameters
@@ -108,7 +128,7 @@ class Sign(ElemOperator):
return getattr(np, self.func)(series)
class Log(ElemOperator):
class Log(NpElemOperator):
"""Feature Log
Parameters
@@ -126,7 +146,7 @@ class Log(ElemOperator):
super(Log, self).__init__(feature, "log")
class Power(ElemOperator):
class Power(NpElemOperator):
"""Feature Power
Parameters
@@ -152,7 +172,7 @@ class Power(ElemOperator):
return getattr(np, self.func)(series, self.exponent)
class Mask(ElemOperator):
class Mask(NpElemOperator):
"""Feature Mask
Parameters
@@ -179,7 +199,7 @@ class Mask(ElemOperator):
return self.feature.load(self.instrument, start_index, end_index, freq)
class Not(ElemOperator):
class Not(NpElemOperator):
"""Not Operator
Parameters
@@ -218,28 +238,13 @@ class PairOperator(ExpressionOps):
two features' operation output
"""
def __init__(self, feature_left, feature_right, func):
def __init__(self, feature_left, feature_right):
self.feature_left = feature_left
self.feature_right = feature_right
self.func = func
def __str__(self):
return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right)
def _load_internal(self, instrument, start_index, end_index, freq):
assert any(
[isinstance(self.feature_left, Expression), self.feature_right, Expression]
), "at least one of two inputs is Expression instance"
if isinstance(self.feature_left, Expression):
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
else:
series_left = self.feature_left # numeric value
if isinstance(self.feature_right, Expression):
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
else:
series_right = self.feature_right
return getattr(np, self.func)(series_left, series_right)
def get_longest_back_rolling(self):
if isinstance(self.feature_left, Expression):
left_br = self.feature_left.get_longest_back_rolling()
@@ -265,7 +270,46 @@ class PairOperator(ExpressionOps):
return max(ll, rl), max(lr, rr)
class Add(PairOperator):
class NpPairOperator(PairOperator):
"""Numpy Pair-wise operator
Parameters
----------
feature_left : Expression
feature instance or numeric value
feature_right : Expression
feature instance or numeric value
func : str
operator function
Returns
----------
Feature:
two features' operation output
"""
def __init__(self, feature_left, feature_right, func):
self.feature_left = feature_left
self.feature_right = feature_right
self.func = func
super(NpPairOperator, self).__init__(feature_left, feature_right)
def _load_internal(self, instrument, start_index, end_index, freq):
assert any(
[isinstance(self.feature_left, Expression), self.feature_right, Expression]
), "at least one of two inputs is Expression instance"
if isinstance(self.feature_left, Expression):
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
else:
series_left = self.feature_left # numeric value
if isinstance(self.feature_right, Expression):
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
else:
series_right = self.feature_right
return getattr(np, self.func)(series_left, series_right)
class Add(NpPairOperator):
"""Add Operator
Parameters
@@ -285,7 +329,7 @@ class Add(PairOperator):
super(Add, self).__init__(feature_left, feature_right, "add")
class Sub(PairOperator):
class Sub(NpPairOperator):
"""Subtract Operator
Parameters
@@ -305,7 +349,7 @@ class Sub(PairOperator):
super(Sub, self).__init__(feature_left, feature_right, "subtract")
class Mul(PairOperator):
class Mul(NpPairOperator):
"""Multiply Operator
Parameters
@@ -325,7 +369,7 @@ class Mul(PairOperator):
super(Mul, self).__init__(feature_left, feature_right, "multiply")
class Div(PairOperator):
class Div(NpPairOperator):
"""Division Operator
Parameters
@@ -345,7 +389,7 @@ class Div(PairOperator):
super(Div, self).__init__(feature_left, feature_right, "divide")
class Greater(PairOperator):
class Greater(NpPairOperator):
"""Greater Operator
Parameters
@@ -365,7 +409,7 @@ class Greater(PairOperator):
super(Greater, self).__init__(feature_left, feature_right, "maximum")
class Less(PairOperator):
class Less(NpPairOperator):
"""Less Operator
Parameters
@@ -385,7 +429,7 @@ class Less(PairOperator):
super(Less, self).__init__(feature_left, feature_right, "minimum")
class Gt(PairOperator):
class Gt(NpPairOperator):
"""Greater Than Operator
Parameters
@@ -405,7 +449,7 @@ class Gt(PairOperator):
super(Gt, self).__init__(feature_left, feature_right, "greater")
class Ge(PairOperator):
class Ge(NpPairOperator):
"""Greater Equal Than Operator
Parameters
@@ -425,7 +469,7 @@ class Ge(PairOperator):
super(Ge, self).__init__(feature_left, feature_right, "greater_equal")
class Lt(PairOperator):
class Lt(NpPairOperator):
"""Less Than Operator
Parameters
@@ -445,7 +489,7 @@ class Lt(PairOperator):
super(Lt, self).__init__(feature_left, feature_right, "less")
class Le(PairOperator):
class Le(NpPairOperator):
"""Less Equal Than Operator
Parameters
@@ -465,7 +509,7 @@ class Le(PairOperator):
super(Le, self).__init__(feature_left, feature_right, "less_equal")
class Eq(PairOperator):
class Eq(NpPairOperator):
"""Equal Operator
Parameters
@@ -485,7 +529,7 @@ class Eq(PairOperator):
super(Eq, self).__init__(feature_left, feature_right, "equal")
class Ne(PairOperator):
class Ne(NpPairOperator):
"""Not Equal Operator
Parameters
@@ -505,7 +549,7 @@ class Ne(PairOperator):
super(Ne, self).__init__(feature_left, feature_right, "not_equal")
class And(PairOperator):
class And(NpPairOperator):
"""And Operator
Parameters
@@ -525,7 +569,7 @@ class And(PairOperator):
super(And, self).__init__(feature_left, feature_right, "bitwise_and")
class Or(PairOperator):
class Or(NpPairOperator):
"""Or Operator
Parameters
@@ -1451,6 +1495,9 @@ class OpsWrapper(object):
def __init__(self):
self._ops = {}
def reset(self):
self._ops = {}
def register(self, ops_list):
for operator in ops_list:
if not issubclass(operator, ExpressionOps):
@@ -1469,12 +1516,15 @@ class OpsWrapper(object):
Operators = OpsWrapper()
Operators.register(OpsList)
def register_custom_ops(C):
"""register custom operator"""
def register_all_ops(C):
"""register all operator"""
logger = get_module_logger("ops")
Operators.reset()
Operators.register(OpsList)
if getattr(C, "custom_ops", None) is not None:
Operators.register(C.custom_ops)
logger.debug("register custom operator {}".format(C.custom_ops))

View File

@@ -66,7 +66,7 @@ class TestDataset(TestAutoData):
# Check the data
# Get data from DataFrame Directly
data_from_df = (
tsdh._handler.fetch(data_key=DataHandlerLP.DK_L)
tsdh.handler.fetch(data_key=DataHandlerLP.DK_L)
.loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"]
.iloc[-30:]
.values

View File

@@ -26,9 +26,6 @@ class Diff(ElemOperator):
a feature instance with first difference
"""
def __init__(self, feature):
super(Diff, self).__init__(feature, "diff")
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.diff()
@@ -50,9 +47,6 @@ class Distance(PairOperator):
a feature instance with distance
"""
def __init__(self, feature_left, feature_right):
super(Distance, self).__init__(feature_left, feature_right, "distance")
def _load_internal(self, instrument, start_index, end_index, freq):
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
series_right = self.feature_right.load(instrument, start_index, end_index, freq)