mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
368 lines
14 KiB
Python
368 lines
14 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from ...data.dataset.handler import DataHandlerLP
|
|
from ...data.dataset.processor import Processor
|
|
from ...utils import get_cls_kwargs
|
|
from ...data.dataset import processor as processor_module
|
|
from ...log import TimeInspector
|
|
from inspect import getfullargspec
|
|
import copy
|
|
|
|
|
|
def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
|
new_l = []
|
|
for p in proc_l:
|
|
if not isinstance(p, Processor):
|
|
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
|
args = getfullargspec(klass).args
|
|
if "fit_start_time" in args and "fit_end_time" in args:
|
|
assert (
|
|
fit_start_time is not None and fit_end_time is not None
|
|
), "Make sure `fit_start_time` and `fit_end_time` are not None."
|
|
pkwargs.update(
|
|
{
|
|
"fit_start_time": fit_start_time,
|
|
"fit_end_time": fit_end_time,
|
|
}
|
|
)
|
|
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
|
else:
|
|
new_l.append(p)
|
|
return new_l
|
|
|
|
|
|
_DEFAULT_LEARN_PROCESSORS = [
|
|
{"class": "DropnaLabel"},
|
|
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
|
]
|
|
_DEFAULT_INFER_PROCESSORS = [
|
|
{"class": "ProcessInf", "kwargs": {}},
|
|
{"class": "ZScoreNorm", "kwargs": {}},
|
|
{"class": "Fillna", "kwargs": {}},
|
|
]
|
|
|
|
|
|
class Alpha360(DataHandlerLP):
|
|
def __init__(
|
|
self,
|
|
instruments="csi500",
|
|
start_time=None,
|
|
end_time=None,
|
|
freq="day",
|
|
infer_processors=_DEFAULT_INFER_PROCESSORS,
|
|
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
|
fit_start_time=None,
|
|
fit_end_time=None,
|
|
filter_pipe=None,
|
|
**kwargs,
|
|
):
|
|
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
|
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
|
|
|
data_loader = {
|
|
"class": "QlibDataLoader",
|
|
"kwargs": {
|
|
"config": {
|
|
"feature": self.get_feature_config(),
|
|
"label": kwargs.get("label", self.get_label_config()),
|
|
},
|
|
"filter_pipe": filter_pipe,
|
|
"freq": freq,
|
|
},
|
|
}
|
|
|
|
super().__init__(
|
|
instruments=instruments,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
data_loader=data_loader,
|
|
learn_processors=learn_processors,
|
|
infer_processors=infer_processors,
|
|
)
|
|
|
|
def get_label_config(self):
|
|
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
|
|
|
def get_feature_config(self):
|
|
|
|
fields = []
|
|
names = []
|
|
|
|
for i in range(59, 0, -1):
|
|
fields += ["Ref($close, %d)/$close" % (i)]
|
|
names += ["CLOSE%d" % (i)]
|
|
fields += ["$close/$close"]
|
|
names += ["CLOSE0"]
|
|
for i in range(59, 0, -1):
|
|
fields += ["Ref($open, %d)/$close" % (i)]
|
|
names += ["OPEN%d" % (i)]
|
|
fields += ["$open/$close"]
|
|
names += ["OPEN0"]
|
|
for i in range(59, 0, -1):
|
|
fields += ["Ref($high, %d)/$close" % (i)]
|
|
names += ["HIGH%d" % (i)]
|
|
fields += ["$high/$close"]
|
|
names += ["HIGH0"]
|
|
for i in range(59, 0, -1):
|
|
fields += ["Ref($low, %d)/$close" % (i)]
|
|
names += ["LOW%d" % (i)]
|
|
fields += ["$low/$close"]
|
|
names += ["LOW0"]
|
|
for i in range(59, 0, -1):
|
|
fields += ["Ref($vwap, %d)/$close" % (i)]
|
|
names += ["VWAP%d" % (i)]
|
|
fields += ["$vwap/$close"]
|
|
names += ["VWAP0"]
|
|
for i in range(59, 0, -1):
|
|
fields += ["Ref($volume, %d)/$volume" % (i)]
|
|
names += ["VOLUME%d" % (i)]
|
|
fields += ["$volume/$volume"]
|
|
names += ["VOLUME0"]
|
|
|
|
return fields, names
|
|
|
|
|
|
class Alpha360vwap(Alpha360):
|
|
def get_label_config(self):
|
|
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
|
|
|
|
|
|
class Alpha158(DataHandlerLP):
|
|
def __init__(
|
|
self,
|
|
instruments="csi500",
|
|
start_time=None,
|
|
end_time=None,
|
|
freq="day",
|
|
infer_processors=[],
|
|
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
|
fit_start_time=None,
|
|
fit_end_time=None,
|
|
process_type=DataHandlerLP.PTYPE_A,
|
|
filter_pipe=None,
|
|
**kwargs,
|
|
):
|
|
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
|
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
|
|
|
data_loader = {
|
|
"class": "QlibDataLoader",
|
|
"kwargs": {
|
|
"config": {
|
|
"feature": self.get_feature_config(),
|
|
"label": kwargs.get("label", self.get_label_config()),
|
|
},
|
|
"filter_pipe": filter_pipe,
|
|
"freq": freq,
|
|
},
|
|
}
|
|
super().__init__(
|
|
instruments=instruments,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
data_loader=data_loader,
|
|
infer_processors=infer_processors,
|
|
learn_processors=learn_processors,
|
|
process_type=process_type,
|
|
)
|
|
|
|
def get_feature_config(self):
|
|
conf = {
|
|
"kbar": {},
|
|
"price": {
|
|
"windows": [0],
|
|
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
|
|
},
|
|
"rolling": {},
|
|
}
|
|
return self.parse_config_to_fields(conf)
|
|
|
|
def get_label_config(self):
|
|
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
|
|
|
@staticmethod
|
|
def parse_config_to_fields(config):
|
|
"""create factors from config
|
|
|
|
config = {
|
|
'kbar': {}, # whether to use some hard-code kbar features
|
|
'price': { # whether to use raw price features
|
|
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
|
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
|
},
|
|
'volume': { # whether to use raw volume features
|
|
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
|
},
|
|
'rolling': { # whether to use rolling operator based features
|
|
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
|
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
|
#if include is None we will use default operators
|
|
'exclude': ['RANK'], # rolling operator not to use
|
|
}
|
|
}
|
|
"""
|
|
fields = []
|
|
names = []
|
|
if "kbar" in config:
|
|
fields += [
|
|
"($close-$open)/$open",
|
|
"($high-$low)/$open",
|
|
"($close-$open)/($high-$low+1e-12)",
|
|
"($high-Greater($open, $close))/$open",
|
|
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
|
"(Less($open, $close)-$low)/$open",
|
|
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
|
"(2*$close-$high-$low)/$open",
|
|
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
|
]
|
|
names += [
|
|
"KMID",
|
|
"KLEN",
|
|
"KMID2",
|
|
"KUP",
|
|
"KUP2",
|
|
"KLOW",
|
|
"KLOW2",
|
|
"KSFT",
|
|
"KSFT2",
|
|
]
|
|
if "price" in config:
|
|
windows = config["price"].get("windows", range(5))
|
|
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
|
for field in feature:
|
|
field = field.lower()
|
|
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
|
names += [field.upper() + str(d) for d in windows]
|
|
if "volume" in config:
|
|
windows = config["volume"].get("windows", range(5))
|
|
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
|
|
names += ["VOLUME" + str(d) for d in windows]
|
|
if "rolling" in config:
|
|
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
|
include = config["rolling"].get("include", None)
|
|
exclude = config["rolling"].get("exclude", [])
|
|
# `exclude` in dataset config unnecessary filed
|
|
# `include` in dataset config necessary field
|
|
use = lambda x: x not in exclude and (include is None or x in include)
|
|
if use("ROC"):
|
|
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
|
names += ["ROC%d" % d for d in windows]
|
|
if use("MA"):
|
|
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
|
names += ["MA%d" % d for d in windows]
|
|
if use("STD"):
|
|
fields += ["Std($close, %d)/$close" % d for d in windows]
|
|
names += ["STD%d" % d for d in windows]
|
|
if use("BETA"):
|
|
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
|
names += ["BETA%d" % d for d in windows]
|
|
if use("RSQR"):
|
|
fields += ["Rsquare($close, %d)" % d for d in windows]
|
|
names += ["RSQR%d" % d for d in windows]
|
|
if use("RESI"):
|
|
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
|
names += ["RESI%d" % d for d in windows]
|
|
if use("MAX"):
|
|
fields += ["Max($high, %d)/$close" % d for d in windows]
|
|
names += ["MAX%d" % d for d in windows]
|
|
if use("LOW"):
|
|
fields += ["Min($low, %d)/$close" % d for d in windows]
|
|
names += ["MIN%d" % d for d in windows]
|
|
if use("QTLU"):
|
|
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
|
names += ["QTLU%d" % d for d in windows]
|
|
if use("QTLD"):
|
|
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
|
names += ["QTLD%d" % d for d in windows]
|
|
if use("RANK"):
|
|
fields += ["Rank($close, %d)" % d for d in windows]
|
|
names += ["RANK%d" % d for d in windows]
|
|
if use("RSV"):
|
|
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
|
names += ["RSV%d" % d for d in windows]
|
|
if use("IMAX"):
|
|
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
|
names += ["IMAX%d" % d for d in windows]
|
|
if use("IMIN"):
|
|
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
|
names += ["IMIN%d" % d for d in windows]
|
|
if use("IMXD"):
|
|
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
|
names += ["IMXD%d" % d for d in windows]
|
|
if use("CORR"):
|
|
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
|
names += ["CORR%d" % d for d in windows]
|
|
if use("CORD"):
|
|
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
|
names += ["CORD%d" % d for d in windows]
|
|
if use("CNTP"):
|
|
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
|
names += ["CNTP%d" % d for d in windows]
|
|
if use("CNTN"):
|
|
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
|
names += ["CNTN%d" % d for d in windows]
|
|
if use("CNTD"):
|
|
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
|
names += ["CNTD%d" % d for d in windows]
|
|
if use("SUMP"):
|
|
fields += [
|
|
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
|
for d in windows
|
|
]
|
|
names += ["SUMP%d" % d for d in windows]
|
|
if use("SUMN"):
|
|
fields += [
|
|
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
|
for d in windows
|
|
]
|
|
names += ["SUMN%d" % d for d in windows]
|
|
if use("SUMD"):
|
|
fields += [
|
|
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
|
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
|
for d in windows
|
|
]
|
|
names += ["SUMD%d" % d for d in windows]
|
|
if use("VMA"):
|
|
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
|
names += ["VMA%d" % d for d in windows]
|
|
if use("VSTD"):
|
|
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
|
names += ["VSTD%d" % d for d in windows]
|
|
if use("WVMA"):
|
|
fields += [
|
|
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
|
% (d, d)
|
|
for d in windows
|
|
]
|
|
names += ["WVMA%d" % d for d in windows]
|
|
if use("VSUMP"):
|
|
fields += [
|
|
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
|
% (d, d)
|
|
for d in windows
|
|
]
|
|
names += ["VSUMP%d" % d for d in windows]
|
|
if use("VSUMN"):
|
|
fields += [
|
|
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
|
% (d, d)
|
|
for d in windows
|
|
]
|
|
names += ["VSUMN%d" % d for d in windows]
|
|
if use("VSUMD"):
|
|
fields += [
|
|
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
|
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
|
for d in windows
|
|
]
|
|
names += ["VSUMD%d" % d for d in windows]
|
|
|
|
return fields, names
|
|
|
|
|
|
class Alpha158vwap(Alpha158):
|
|
def get_label_config(self):
|
|
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
|