mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 01:21:18 +08:00
add Cut ops
This commit is contained in:
@@ -62,9 +62,9 @@ class HighFreqHandler(DataHandlerLP):
|
||||
def get_normalized_price_feature(price_field, shift=0):
|
||||
"""Get normalized price feature ops"""
|
||||
if shift == 0:
|
||||
template_norm = "{0}/Ref(DayLast({1}), 240)"
|
||||
template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)"
|
||||
else:
|
||||
template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)"
|
||||
template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)"
|
||||
|
||||
feature_ops = template_norm.format(
|
||||
template_if.format(
|
||||
@@ -90,7 +90,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
|
||||
|
||||
fields += [
|
||||
"{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(
|
||||
"Cut({0}/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
@@ -101,7 +101,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume"]
|
||||
fields += [
|
||||
"Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format(
|
||||
"Cut(Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240), 240, None)".format(
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
@@ -112,7 +112,7 @@ class HighFreqHandler(DataHandlerLP):
|
||||
]
|
||||
names += ["$volume_1"]
|
||||
|
||||
fields += [template_paused.format("Date($close)")]
|
||||
fields += ["Cut({0}, 240, None)".format(template_paused.format("Date($close)"))]
|
||||
names += ["date"]
|
||||
return fields, names
|
||||
|
||||
@@ -149,18 +149,20 @@ class HighFreqBacktestHandler(DataHandler):
|
||||
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
|
||||
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
|
||||
fields += [
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
"Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))),
|
||||
]
|
||||
names += ["$close0"]
|
||||
fields += [
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
"Cut({0}, 240, None)".format(
|
||||
template_if.format(
|
||||
template_fillnan.format(template_paused.format("$close")),
|
||||
template_paused.format(simpson_vwap),
|
||||
)
|
||||
)
|
||||
]
|
||||
names += ["$vwap0"]
|
||||
fields += [
|
||||
"If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0}))".format(
|
||||
"Cut(If(IsNull({0}), 0, If(Or(Gt({1}, Mul(1.001, {3})), Lt({1}, Mul(0.999, {2}))), 0, {0})), 240, None)".format(
|
||||
template_paused.format("$volume"),
|
||||
template_paused.format(simpson_vwap),
|
||||
template_paused.format("$low"),
|
||||
|
||||
@@ -54,3 +54,25 @@ class IsNull(ElemOperator):
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.isnull()
|
||||
|
||||
|
||||
class Cut(ElemOperator):
|
||||
def __init__(self, feature, l=None, r=None):
|
||||
self.l = l
|
||||
self.r = r
|
||||
if (self.l != None and self.l <= 0) or (self.r != None and self.r >= 0):
|
||||
raise ValueError("Cut operator l shoud > 0 and r should < 0")
|
||||
|
||||
super(Cut, self).__init__(feature)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
return series.iloc[self.l : self.r]
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll = 0 if self.l == None else self.l
|
||||
rr = 0 if self.r == None else abs(self.r)
|
||||
lft_etd, rght_etd = self.feature.get_extended_window_size()
|
||||
lft_etd = lft_etd + ll
|
||||
rght_etd = rght_etd + rr
|
||||
return lft_etd, rght_etd
|
||||
|
||||
@@ -24,17 +24,17 @@ from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow(object):
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull], "expression_cache": None}
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
MARKET = ["SH600004"]
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2020-09-14 00:00:00"
|
||||
start_time = "2020-09-15 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
train_end_time = "2020-11-30 16:00:00"
|
||||
test_start_time = "2020-12-01 00:00:00"
|
||||
|
||||
Reference in New Issue
Block a user