mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
fix bug
This commit is contained in:
@@ -29,8 +29,8 @@ class HighFreqHandler(DataHandlerLP):
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = []
|
||||
learn_processors = []
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
@@ -179,8 +179,6 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
end_time=None,
|
||||
freq="1min",
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
@@ -207,7 +205,7 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
]
|
||||
names += ["$close0"]
|
||||
names += ["$vwap0"]
|
||||
fields += [
|
||||
"If(Eq({1}, np.nan), 0, If(Or(Gt({2}, Mul(1.001, {4})), Lt({2}, Mul(0.999, {3}))), 0, {1}))".format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
|
||||
@@ -11,7 +11,7 @@ class DayFirst(ElemOperator):
|
||||
super(DayFirst, self).__init__(feature, "day_first")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calender_day(freq=freq)[0]
|
||||
_calendar = Cal.get_calendar_day(freq=freq)[0]
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("first")
|
||||
|
||||
@@ -21,7 +21,7 @@ class DayLast(ElemOperator):
|
||||
super(DayLast, self).__init__(feature, "day_last")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calender_day(freq=freq)[0]
|
||||
_calendar = Cal.get_calendar_day(freq=freq)[0]
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.groupby(_calendar[series.index]).transform("last")
|
||||
|
||||
@@ -40,7 +40,7 @@ class Date(ElemOperator):
|
||||
super(Date, self).__init__(feature, "date")
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
_calendar = Cal.get_calender_day(freq=freq)[0]
|
||||
_calendar = Cal.get_calendar_day(freq=freq)[0]
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return pd.Series(_calendar[series.index], index=series.index)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.dataset.processor import Processor
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.data.dataset.utils import fetch_df_by_index
|
||||
|
||||
|
||||
@@ -11,8 +10,9 @@ class HighFreqNorm(Processor):
|
||||
self.fit_end_time = fit_end_time
|
||||
|
||||
def fit(self, df_features):
|
||||
fetch_df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
del df
|
||||
print("==============fit==============")
|
||||
fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
del df_features
|
||||
df_values = fetch_df.values
|
||||
names = {
|
||||
"price": slice(0, 10),
|
||||
@@ -23,17 +23,18 @@ class HighFreqNorm(Processor):
|
||||
self.feature_vmax = {}
|
||||
self.feature_vmin = {}
|
||||
for name, name_val in names.items():
|
||||
part_values = df_values[:, name_val]
|
||||
part_values = df_values[:, name_val].astype(np.float32)
|
||||
if name == "volume":
|
||||
df_features.loc(axis=1)[name_val] = np.log1p(part_values)
|
||||
part_values = np.log1p(part_values)
|
||||
self.feature_med[name] = np.nanmedian(part_values)
|
||||
part_values = part_values - self.feature_med # mean, copy
|
||||
part_values = part_values - self.feature_med[name] # mean, copy
|
||||
self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + 1e-12
|
||||
part_values = part_values / self.feature_std
|
||||
part_values = part_values / self.feature_std[name]
|
||||
self.feature_vmax[name] = np.nanmax(part_values)
|
||||
self.feature_vmin[name] = np.nanmin(part_values)
|
||||
|
||||
def __call__(self, df_features):
|
||||
print("==============call==============")
|
||||
df_features.set_index("date", append=True, drop=True, inplace=True)
|
||||
df_values = df_features.values
|
||||
names = {
|
||||
@@ -58,13 +59,12 @@ class HighFreqNorm(Processor):
|
||||
part_values[slice3] = -3.5
|
||||
# print("start_call_feature_reshape")
|
||||
idx = df_features.index.droplevel("datetime").drop_duplicates()
|
||||
idx.set_names(['instrument', 'datetime'], inplace=True)
|
||||
feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240)
|
||||
feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240)
|
||||
|
||||
df_new_features = pd.DataFrame(
|
||||
data=np.concatenate((feat, feat_1), axis=1),
|
||||
index=idx,
|
||||
columns=["FEATURE_%d" % i for i in range(12 * 240)],
|
||||
).sort_index()
|
||||
|
||||
return df_new_features
|
||||
|
||||
@@ -73,31 +73,36 @@ if __name__ == "__main__":
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
custom_ops=[DayFirst, DayLast, FFillNan, Date, Select, IsNull],
|
||||
redis_port=233,
|
||||
redis_port=-1,
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
)
|
||||
|
||||
MARKET = "csi300"
|
||||
MARKET = "test_10"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2019-01-01 00:00:00"
|
||||
end_time = "2019-12-31 15:00:00"
|
||||
train_end_time = "2019-05-31 15:00:00"
|
||||
test_start_time = "2019-06-01 00:00:00"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG0 = {
|
||||
"start_time": "2017-01-01 00:00:00",
|
||||
"end_time": "2020-11-30 15:00:00",
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"freq": "1min",
|
||||
"fit_start_time": "2017-01-01 00:00:00",
|
||||
"fit_end_time": "2020-08-31 15:00:00",
|
||||
"instruments": "all",
|
||||
"fit_start_time": start_time,
|
||||
"fit_end_time": train_end_time,
|
||||
"instruments": MARKET,
|
||||
"infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor", "kwargs": {}}],
|
||||
}
|
||||
DATA_HANDLER_CONFIG1 = {
|
||||
"start_time": "2017-01-01 00:00:00",
|
||||
"end_time": "2020-11-30 15:00:00",
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"freq": "1min",
|
||||
"instruments": "all",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
task = {
|
||||
@@ -111,10 +116,10 @@ if __name__ == "__main__":
|
||||
"kwargs": DATA_HANDLER_CONFIG0,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2017-01-01 00:00:00", "2020-08-31 15:00:00"),
|
||||
"train": (start_time, train_end_time),
|
||||
"test": (
|
||||
"2020-09-01 00:00:00",
|
||||
"2020-11-30 15:00:00",
|
||||
test_start_time,
|
||||
end_time,
|
||||
),
|
||||
},
|
||||
},
|
||||
@@ -127,19 +132,72 @@ if __name__ == "__main__":
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "HighFreqBacktestHandler",
|
||||
"module_path": "highfreq_hander",
|
||||
"module_path": "highfreq_handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG1,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2017-01-01 00:00:00", "2020-08-31 15:00:00"),
|
||||
"train": (start_time, train_end_time),
|
||||
"test": (
|
||||
"2020-09-01 00:00:00",
|
||||
"2020-11-30 15:00:00",
|
||||
test_start_time,
|
||||
end_time,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
Cal.get_calender_day(freq="1min") # TO FIX: load the calendar day for cache
|
||||
##=============load the calendar for cache=============
|
||||
Cal.calendar(freq="1min")
|
||||
Cal.get_calendar_day(freq="1min")
|
||||
|
||||
|
||||
##=============get data=============
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
dataset_backtest = init_instance_by_config(task["dataset_backtest"])
|
||||
xtrain, xtest = dataset.prepare(['train', 'test'])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(['train', 'test'])
|
||||
print(xtrain, xtest)
|
||||
print(backtest_train, backtest_test)
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
|
||||
##=============dump dataset=============
|
||||
dataset.to_pickle(path="dataset.pkl")
|
||||
dataset_backtest.to_pickle(path="dataset_backtest.pkl")
|
||||
|
||||
del dataset, dataset_backtest
|
||||
##=============reload dataset=============
|
||||
file_dataset = open("dataset.pkl", "rb")
|
||||
dataset = pickle.load(file_dataset)
|
||||
file_dataset.close()
|
||||
|
||||
file_dataset_backtest = open("dataset_backtest.pkl", "rb")
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
|
||||
file_dataset_backtest.close()
|
||||
|
||||
##=============reload_dataset=============
|
||||
dataset.init(init_type=DataHandlerLP.IT_LS)
|
||||
dataset_backtest.init(init_type=DataHandlerLP.IT_LS)
|
||||
|
||||
|
||||
|
||||
##=============reinit qlib=============
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
custom_ops=[DayFirst, DayLast, FFillNan, Date, Select, IsNull],
|
||||
redis_port=-1,
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
)
|
||||
|
||||
Cal.calendar(freq="1min") #load the calendar for cache
|
||||
Cal.get_calendar_day(freq="1min") #load the calendar for cache
|
||||
|
||||
##=============test dataset
|
||||
xtrain, xtest = dataset.prepare(['train', 'test'])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(['train', 'test'])
|
||||
|
||||
print(xtrain, xtest)
|
||||
print(backtest_train, backtest_test)
|
||||
del xtrain, xtest
|
||||
del backtest_train, backtest_test
|
||||
|
||||
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, redis_port=233)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
@@ -291,12 +291,12 @@ class QlibConfig(Config):
|
||||
|
||||
def register(self):
|
||||
from .utils import init_instance_by_config
|
||||
from .data.ops import register_custom_ops
|
||||
from .data.ops import register_all_ops
|
||||
from .data.data import register_all_wrappers
|
||||
from .workflow import R, QlibRecorder
|
||||
from .workflow.utils import experiment_exit_handler
|
||||
|
||||
register_custom_ops(self)
|
||||
register_all_ops(self)
|
||||
register_all_wrappers(self)
|
||||
# set up QlibRecorder
|
||||
exp_manager = init_instance_by_config(self["exp_manager"])
|
||||
|
||||
@@ -123,7 +123,7 @@ class CalendarProvider(abc.ABC):
|
||||
H["c"][flag] = _calendar, _calendar_index
|
||||
return _calendar, _calendar_index
|
||||
|
||||
def get_calender_day(self, freq="day", future=False):
|
||||
def get_calendar_day(self, freq="day", future=False):
|
||||
flag = f"{freq}_future_{future}_day"
|
||||
if flag in H["c"]:
|
||||
_calendar, _calendar_index = H["c"][flag]
|
||||
|
||||
@@ -87,34 +87,16 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
super().__init__(handler, segments)
|
||||
|
||||
def init(self, init_type: str = DataHandlerLP.IT_FIT_SEQ, enable_cache: bool = False):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
def init(self, **kwargs):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
init_type : str
|
||||
- if `init_type` == DataHandlerLP.IT_FIT_SEQ:
|
||||
|
||||
the input of `DataHandlerLP.fit` will be the output of the previous processor
|
||||
|
||||
- if `init_type` == DataHandlerLP.IT_FIT_IND:
|
||||
|
||||
the input of `DataHandlerLP.fit` will be the original df
|
||||
|
||||
- if `init_type` == DataHandlerLP.IT_LS:
|
||||
|
||||
The state of the object has been load by pickle
|
||||
|
||||
enable_cache : bool
|
||||
default value is false:
|
||||
|
||||
- if `enable_cache` == True:
|
||||
|
||||
the processed data will be saved on disk, and handler will load the cached data from the disk directly
|
||||
when we call `init` next time
|
||||
"""
|
||||
self.handler.init(init_type=init_type, enable_cache=enable_cache)
|
||||
logger = get_module_logger("DatasetH")
|
||||
handler_init_kwargs = {}
|
||||
for arg_key, arg_value in kwargs.items():
|
||||
if arg_key in getfullargspec(self.handler.init).args:
|
||||
handler_init_kwargs[arg_key] = arg_value
|
||||
else:
|
||||
logger.info(f"init arguments[{arg_key}] is ignored.")
|
||||
self.handler.init(**handler_init_kwargs)
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: list):
|
||||
"""
|
||||
|
||||
@@ -433,6 +433,8 @@ class DataHandlerLP(DataHandler):
|
||||
except AttributeError:
|
||||
print("please set drop_raw = False if you want to use raw data")
|
||||
raise
|
||||
except:
|
||||
raise
|
||||
return df
|
||||
|
||||
def fetch(
|
||||
|
||||
@@ -147,7 +147,6 @@ class QlibDataLoader(DLWParser):
|
||||
"""
|
||||
self.filter_pipe = filter_pipe
|
||||
self.swap_level = swap_level
|
||||
print("swap level", swap_level)
|
||||
super().__init__(config)
|
||||
|
||||
def load_group_df(
|
||||
|
||||
@@ -17,11 +17,13 @@ from ..log import get_module_logger
|
||||
try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi
|
||||
except ImportError as err:
|
||||
except ImportError:
|
||||
print(
|
||||
"#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####"
|
||||
)
|
||||
raise
|
||||
except:
|
||||
raise
|
||||
|
||||
|
||||
np.seterr(invalid="ignore")
|
||||
@@ -1451,6 +1453,9 @@ class OpsWrapper(object):
|
||||
def __init__(self):
|
||||
self._ops = {}
|
||||
|
||||
def reset(self):
|
||||
self._ops = {}
|
||||
|
||||
def register(self, ops_list):
|
||||
for operator in ops_list:
|
||||
if not issubclass(operator, ExpressionOps):
|
||||
@@ -1469,12 +1474,15 @@ class OpsWrapper(object):
|
||||
|
||||
|
||||
Operators = OpsWrapper()
|
||||
Operators.register(OpsList)
|
||||
|
||||
|
||||
def register_custom_ops(C):
|
||||
"""register custom operator"""
|
||||
def register_all_ops(C):
|
||||
"""register all operator"""
|
||||
logger = get_module_logger("ops")
|
||||
|
||||
Operators.reset()
|
||||
Operators.register(OpsList)
|
||||
|
||||
if getattr(C, "custom_ops", None) is not None:
|
||||
Operators.register(C.custom_ops)
|
||||
logger.debug("register custom operator {}".format(C.custom_ops))
|
||||
|
||||
Reference in New Issue
Block a user