1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
This commit is contained in:
bxdd
2021-01-28 14:25:55 +00:00
parent f6dd006c35
commit ffa68fd010
5 changed files with 21 additions and 57 deletions

View File

@@ -60,9 +60,14 @@ class HighFreqHandler(DataHandlerLP):
# Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap
simpson_vwap = "($open + 2*$high + 2*$low + $close)/6"
def get_04_price_feature(price_field):
def get_normalized_price_feature(price_field, shift=0):
"""Get 0~4 column price feature ops"""
feature_ops = "{0}/Ref(DayLast({1}), 240)".format(
if shift == 0:
template_norm = "{0}/Ref(DayLast({1}), 240)"
else:
template_norm = "Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240)"
feature_ops = template_norm.format(
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(price_field),
@@ -71,29 +76,18 @@ class HighFreqHandler(DataHandlerLP):
)
return feature_ops
fields += [get_04_price_feature("$open")]
fields += [get_04_price_feature("$high")]
fields += [get_04_price_feature("$low")]
fields += [get_04_price_feature("$close")]
fields += [get_04_price_feature(simpson_vwap)]
fields += [get_normalized_price_feature("$open", 0)]
fields += [get_normalized_price_feature("$high", 0)]
fields += [get_normalized_price_feature("$low", 0)]
fields += [get_normalized_price_feature("$close", 0)]
fields += [get_normalized_price_feature(simpson_vwap, 0)]
names += ["$open", "$high", "$low", "$close", "$vwap"]
def get_59_price_feature(price_field):
"""Get 5~9 column price feature ops"""
feature_ops = "Ref({0}, 240)/Ref(DayLast({1}), 240)".format(
template_if.format(
template_fillnan.format(template_paused.format("$close")),
template_paused.format(price_field),
),
template_fillnan.format(template_paused.format("$close")),
)
return feature_ops
fields += [get_59_price_feature("$open")]
fields += [get_59_price_feature("$high")]
fields += [get_59_price_feature("$low")]
fields += [get_59_price_feature("$close")]
fields += [get_59_price_feature(simpson_vwap)]
fields += [get_normalized_price_feature("$open", 240)]
fields += [get_normalized_price_feature("$high", 240)]
fields += [get_normalized_price_feature("$low", 240)]
fields += [get_normalized_price_feature("$close", 240)]
fields += [get_normalized_price_feature(simpson_vwap, 240)]
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
fields += [

View File

@@ -18,9 +18,6 @@ def get_calendar_day(freq="day", future=False):
class DayLast(ElemOperator):
def __init__(self, feature):
super(DayLast, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
@@ -28,27 +25,18 @@ class DayLast(ElemOperator):
class FFillNan(ElemOperator):
def __init__(self, feature):
super(FFillNan, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="ffill")
class BFillNan(ElemOperator):
def __init__(self, feature):
super(BFillNan, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.fillna(method="bfill")
class Date(ElemOperator):
def __init__(self, feature):
super(Date, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
_calendar = get_calendar_day(freq=freq)
series = self.feature.load(instrument, start_index, end_index, freq)
@@ -56,9 +44,6 @@ class Date(ElemOperator):
class Select(PairOperator):
def __init__(self, condition, feature):
super(Select, self).__init__(condition, feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series_condition = self.feature_left.load(instrument, start_index, end_index, freq)
series_feature = self.feature_right.load(instrument, start_index, end_index, freq)
@@ -66,9 +51,6 @@ class Select(PairOperator):
class IsNull(ElemOperator):
def __init__(self, feature):
super(IsNull, self).__init__(feature)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.isnull()

View File

@@ -18,11 +18,11 @@ from qlib.contrib.evaluate import (
risk_analysis,
)
from qlib.utils import init_instance_by_config
from qlib.utils import init_instance_by_config, exists_qlib_data
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.ops import Operators
from qlib.data.data import Cal
from qlib.utils import exists_qlib_data
from qlib.tests.data import GetData
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull
@@ -102,9 +102,6 @@ class HighfreqWorkflow(object):
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
qlib.init(**QLIB_INIT_CONFIG)

View File

@@ -17,7 +17,7 @@ from qlib.contrib.evaluate import (
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
if __name__ == "__main__":
@@ -25,12 +25,9 @@ if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN, redis_port=-1)
market = "csi300"
benchmark = "SH000300"

View File

@@ -26,9 +26,6 @@ class Diff(ElemOperator):
a feature instance with first difference
"""
def __init__(self, feature):
super(Diff, self).__init__(feature, "diff")
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
return series.diff()
@@ -50,9 +47,6 @@ class Distance(PairOperator):
a feature instance with distance
"""
def __init__(self, feature_left, feature_right):
super(Distance, self).__init__(feature_left, feature_right, "distance")
def _load_internal(self, instrument, start_index, end_index, freq):
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
series_right = self.feature_right.load(instrument, start_index, end_index, freq)