mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
format code and add example
This commit is contained in:
@@ -1,176 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import qlib\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.workflow.task.gen import RollingGen, task_generator\n",
|
||||
"from qlib.workflow.task.manage import TaskManager\n",
|
||||
"from qlib.config import C\n",
|
||||
"\n",
|
||||
"data_handler_template = {\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": 'csi100',\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"dataset_template = {\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": {\n",
|
||||
" \"class\": \"Alpha158\",\n",
|
||||
" \"module_path\": \"qlib.contrib.data.handler\",\n",
|
||||
" \"kwargs\": data_handler_template,\n",
|
||||
" },\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"record_template = [\n",
|
||||
" {\n",
|
||||
" \"class\": \"SignalRecord\",\n",
|
||||
" \"module_path\": \"qlib.workflow.record_temp\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"class\": \"SigAnaRecord\",\n",
|
||||
" \"module_path\": \"qlib.workflow.record_temp\",\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# use lgb\n",
|
||||
"lgb_task_template = {\n",
|
||||
" \"model\": {\n",
|
||||
" \"class\": \"LGBModel\",\n",
|
||||
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
|
||||
" },\n",
|
||||
" \"dataset\": dataset_template,\n",
|
||||
" \"record\": record_template,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# use xgboost\n",
|
||||
"xgboost_task_template = {\n",
|
||||
" \"model\": {\n",
|
||||
" \"class\": \"XGBModel\",\n",
|
||||
" \"module_path\": \"qlib.contrib.model.xgboost\",\n",
|
||||
" },\n",
|
||||
" \"dataset\": dataset_template,\n",
|
||||
" \"record\": record_template,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)\n",
|
||||
"\n",
|
||||
"C[\"mongo\"] = {\n",
|
||||
" \"task_url\" : \"mongodb://localhost:27017/\", # maybe you need to change it to your url\n",
|
||||
" \"task_db_name\" : \"rolling_db\"\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"exp_name = 'rolling_exp' # experiment name, will be used as the experiment in MLflow\n",
|
||||
"task_pool = 'rolling_task' # task pool name, will be used as the document in MongoDB"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tasks = task_generator(\n",
|
||||
" xgboost_task_template, # default task name\n",
|
||||
" RollingGen(step=550,rtype=RollingGen.ROLL_SD), # generate different date segment\n",
|
||||
" task_lgb=lgb_task_template # use \"task_lgb\" as the task name\n",
|
||||
")\n",
|
||||
"# Uncomment next two lines to see the generated tasks\n",
|
||||
"# from pprint import pprint\n",
|
||||
"# pprint(tasks)\n",
|
||||
"tm = TaskManager(task_pool=task_pool)\n",
|
||||
"tm.create_task(tasks) # all tasks will be saved to MongoDB"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.workflow.task.manage import run_task\n",
|
||||
"from qlib.workflow.task.collect import TaskCollector\n",
|
||||
"from qlib.model.trainer import task_train\n",
|
||||
"\n",
|
||||
"run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using \"task_train\" method"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_task_key(task):\n",
|
||||
" task_key = task[\"task_key\"]\n",
|
||||
" rolling_end_timestamp = task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]\n",
|
||||
" return task_key, rolling_end_timestamp.strftime('%Y-%m-%d')\n",
|
||||
"\n",
|
||||
"def my_filter(task):\n",
|
||||
" # only choose the results of \"task_lgb\" and test segment end in 2019 from all tasks\n",
|
||||
" task_key, rolling_end = get_task_key(task)\n",
|
||||
" if task_key==\"task_lgb\" and rolling_end.startswith('2019'):\n",
|
||||
" return True\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
"# name tasks by \"get_task_key\" and filter tasks by \"my_filter\"\n",
|
||||
"pred_rolling = TaskCollector.collect_predictions(exp_name, get_task_key, my_filter) \n",
|
||||
"pred_rolling"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "3.6.5-final"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -16,7 +16,11 @@ dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {"class": "Alpha158", "module_path": "qlib.contrib.data.handler", "kwargs": data_handler_config,},
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
@@ -26,20 +30,32 @@ dataset_config = {
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{"class": "SignalRecord", "module_path": "qlib.workflow.record_temp",},
|
||||
{"class": "SigAnaRecord", "module_path": "qlib.workflow.record_temp",},
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb
|
||||
task_lgb_config = {
|
||||
"model": {"class": "LGBModel", "module_path": "qlib.contrib.model.gbdt",},
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost
|
||||
task_xgboost_config = {
|
||||
"model": {"class": "XGBModel", "module_path": "qlib.contrib.model.xgboost",},
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
244
examples/taskmanager/task_manager_rolling_with_updating.py
Normal file
244
examples/taskmanager/task_manager_rolling_with_updating.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import qlib
|
||||
import fire
|
||||
import mlflow
|
||||
from qlib.config import C
|
||||
from qlib.workflow import R
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.workflow.task.update import ModelUpdater
|
||||
from qlib.workflow.task.collect import TaskCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2017-01-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating(**kwargs):
|
||||
print("========================================= task_generating =========================================")
|
||||
|
||||
rolling_generator = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_EX)
|
||||
|
||||
tasks = task_generator(rolling_generator, **kwargs)
|
||||
|
||||
# See the generated tasks in a easy way
|
||||
from pprint import pprint
|
||||
|
||||
pprint(tasks)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
# This part corresponds to "Task Storing" in the document
|
||||
def task_storing(tasks):
|
||||
print("========================================= task_storing =========================================")
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
|
||||
|
||||
# This part corresponds to "Task Running" in the document
|
||||
def task_running():
|
||||
print("========================================= task_running =========================================")
|
||||
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting():
|
||||
print("========================================= task_collecting =========================================")
|
||||
|
||||
def get_task_key(task_config):
|
||||
task_key = task_config["task_key"]
|
||||
rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1]
|
||||
if rolling_end_timestamp == None:
|
||||
rolling_end_timestamp = TimeAdjuster().last_date()
|
||||
return task_key, rolling_end_timestamp.strftime("%Y-%m-%d")
|
||||
|
||||
def lgb_filter(task_config):
|
||||
# only choose the results of "task_lgb"
|
||||
task_key, rolling_end = get_task_key(task_config)
|
||||
if task_key == "task_lgb":
|
||||
return True
|
||||
return False
|
||||
|
||||
task_collector = TaskCollector(exp_name)
|
||||
pred_rolling = task_collector.collect_predictions(
|
||||
get_task_key, lgb_filter
|
||||
) # name tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
print(pred_rolling)
|
||||
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(force_end=False):
|
||||
print("========================================= reset =========================================")
|
||||
TaskManager(task_pool=task_pool).remove()
|
||||
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
recs = TaskCollector(exp_name).list_recorders(only_finished=True)
|
||||
|
||||
for rid in recs:
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
try:
|
||||
if force_end:
|
||||
mlflow.end_run()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def set_online_model_to_latest():
|
||||
print(
|
||||
"========================================= set_online_model_to_latest ========================================="
|
||||
)
|
||||
model_updater = ModelUpdater(experiment_name=exp_name)
|
||||
latest_records, latest_test = model_updater.collect_latest_records()
|
||||
model_updater.reset_online_model(latest_records.values())
|
||||
|
||||
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run():
|
||||
print("========================================= first_run =========================================")
|
||||
reset(force_end=True)
|
||||
|
||||
# use "task_lgb" and "task_xgboost" as the task name
|
||||
tasks = task_generating(**{"task_xgboost": task_xgboost_config, "task_lgb": task_lgb_config})
|
||||
task_storing(tasks)
|
||||
task_running()
|
||||
task_collecting()
|
||||
set_online_model_to_latest()
|
||||
|
||||
|
||||
# Update the predictions of online model
|
||||
def update_predictions():
|
||||
print("========================================= update_predictions =========================================")
|
||||
model_updater = ModelUpdater(experiment_name=exp_name)
|
||||
model_updater.update_online_pred()
|
||||
|
||||
|
||||
# Update the models using the latest date and set them to online model
|
||||
def update_model():
|
||||
print("========================================= update_model =========================================")
|
||||
# get the latest recorders
|
||||
model_updater = ModelUpdater(experiment_name=exp_name)
|
||||
latest_records, latest_test = model_updater.collect_latest_records()
|
||||
# date adjustment based on trade day of Calendar in Qlib
|
||||
time_adjuster = TimeAdjuster()
|
||||
calendar_latest = time_adjuster.last_date()
|
||||
print("The latest date is ", calendar_latest)
|
||||
if time_adjuster.cal_interval(calendar_latest, latest_test[0]) > rolling_step:
|
||||
print("Need update models!")
|
||||
tasks = {}
|
||||
for rid, rec in latest_records.items():
|
||||
old_task = rec.task
|
||||
test_begin = old_task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
old_task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks[old_task["task_key"]] = old_task
|
||||
|
||||
# retrain the latest model
|
||||
new_tasks = task_generating(**tasks)
|
||||
task_storing(new_tasks)
|
||||
task_running()
|
||||
task_collecting()
|
||||
latest_records, _ = model_updater.collect_latest_records()
|
||||
|
||||
# set the latest model to online model
|
||||
model_updater.reset_online_model(latest_records.values())
|
||||
|
||||
|
||||
# Run whole workflow completely
|
||||
def whole_workflow():
|
||||
print("========================================= whole_workflow =========================================")
|
||||
# run this at the first time
|
||||
first_run()
|
||||
# run this every day
|
||||
update_predictions()
|
||||
# run this every "rolling_steps" day
|
||||
update_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python task_manager_rolling_with_updating.py first_run
|
||||
|
||||
####### to update the models using the latest date and set them to online model, use the command below
|
||||
# python task_manager_rolling_with_updating.py update_model
|
||||
|
||||
####### to update the predictions to the latest date, use the command below
|
||||
# python task_manager_rolling_with_updating.py update_predictions
|
||||
|
||||
####### to run whole workflow completely, use the command below
|
||||
# python task_manager_rolling_with_updating.py whole_workflow
|
||||
|
||||
#################### you need to finish the configurations below #########################
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
C["mongo"] = {
|
||||
"task_url": "mongodb://localhost:27017/", # your MongoDB url
|
||||
"task_db_name": "rolling_db", # database name
|
||||
}
|
||||
|
||||
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
|
||||
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
|
||||
rolling_step = 550
|
||||
|
||||
##########################################################################################
|
||||
|
||||
fire.Fire()
|
||||
@@ -5,12 +5,12 @@ from qlib.config import REG_CN
|
||||
import fire
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
@@ -44,22 +44,26 @@ task = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"record": {"class": "SignalRecord", "module_path": "qlib.workflow.record_temp",},
|
||||
"record": {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
}
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
|
||||
|
||||
def first_train(experiment_name="online_svr"):
|
||||
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
model_updater = ModelUpdater(experiment_name)
|
||||
|
||||
rid = task_train(task_config=task, experiment_name=experiment_name)
|
||||
model_updater.reset_online_model(rid)
|
||||
|
||||
|
||||
def update_online_pred(experiment_name="online_svr"):
|
||||
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
model_updater = ModelUpdater(experiment_name)
|
||||
|
||||
@@ -68,8 +72,9 @@ def update_online_pred(experiment_name="online_svr"):
|
||||
print(rid)
|
||||
|
||||
model_updater.update_online_pred()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire()
|
||||
# to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
|
||||
@@ -18,7 +18,7 @@ class TaskCollector:
|
||||
|
||||
def list_recorders(self, rec_filter_func=None, task_filter_func=None, only_finished=True, only_have_task=False):
|
||||
"""
|
||||
Return a dict of {rid:Recorder} by recorder filter and task filter. It is not necessary to use those filter.
|
||||
Return a dict of {rid:Recorder} by recorder filter and task filter. It is not necessary to use those filter.
|
||||
If you don't train with "task_train", then there is no "task" which includes the task config.
|
||||
If there is a "task", then it will become rec.task which can be get simply.
|
||||
|
||||
@@ -48,7 +48,7 @@ class TaskCollector:
|
||||
if task_filter_func is not None:
|
||||
only_have_task = True
|
||||
for rid, rec in recs.items():
|
||||
if (only_finished and rec.status == rec.STATUS_FI) or only_finished==False:
|
||||
if (only_finished and rec.status == rec.STATUS_FI) or only_finished == False:
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
task = None
|
||||
try:
|
||||
@@ -60,7 +60,7 @@ class TaskCollector:
|
||||
if task_filter_func is None or task_filter_func(task):
|
||||
rec.task = task
|
||||
recs_flt[rid] = rec
|
||||
|
||||
|
||||
return recs_flt
|
||||
|
||||
def collect_predictions(
|
||||
@@ -83,7 +83,7 @@ class TaskCollector:
|
||||
dict
|
||||
the dict of predictions
|
||||
"""
|
||||
recs_flt = self.list_recorders(task_filter_func=task_filter_func,only_have_task=True)
|
||||
recs_flt = self.list_recorders(task_filter_func=task_filter_func, only_have_task=True)
|
||||
|
||||
# group
|
||||
recs_group = {}
|
||||
@@ -108,18 +108,17 @@ class TaskCollector:
|
||||
self,
|
||||
task_filter_func=None,
|
||||
):
|
||||
recs_flt = self.list_recorders(task_filter_func=task_filter_func,only_have_task=True)
|
||||
|
||||
recs_flt = self.list_recorders(task_filter_func=task_filter_func, only_have_task=True)
|
||||
|
||||
if len(recs_flt) == 0:
|
||||
self.logger.warning("Can not collect any recorders...")
|
||||
return None, None
|
||||
max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values())
|
||||
max_test = max(rec.task["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values())
|
||||
|
||||
latest_record = {}
|
||||
for rid, rec in recs_flt.items():
|
||||
if rec.task['dataset']['kwargs']['segments']['test'] == max_test:
|
||||
if rec.task["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_record[rid] = rec
|
||||
|
||||
|
||||
self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}")
|
||||
return latest_record, max_test
|
||||
|
||||
@@ -21,6 +21,7 @@ from .utils import get_mongodb
|
||||
from qlib import auto_init
|
||||
from qlib import get_module_logger
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""TaskManager
|
||||
here is the what will a task looks like
|
||||
@@ -361,4 +362,3 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
ever_run = True
|
||||
|
||||
return ever_run
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Union,List
|
||||
from typing import Union, List
|
||||
from qlib.workflow import R
|
||||
from tqdm.auto import tqdm
|
||||
from qlib.data import D
|
||||
@@ -10,6 +10,7 @@ from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import TaskCollector
|
||||
|
||||
|
||||
class ModelUpdater(TaskCollector):
|
||||
"""
|
||||
The model updater to re-train model or update predictions
|
||||
@@ -31,7 +32,7 @@ class ModelUpdater(TaskCollector):
|
||||
self.exp = R.get_exp(experiment_name=experiment_name)
|
||||
self.logger = get_module_logger("ModelUpdater")
|
||||
|
||||
def set_online_model(self, recorder: Union[str,Recorder]):
|
||||
def set_online_model(self, recorder: Union[str, Recorder]):
|
||||
"""online model will be identified at the tags of the record
|
||||
|
||||
Parameters
|
||||
@@ -39,12 +40,12 @@ class ModelUpdater(TaskCollector):
|
||||
recorder: Union[str,Recorder]
|
||||
the id of a Recorder or the Recorder instance
|
||||
"""
|
||||
if isinstance(recorder,str):
|
||||
if isinstance(recorder, str):
|
||||
recorder = self.exp.get_recorder(recorder_id=recorder)
|
||||
recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_TRUE})
|
||||
|
||||
def cancel_online_model(self, recorder: Union[str,Recorder]):
|
||||
if isinstance(recorder,str):
|
||||
def cancel_online_model(self, recorder: Union[str, Recorder]):
|
||||
if isinstance(recorder, str):
|
||||
recorder = self.exp.get_recorder(recorder_id=recorder)
|
||||
recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_FALSE})
|
||||
|
||||
@@ -53,7 +54,7 @@ class ModelUpdater(TaskCollector):
|
||||
for rid, rec in recs.items():
|
||||
self.cancel_online_model(rec)
|
||||
|
||||
def reset_online_model(self, recorders: List[Union[str,Recorder]]):
|
||||
def reset_online_model(self, recorders: List[Union[str, Recorder]]):
|
||||
"""cancel all online model and reset the given model to online model
|
||||
|
||||
Parameters
|
||||
@@ -65,7 +66,7 @@ class ModelUpdater(TaskCollector):
|
||||
for rec_or_rid in recorders:
|
||||
self.set_online_model(rec_or_rid)
|
||||
|
||||
def update_pred(self, recorder: Union[str,Recorder]):
|
||||
def update_pred(self, recorder: Union[str, Recorder]):
|
||||
"""update predictions to the latest day in Calendar based on rid
|
||||
|
||||
Parameters
|
||||
@@ -73,17 +74,19 @@ class ModelUpdater(TaskCollector):
|
||||
recorder: Union[str,Recorder]
|
||||
the id of a Recorder or the Recorder instance
|
||||
"""
|
||||
if isinstance(recorder,str):
|
||||
if isinstance(recorder, str):
|
||||
recorder = self.exp.get_recorder(recorder_id=recorder)
|
||||
old_pred = recorder.load_object("pred.pkl")
|
||||
last_end = old_pred.index.get_level_values("datetime").max()
|
||||
task_config = recorder.load_object("task") # recorder.task
|
||||
task_config = recorder.load_object("task") # recorder.task
|
||||
|
||||
# updated to the latest trading day
|
||||
cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None)
|
||||
|
||||
if len(cal) == 0:
|
||||
self.logger.info(f"The prediction in {recorder.info['id']} of {self.exp_name} are latest. No need to update.")
|
||||
self.logger.info(
|
||||
f"The prediction in {recorder.info['id']} of {self.exp_name} are latest. No need to update."
|
||||
)
|
||||
return
|
||||
|
||||
start_time, end_time = cal[0], cal[-1]
|
||||
@@ -100,7 +103,9 @@ class ModelUpdater(TaskCollector):
|
||||
|
||||
recorder.save_objects(**{"pred.pkl": cb_pred})
|
||||
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {recorder.info['id']} of {self.exp_name}.")
|
||||
self.logger.info(
|
||||
f"Finish updating new {new_pred.shape[0]} predictions in {recorder.info['id']} of {self.exp_name}."
|
||||
)
|
||||
|
||||
def update_all_pred(self, rec_filter_func=None):
|
||||
"""update all predictions in this experiment after filter.
|
||||
@@ -126,7 +131,7 @@ class ModelUpdater(TaskCollector):
|
||||
the count of updated record
|
||||
|
||||
"""
|
||||
recs = self.list_recorders(rec_filter_func=rec_filter_func,only_have_task=True)
|
||||
recs = self.list_recorders(rec_filter_func=rec_filter_func, only_have_task=True)
|
||||
for rid, rec in recs.items():
|
||||
self.update_pred(rec)
|
||||
return len(recs)
|
||||
@@ -150,5 +155,5 @@ class ModelUpdater(TaskCollector):
|
||||
dict
|
||||
{rid : recorder of the online model}
|
||||
"""
|
||||
|
||||
|
||||
return self.list_recorders(rec_filter_func=self.online_filter)
|
||||
|
||||
@@ -50,6 +50,7 @@ class TimeAdjuster:
|
||||
if idx >= len(self.cals):
|
||||
return None
|
||||
return self.cals[idx]
|
||||
|
||||
def max(self):
|
||||
"""
|
||||
(Deprecated)
|
||||
|
||||
Reference in New Issue
Block a user