1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/tests/rolling_tests/test_update_pred.py
you-n-g be4646b4b7 Adjust rolling api (#1594)
* Intermediate version

* Fix yaml template & Successfully run rolling

* Be compatible with benchmark

* Get same results with previous linear model

* Black formatting

* Update black

* Update the placeholder mechanism

* Update CI

* Update CI

* Upgrade Black

* Fix CI and simplify code

* Fix CI

* Move the data processing caching mechanism into utils.

* Adjusting DDG-DA

* Organize import
2023-07-14 12:16:12 +08:00

136 lines
4.5 KiB
Python

import copy
import unittest
import pytest
import fire
import pandas as pd
import qlib
from qlib.data import D
from qlib.model.trainer import task_train
from qlib.tests import TestAutoData
from qlib.tests.config import CSI300_GBDT_TASK
from qlib.workflow.online.utils import OnlineToolR
from qlib.workflow.online.update import LabelUpdater
class TestRolling(TestAutoData):
@pytest.mark.slow
def test_update_pred(self):
"""
This test is for testing if it will raise error if the `to_date` is out of the boundary.
"""
task = copy.deepcopy(CSI300_GBDT_TASK)
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]
exp_name = "online_srv_test"
cal = D.calendar()
latest_date = cal[-1]
train_start = latest_date - pd.Timedelta(days=61)
train_end = latest_date - pd.Timedelta(days=41)
task["dataset"]["kwargs"]["segments"] = {
"train": (train_start, train_end),
"valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
"test": (latest_date - pd.Timedelta(days=20), latest_date),
}
task["dataset"]["kwargs"]["handler"]["kwargs"] = {
"start_time": train_start,
"end_time": latest_date,
"fit_start_time": train_start,
"fit_end_time": train_end,
"instruments": "csi300",
}
rec = task_train(task, exp_name)
pred = rec.load_object("pred.pkl")
online_tool = OnlineToolR(exp_name)
online_tool.reset_online_tag(rec) # set to online model
online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))
good_pred = rec.load_object("pred.pkl")
mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10))
mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2))
mod_pred = good_pred.copy()
mod_pred.loc[mod_range] = -1
mod_pred.loc[mod_range2] = -2
rec.save_objects(**{"pred.pkl": mod_pred})
online_tool.update_online_pred(
to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20)
)
updated_pred = rec.load_object("pred.pkl")
# this range is not fixed
self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item())
# this range is fixed now
self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())
@pytest.mark.slow
def test_update_label(self):
task = copy.deepcopy(CSI300_GBDT_TASK)
task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": {"dataset": "<DATASET>", "model": "<MODEL>"},
}
exp_name = "online_srv_test"
cal = D.calendar()
shift = 10
latest_date = cal[-1 - shift]
train_start = latest_date - pd.Timedelta(days=61)
train_end = latest_date - pd.Timedelta(days=41)
task["dataset"]["kwargs"]["segments"] = {
"train": (train_start, train_end),
"valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)),
"test": (latest_date - pd.Timedelta(days=20), latest_date),
}
task["dataset"]["kwargs"]["handler"]["kwargs"] = {
"start_time": train_start,
"end_time": latest_date,
"fit_start_time": train_start,
"fit_end_time": train_end,
"instruments": "csi300",
}
rec = task_train(task, exp_name)
pred = rec.load_object("pred.pkl")
online_tool = OnlineToolR(exp_name)
online_tool.reset_online_tag(rec) # set to online model
online_tool.update_online_pred()
new_pred = rec.load_object("pred.pkl")
label = rec.load_object("label.pkl")
label_date = label.dropna().index.get_level_values("datetime").max()
pred_date = new_pred.dropna().index.get_level_values("datetime").max()
# The prediction is updated, but the label is not updated.
self.assertTrue(label_date < pred_date)
# Update label now
lu = LabelUpdater(rec)
lu.update()
new_label = rec.load_object("label.pkl")
new_label_date = new_label.index.get_level_values("datetime").max()
self.assertTrue(new_label_date == pred_date) # make sure the label is updated now
if __name__ == "__main__":
unittest.main()