mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
update
This commit is contained in:
@@ -56,88 +56,44 @@ class HighFreqHandler(DataHandlerLP):
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})"
|
||||
# template_paused="{0}"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
template_fillnan = "BFillNan(FFillNan({0}))"
|
||||
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
|
||||
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
|
||||
fields += [
|
||||
"{0}/Ref(DayLast({1}), 240)".format(
|
||||
|
||||
def get_04_price_feature(price_field):
|
||||
"""Get 0~4 column price feature ops"""
|
||||
feature_ops = "{0}/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$open"),
|
||||
template_paused.format(price_field),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"{0}/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$high"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"{0}/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$low"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += ["{0}/Ref(DayLast({0}), 240)".format(template_fillnan.format(template_paused.format("$close")))]
|
||||
fields += [
|
||||
"{0}/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
return feature_ops
|
||||
|
||||
fields += [get_04_price_feature("$open")]
|
||||
fields += [get_04_price_feature("$high")]
|
||||
fields += [get_04_price_feature("$low")]
|
||||
fields += [get_04_price_feature("$close")]
|
||||
fields += [get_04_price_feature(simpson_vwap)]
|
||||
names += ["$open", "$high", "$low", "$close", "$vwap"]
|
||||
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
def get_59_price_feature(price_field):
|
||||
"""Get 5~9 column price feature ops"""
|
||||
feature_ops = "Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$open"),
|
||||
template_paused.format(price_field),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$high"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format("$low"),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({0}), 240)".format(template_fillnan.format(template_paused.format("$close")))
|
||||
]
|
||||
return feature_ops
|
||||
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
),
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
)
|
||||
]
|
||||
fields += [get_59_price_feature("$open")]
|
||||
fields += [get_59_price_feature("$high")]
|
||||
fields += [get_59_price_feature("$low")]
|
||||
fields += [get_59_price_feature("$close")]
|
||||
fields += [get_59_price_feature(simpson_vwap)]
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
fields += [
|
||||
@@ -197,19 +153,20 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
|
||||
template_if = "If(IsNull({1}), {0}, {1})"
|
||||
template_paused = "Select(Or(IsNull($paused), Eq($paused, 0.0)), {0})"
|
||||
# template_paused="{0}"
|
||||
template_fillnan = "FFillNan({0})"
|
||||
template_fillnan = "BFillNan(FFillNan({0}))"
|
||||
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
|
||||
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
|
||||
# fields += [
|
||||
# template_fillnan.format(template_paused.format("$close")),
|
||||
# ]
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
fields += [
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
)
|
||||
]
|
||||
names += ["$vwap_0"]
|
||||
names += ["$vwap0"]
|
||||
fields += [
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
@@ -218,6 +175,6 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
template_paused.format("$high"),
|
||||
)
|
||||
]
|
||||
names += ["$volume_0"]
|
||||
names += ["$volume0"]
|
||||
|
||||
return fields, names
|
||||
@@ -3,51 +3,61 @@ import pandas as pd
|
||||
import importlib
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
from qlib.config import C
|
||||
from qlib.data.cache import H
|
||||
from qlib.data.data import Cal
|
||||
|
||||
|
||||
class DayFirst(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(DayFirst, self).__init__(feature, "day_first")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calendar_day(freq=freq)[0]
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("first")
|
||||
def get_calendar_day(freq="day", future=False):
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future))))
|
||||
H["c"][flag] = _calendar
|
||||
return _calendar
|
||||
|
||||
|
||||
class DayLast(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(DayLast, self).__init__(feature, "day_last")
|
||||
super(DayLast, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calendar_day(freq=freq)[0]
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
|
||||
|
||||
class FFillNan(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(FFillNan, self).__init__(feature, "fill_nan")
|
||||
super(FFillNan, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="ffill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
class BFillNan(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(Date, self).__init__(feature, "date")
|
||||
super(BFillNan, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calendar_day(freq=freq)[0]
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.fillna(method="bfill")
|
||||
|
||||
|
||||
class Date(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(Date, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = get_calendar_day(freq=freq)
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return pd.Series(_calendar[series.index], index=series.index)
|
||||
|
||||
|
||||
class Select(PairOperator):
|
||||
def __init__(self, condition, feature):
|
||||
super(Select, self).__init__(condition, feature, "select")
|
||||
super(Select, self).__init__(condition, feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
@@ -57,7 +67,7 @@ class Select(PairOperator):
|
||||
|
||||
class IsNull(ElemOperator):
|
||||
def __init__(self, feature):
|
||||
super(IsNull, self).__init__(feature, "isnull")
|
||||
super(IsNull, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
@@ -26,7 +26,7 @@ class HighFreqNorm(Processor):
|
||||
if name == "volume":
|
||||
part_values = np.log1p(part_values)
|
||||
self.feature_med[name] = np.nanmedian(part_values)
|
||||
part_values = part_values - self.feature_med[name] # mean, copy
|
||||
part_values = part_values - self.feature_med[name]
|
||||
self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12
|
||||
part_values = part_values / self.feature_std[name]
|
||||
self.feature_vmax[name] = np.nanmax(part_values)
|
||||
@@ -41,23 +41,27 @@ class HighFreqNorm(Processor):
|
||||
}
|
||||
|
||||
for name, name_val in names.items():
|
||||
part_values = df_values[:, name_val]
|
||||
if name == "volume":
|
||||
part_values[:] = np.log1p(part_values)
|
||||
part_values -= self.feature_med[name]
|
||||
part_values /= self.feature_std[name]
|
||||
slice0 = part_values > 3.0
|
||||
slice1 = part_values > 3.5
|
||||
slice2 = part_values < -3.0
|
||||
slice3 = part_values < -3.5
|
||||
df_values[:, name_val] = np.log1p(df_values[:, name_val])
|
||||
df_values[:, name_val] -= self.feature_med[name]
|
||||
df_values[:, name_val] /= self.feature_std[name]
|
||||
slice0 = df_values[:, name_val] > 3.0
|
||||
slice1 = df_values[:, name_val] > 3.5
|
||||
slice2 = df_values[:, name_val] < -3.0
|
||||
slice3 = df_values[:, name_val] < -3.5
|
||||
|
||||
part_values[slice0] = 3.0 + (part_values[slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5
|
||||
part_values[slice1] = 3.5
|
||||
part_values[slice2] = -3.0 - (part_values[slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5
|
||||
part_values[slice3] = -3.5
|
||||
# print("start_call_feature_reshape")
|
||||
df_values[:, name_val][slice0] = (
|
||||
3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5
|
||||
)
|
||||
df_values[:, name_val][slice1] = 3.5
|
||||
df_values[:, name_val][slice2] = (
|
||||
-3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5
|
||||
)
|
||||
df_values[:, name_val][slice3] = -3.5
|
||||
idx = df_features.index.droplevel("datetime").drop_duplicates()
|
||||
idx.set_names(["instrument", "datetime"], inplace=True)
|
||||
|
||||
# Reshape is specifically for adapting to RL high-freq executor
|
||||
feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)
|
||||
feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)
|
||||
df_new_features = pd.DataFrame(
|
||||
@@ -2,13 +2,14 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import fire
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.config import HIGH_FREQ_CONFIG
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
@@ -23,42 +24,22 @@ from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
from highfreq_ops import DayFirst, DayLast, FFillNan, Date, Select, IsNull
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/yahoo_cn_1min"
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
class HighfreqWorkflow(object):
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
custom_ops=[DayFirst, DayLast, FFillNan, Date, Select, IsNull],
|
||||
redis_port=-1,
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
)
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
DROP_LOAD_DATASET = False # flag wether to test [drop and load dataset]
|
||||
|
||||
# start_time = "2019-01-01 00:00:00"
|
||||
# end_time = "2019-12-31 15:00:00"
|
||||
# train_end_time = "2019-05-31 15:00:00"
|
||||
# test_start_time = "2019-06-01 00:00:00"
|
||||
start_time = "2020-09-14 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
train_end_time = "2020-11-30 16:00:00"
|
||||
test_start_time = "2020-12-01 00:00:00"
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
|
||||
DATA_HANDLER_CONFIG0 = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
@@ -94,8 +75,6 @@ if __name__ == "__main__":
|
||||
},
|
||||
},
|
||||
},
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
"dataset_backtest": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
@@ -115,26 +94,50 @@ if __name__ == "__main__":
|
||||
},
|
||||
},
|
||||
}
|
||||
##=============load the calendar for cache=============
|
||||
# unnecessary, but may accelerate
|
||||
Cal.calendar(freq="1min") # load the calendar for cache
|
||||
Cal.get_calendar_day(freq="1min") # load the calendar for cache
|
||||
|
||||
##=============get data=============
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
|
||||
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
print(xtrain, xtest)
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
qlib.init(**QLIB_INIT_CONFIG)
|
||||
|
||||
dataset_backtest = init_instance_by_config(task["dataset_backtest"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
print(backtest_train, backtest_test)
|
||||
def _prepare_calender_cache(self):
|
||||
"""preload the calendar for cache"""
|
||||
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
# This code used the copy-on-write feature of Linux to avoid calculating the calendar multiple times in the subprocess
|
||||
# This code may accelerate, but may be not useful on Windows and Mac Os
|
||||
Cal.calendar(freq="1min")
|
||||
get_calendar_day(freq="1min")
|
||||
|
||||
## example to show how to save the dataset and reload it, and how to use different data
|
||||
if DROP_LOAD_DATASET:
|
||||
def get_data(self):
|
||||
"""use dataset to get highreq data"""
|
||||
self._init_qlib()
|
||||
self._prepare_calender_cache()
|
||||
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
print(xtrain, xtest)
|
||||
|
||||
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
print(backtest_train, backtest_test)
|
||||
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
|
||||
def dump_and_load_dataset(self):
|
||||
"""dump and load dataset state on disk"""
|
||||
self._init_qlib()
|
||||
self._prepare_calender_cache()
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
|
||||
|
||||
##=============dump dataset=============
|
||||
dataset.to_pickle(path="dataset.pkl")
|
||||
@@ -142,33 +145,18 @@ if __name__ == "__main__":
|
||||
|
||||
del dataset, dataset_backtest
|
||||
##=============reload dataset=============
|
||||
file_dataset = open("dataset.pkl", "rb")
|
||||
dataset = pickle.load(file_dataset)
|
||||
file_dataset.close()
|
||||
with open("dataset.pkl", "rb") as file_dataset:
|
||||
dataset = pickle.load(file_dataset)
|
||||
|
||||
file_dataset_backtest = open("dataset_backtest.pkl", "rb")
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
|
||||
file_dataset_backtest.close()
|
||||
with open("dataset_backtest.pkl", "rb") as file_dataset_backtest:
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reload_dataset=============
|
||||
dataset.init(init_type=DataHandlerLP.IT_LS)
|
||||
dataset_backtest.init(init_type=DataHandlerLP.IT_LS)
|
||||
dataset_backtest.init()
|
||||
|
||||
##=============reinit qlib=============
|
||||
## Unless you want to modify the provider_uri and other configurations, reinit is unnecessary
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
custom_ops=[DayFirst, DayLast, FFillNan, Date, Select, IsNull],
|
||||
redis_port=-1,
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
)
|
||||
|
||||
Cal.calendar(freq="1min") # load the calendar for cache
|
||||
Cal.get_calendar_day(freq="1min") # load the calendar for cache
|
||||
|
||||
##=============test dataset=============
|
||||
##=============get data=============
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
|
||||
@@ -176,3 +164,7 @@ if __name__ == "__main__":
|
||||
print(backtest_train, backtest_test)
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(HighfreqWorkflow)
|
||||
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, redis_port=233)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
@@ -193,6 +193,12 @@ MODE_CONF = {
|
||||
},
|
||||
}
|
||||
|
||||
HIGH_FREQ_CONFIG = {
|
||||
"provider_uri": "~/.qlib/qlib_data/yahoo_cn_1min",
|
||||
"dataset_cache": None,
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"region": REG_CN,
|
||||
}
|
||||
|
||||
_default_region_config = {
|
||||
REG_CN: {
|
||||
|
||||
@@ -157,7 +157,7 @@ class Expression(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
pass
|
||||
raise NotImplementedError("This function must be implemented in your newly defined feature")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_longest_back_rolling(self):
|
||||
|
||||
@@ -117,17 +117,7 @@ class CalendarProvider(abc.ABC):
|
||||
if flag in H["c"]:
|
||||
_calendar, _calendar_index = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(self._load_calendar(freq, future))
|
||||
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
|
||||
H["c"][flag] = _calendar, _calendar_index
|
||||
return _calendar, _calendar_index
|
||||
|
||||
def get_calendar_day(self, freq="day", future=False):
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar, _calendar_index = H["c"][flag]
|
||||
else:
|
||||
_calendar = np.array(list(map(lambda x: x.date(), self._load_calendar(freq, future))))
|
||||
_calendar = np.array(self.load_calendar(freq, future))
|
||||
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
|
||||
H["c"][flag] = _calendar, _calendar_index
|
||||
return _calendar, _calendar_index
|
||||
@@ -514,7 +504,7 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
"""Calendar file uri."""
|
||||
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
|
||||
|
||||
def _load_calendar(self, freq, future):
|
||||
def load_calendar(self, freq, future):
|
||||
"""Load original calendar timestamp from file.
|
||||
|
||||
Parameters
|
||||
@@ -679,12 +669,11 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
|
||||
# 2) The the precision should be configurable
|
||||
try:
|
||||
if series.dtype == np.float64:
|
||||
series = series.astype(np.float32)
|
||||
elif series.dtype == np.bool:
|
||||
series = series.astype(np.int8)
|
||||
series = series.astype(np.float32)
|
||||
except ValueError:
|
||||
pass
|
||||
except TypeError:
|
||||
pass
|
||||
if not series.empty:
|
||||
series = series.loc[start_index:end_index]
|
||||
return series
|
||||
|
||||
@@ -88,15 +88,8 @@ class DatasetH(Dataset):
|
||||
super().__init__(handler, segments)
|
||||
|
||||
def init(self, **kwargs):
|
||||
|
||||
logger = get_module_logger("DatasetH")
|
||||
handler_init_kwargs = {}
|
||||
for arg_key, arg_value in kwargs.items():
|
||||
if arg_key in getfullargspec(self.handler.init).args:
|
||||
handler_init_kwargs[arg_key] = arg_value
|
||||
else:
|
||||
logger.info(f"init arguments[{arg_key}] is ignored.")
|
||||
self.handler.init(**handler_init_kwargs)
|
||||
"""Initialize the DatasetH, Only parameters belonging to handler.init will be passed in"""
|
||||
self.handler.init(**kwargs)
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: list):
|
||||
"""
|
||||
|
||||
@@ -428,13 +428,11 @@ class DataHandlerLP(DataHandler):
|
||||
# TODO: Be able to cache handler data. Save the memory for data processing
|
||||
|
||||
def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
|
||||
try:
|
||||
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
except AttributeError:
|
||||
print("please set drop_raw = False if you want to use raw data")
|
||||
raise
|
||||
except:
|
||||
raise
|
||||
if data_key == self.DK_R and self.drop_raw:
|
||||
raise AttributeError(
|
||||
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
|
||||
)
|
||||
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
return df
|
||||
|
||||
def fetch(
|
||||
|
||||
138
qlib/data/ops.py
138
qlib/data/ops.py
@@ -6,6 +6,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import abc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -22,8 +23,6 @@ except ImportError:
|
||||
"#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####"
|
||||
)
|
||||
raise
|
||||
except:
|
||||
raise
|
||||
|
||||
|
||||
np.seterr(invalid="ignore")
|
||||
@@ -34,12 +33,39 @@ np.seterr(invalid="ignore")
|
||||
class ElemOperator(ExpressionOps):
|
||||
"""Element-wise Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
feature operation output
|
||||
"""
|
||||
|
||||
def __init__(self, feature):
|
||||
self.feature = feature
|
||||
|
||||
def __str__(self):
|
||||
return "{}({})".format(type(self).__name__, self.feature)
|
||||
|
||||
def get_longest_back_rolling(self):
|
||||
return self.feature.get_longest_back_rolling()
|
||||
|
||||
def get_extended_window_size(self):
|
||||
return self.feature.get_extended_window_size()
|
||||
|
||||
|
||||
class NpElemOperator(ElemOperator):
|
||||
"""Numpy Element-wise Operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
feature instance
|
||||
func : str
|
||||
feature operation method
|
||||
numpy feature operation method
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -50,22 +76,14 @@ class ElemOperator(ExpressionOps):
|
||||
def __init__(self, feature, func):
|
||||
self.feature = feature
|
||||
self.func = func
|
||||
|
||||
def __str__(self):
|
||||
return "{}({})".format(type(self).__name__, self.feature)
|
||||
super(NpElemOperator, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return getattr(np, self.func)(series)
|
||||
|
||||
def get_longest_back_rolling(self):
|
||||
return self.feature.get_longest_back_rolling()
|
||||
|
||||
def get_extended_window_size(self):
|
||||
return self.feature.get_extended_window_size()
|
||||
|
||||
|
||||
class Abs(ElemOperator):
|
||||
class Abs(NpElemOperator):
|
||||
"""Feature Absolute Value
|
||||
|
||||
Parameters
|
||||
@@ -83,7 +101,7 @@ class Abs(ElemOperator):
|
||||
super(Abs, self).__init__(feature, "abs")
|
||||
|
||||
|
||||
class Sign(ElemOperator):
|
||||
class Sign(NpElemOperator):
|
||||
"""Feature Sign
|
||||
|
||||
Parameters
|
||||
@@ -110,7 +128,7 @@ class Sign(ElemOperator):
|
||||
return getattr(np, self.func)(series)
|
||||
|
||||
|
||||
class Log(ElemOperator):
|
||||
class Log(NpElemOperator):
|
||||
"""Feature Log
|
||||
|
||||
Parameters
|
||||
@@ -128,7 +146,7 @@ class Log(ElemOperator):
|
||||
super(Log, self).__init__(feature, "log")
|
||||
|
||||
|
||||
class Power(ElemOperator):
|
||||
class Power(NpElemOperator):
|
||||
"""Feature Power
|
||||
|
||||
Parameters
|
||||
@@ -154,7 +172,7 @@ class Power(ElemOperator):
|
||||
return getattr(np, self.func)(series, self.exponent)
|
||||
|
||||
|
||||
class Mask(ElemOperator):
|
||||
class Mask(NpElemOperator):
|
||||
"""Feature Mask
|
||||
|
||||
Parameters
|
||||
@@ -181,7 +199,7 @@ class Mask(ElemOperator):
|
||||
return self.feature.load(self.instrument, start_index, end_index, freq)
|
||||
|
||||
|
||||
class Not(ElemOperator):
|
||||
class Not(NpElemOperator):
|
||||
"""Not Operator
|
||||
|
||||
Parameters
|
||||
@@ -220,28 +238,13 @@ class PairOperator(ExpressionOps):
|
||||
two features' operation output
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right, func):
|
||||
def __init__(self, feature_left, feature_right):
|
||||
self.feature_left = feature_left
|
||||
self.feature_right = feature_right
|
||||
self.func = func
|
||||
|
||||
def __str__(self):
|
||||
return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
assert any(
|
||||
[isinstance(self.feature_left, Expression), self.feature_right, Expression]
|
||||
), "at least one of two inputs is Expression instance"
|
||||
if isinstance(self.feature_left, Expression):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
else:
|
||||
series_left = self.feature_left # numeric value
|
||||
if isinstance(self.feature_right, Expression):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
else:
|
||||
series_right = self.feature_right
|
||||
return getattr(np, self.func)(series_left, series_right)
|
||||
|
||||
def get_longest_back_rolling(self):
|
||||
if isinstance(self.feature_left, Expression):
|
||||
left_br = self.feature_left.get_longest_back_rolling()
|
||||
@@ -267,7 +270,46 @@ class PairOperator(ExpressionOps):
|
||||
return max(ll, rl), max(lr, rr)
|
||||
|
||||
|
||||
class Add(PairOperator):
|
||||
class NpPairOperator(PairOperator):
|
||||
"""Numpy Pair-wise operator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature_left : Expression
|
||||
feature instance or numeric value
|
||||
feature_right : Expression
|
||||
feature instance or numeric value
|
||||
func : str
|
||||
operator function
|
||||
|
||||
Returns
|
||||
----------
|
||||
Feature:
|
||||
two features' operation output
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right, func):
|
||||
self.feature_left = feature_left
|
||||
self.feature_right = feature_right
|
||||
self.func = func
|
||||
super(NpPairOperator, self).__init__(feature_left, feature_right)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
assert any(
|
||||
[isinstance(self.feature_left, Expression), self.feature_right, Expression]
|
||||
), "at least one of two inputs is Expression instance"
|
||||
if isinstance(self.feature_left, Expression):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
else:
|
||||
series_left = self.feature_left # numeric value
|
||||
if isinstance(self.feature_right, Expression):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
else:
|
||||
series_right = self.feature_right
|
||||
return getattr(np, self.func)(series_left, series_right)
|
||||
|
||||
|
||||
class Add(NpPairOperator):
|
||||
"""Add Operator
|
||||
|
||||
Parameters
|
||||
@@ -287,7 +329,7 @@ class Add(PairOperator):
|
||||
super(Add, self).__init__(feature_left, feature_right, "add")
|
||||
|
||||
|
||||
class Sub(PairOperator):
|
||||
class Sub(NpPairOperator):
|
||||
"""Subtract Operator
|
||||
|
||||
Parameters
|
||||
@@ -307,7 +349,7 @@ class Sub(PairOperator):
|
||||
super(Sub, self).__init__(feature_left, feature_right, "subtract")
|
||||
|
||||
|
||||
class Mul(PairOperator):
|
||||
class Mul(NpPairOperator):
|
||||
"""Multiply Operator
|
||||
|
||||
Parameters
|
||||
@@ -327,7 +369,7 @@ class Mul(PairOperator):
|
||||
super(Mul, self).__init__(feature_left, feature_right, "multiply")
|
||||
|
||||
|
||||
class Div(PairOperator):
|
||||
class Div(NpPairOperator):
|
||||
"""Division Operator
|
||||
|
||||
Parameters
|
||||
@@ -347,7 +389,7 @@ class Div(PairOperator):
|
||||
super(Div, self).__init__(feature_left, feature_right, "divide")
|
||||
|
||||
|
||||
class Greater(PairOperator):
|
||||
class Greater(NpPairOperator):
|
||||
"""Greater Operator
|
||||
|
||||
Parameters
|
||||
@@ -367,7 +409,7 @@ class Greater(PairOperator):
|
||||
super(Greater, self).__init__(feature_left, feature_right, "maximum")
|
||||
|
||||
|
||||
class Less(PairOperator):
|
||||
class Less(NpPairOperator):
|
||||
"""Less Operator
|
||||
|
||||
Parameters
|
||||
@@ -387,7 +429,7 @@ class Less(PairOperator):
|
||||
super(Less, self).__init__(feature_left, feature_right, "minimum")
|
||||
|
||||
|
||||
class Gt(PairOperator):
|
||||
class Gt(NpPairOperator):
|
||||
"""Greater Than Operator
|
||||
|
||||
Parameters
|
||||
@@ -407,7 +449,7 @@ class Gt(PairOperator):
|
||||
super(Gt, self).__init__(feature_left, feature_right, "greater")
|
||||
|
||||
|
||||
class Ge(PairOperator):
|
||||
class Ge(NpPairOperator):
|
||||
"""Greater Equal Than Operator
|
||||
|
||||
Parameters
|
||||
@@ -427,7 +469,7 @@ class Ge(PairOperator):
|
||||
super(Ge, self).__init__(feature_left, feature_right, "greater_equal")
|
||||
|
||||
|
||||
class Lt(PairOperator):
|
||||
class Lt(NpPairOperator):
|
||||
"""Less Than Operator
|
||||
|
||||
Parameters
|
||||
@@ -447,7 +489,7 @@ class Lt(PairOperator):
|
||||
super(Lt, self).__init__(feature_left, feature_right, "less")
|
||||
|
||||
|
||||
class Le(PairOperator):
|
||||
class Le(NpPairOperator):
|
||||
"""Less Equal Than Operator
|
||||
|
||||
Parameters
|
||||
@@ -467,7 +509,7 @@ class Le(PairOperator):
|
||||
super(Le, self).__init__(feature_left, feature_right, "less_equal")
|
||||
|
||||
|
||||
class Eq(PairOperator):
|
||||
class Eq(NpPairOperator):
|
||||
"""Equal Operator
|
||||
|
||||
Parameters
|
||||
@@ -487,7 +529,7 @@ class Eq(PairOperator):
|
||||
super(Eq, self).__init__(feature_left, feature_right, "equal")
|
||||
|
||||
|
||||
class Ne(PairOperator):
|
||||
class Ne(NpPairOperator):
|
||||
"""Not Equal Operator
|
||||
|
||||
Parameters
|
||||
@@ -507,7 +549,7 @@ class Ne(PairOperator):
|
||||
super(Ne, self).__init__(feature_left, feature_right, "not_equal")
|
||||
|
||||
|
||||
class And(PairOperator):
|
||||
class And(NpPairOperator):
|
||||
"""And Operator
|
||||
|
||||
Parameters
|
||||
@@ -527,7 +569,7 @@ class And(PairOperator):
|
||||
super(And, self).__init__(feature_left, feature_right, "bitwise_and")
|
||||
|
||||
|
||||
class Or(PairOperator):
|
||||
class Or(NpPairOperator):
|
||||
"""Or Operator
|
||||
|
||||
Parameters
|
||||
|
||||
Reference in New Issue
Block a user