From e2f58274ba91f1ef43e8bc87b6cc04e67416137b Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Wed, 10 Mar 2021 10:58:49 +0000 Subject: [PATCH] update task manager --- .../taskmanager/task_manager_rolling.ipynb | 297 +----------------- examples/taskmanager/task_manager_rolling.py | 9 +- qlib/model/trainer.py | 10 +- qlib/workflow/task/collect.py | 15 +- qlib/workflow/task/gen.py | 2 +- 5 files changed, 34 insertions(+), 299 deletions(-) diff --git a/examples/taskmanager/task_manager_rolling.ipynb b/examples/taskmanager/task_manager_rolling.ipynb index 43ae5b1d1..e8ec8d4a7 100644 --- a/examples/taskmanager/task_manager_rolling.ipynb +++ b/examples/taskmanager/task_manager_rolling.ipynb @@ -2,32 +2,11 @@ "cells": [ { "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "import mlflow\n", - "mlflow.end_run()" - ] - }, - { - "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": { "collapsed": true }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "[8348:MainThread](2021-03-09 14:55:48,543) INFO - qlib.Initialization - [config.py:279] - default_conf: client.\n", - "[8348:MainThread](2021-03-09 14:55:50,592) WARNING - qlib.Initialization - [config.py:295] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", - "[8348:MainThread](2021-03-09 14:55:50,597) INFO - qlib.Initialization - [__init__.py:48] - qlib successfully initialized based on client settings.\n", - "[8348:MainThread](2021-03-09 14:55:50,601) INFO - qlib.Initialization - [__init__.py:49] - data_path=C:\\Users\\lzh222333\\.qlib\\qlib_data\\cn_data\n" - ] - } - ], + "outputs": [], "source": [ "import qlib\n", "from qlib.config import REG_CN\n", @@ -96,109 +75,17 @@ "\n", "C[\"mongo\"] = {\n", " \"task_url\" : \"mongodb://localhost:27017/\", # maybe you need to change it to your url\n", - " \"task_db_name\" : \"rolling_db3\"\n", + " \"task_db_name\" : \"rolling_db\"\n", "}\n", "\n", - "exp_name = 'rolling_exp3' # experiment name, will be used as the experiment in MLflow\n", - "task_pool = 'rolling_task3' # task pool name, will be used as the document in MongoDB" + "exp_name = 'rolling_exp' # experiment name, will be used as the experiment in MLflow\n", + "task_pool = 'rolling_task' # task pool name, will be used as the document in MongoDB" ] }, { "cell_type": "code", - "execution_count": 25, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[{'dataset': {'class': 'DatasetH',\n", - " 'kwargs': {'handler': {'class': 'Alpha158',\n", - " 'kwargs': {'end_time': '2020-08-01',\n", - " 'fit_end_time': '2014-12-31',\n", - " 'fit_start_time': '2008-01-01',\n", - " 'instruments': 'csi100',\n", - " 'start_time': '2008-01-01'},\n", - " 'module_path': 'qlib.contrib.data.handler'},\n", - " 'segments': {'test': (Timestamp('2017-01-03 00:00:00'),\n", - " Timestamp('2019-04-08 00:00:00')),\n", - " 'train': (Timestamp('2008-01-02 00:00:00'),\n", - " Timestamp('2014-12-31 00:00:00')),\n", - " 'valid': (Timestamp('2015-01-05 00:00:00'),\n", - " Timestamp('2016-12-30 00:00:00'))}},\n", - " 'module_path': 'qlib.data.dataset'},\n", - " 'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'},\n", - " 'record': [{'class': 'SignalRecord',\n", - " 'module_path': 'qlib.workflow.record_temp'},\n", - " {'class': 'SigAnaRecord',\n", - " 'module_path': 'qlib.workflow.record_temp'}],\n", - " 'task_key': 1},\n", - " {'dataset': {'class': 'DatasetH',\n", - " 'kwargs': {'handler': {'class': 'Alpha158',\n", - " 'kwargs': {'end_time': '2020-08-01',\n", - " 'fit_end_time': '2014-12-31',\n", - " 'fit_start_time': '2008-01-01',\n", - " 'instruments': 'csi100',\n", - " 'start_time': '2008-01-01'},\n", - " 'module_path': 'qlib.contrib.data.handler'},\n", - " 'segments': {'test': (Timestamp('2019-04-09 00:00:00'),\n", - " Timestamp('2021-07-12 00:00:00')),\n", - " 'train': (Timestamp('2010-04-23 00:00:00'),\n", - " Timestamp('2017-05-24 00:00:00')),\n", - " 'valid': (Timestamp('2017-05-25 00:00:00'),\n", - " Timestamp('2019-04-08 00:00:00'))}},\n", - " 'module_path': 'qlib.data.dataset'},\n", - " 'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'},\n", - " 'record': [{'class': 'SignalRecord',\n", - " 'module_path': 'qlib.workflow.record_temp'},\n", - " {'class': 'SigAnaRecord',\n", - " 'module_path': 'qlib.workflow.record_temp'}],\n", - " 'task_key': 1},\n", - " {'dataset': {'class': 'DatasetH',\n", - " 'kwargs': {'handler': {'class': 'Alpha158',\n", - " 'kwargs': {'end_time': '2020-08-01',\n", - " 'fit_end_time': '2014-12-31',\n", - " 'fit_start_time': '2008-01-01',\n", - " 'instruments': 'csi100',\n", - " 'start_time': '2008-01-01'},\n", - " 'module_path': 'qlib.contrib.data.handler'},\n", - " 'segments': {'test': (Timestamp('2017-01-03 00:00:00'),\n", - " Timestamp('2019-04-08 00:00:00')),\n", - " 'train': (Timestamp('2008-01-02 00:00:00'),\n", - " Timestamp('2014-12-31 00:00:00')),\n", - " 'valid': (Timestamp('2015-01-05 00:00:00'),\n", - " Timestamp('2016-12-30 00:00:00'))}},\n", - " 'module_path': 'qlib.data.dataset'},\n", - " 'model': {'class': 'LGBModel', 'module_path': 'qlib.contrib.model.gbdt'},\n", - " 'record': [{'class': 'SignalRecord',\n", - " 'module_path': 'qlib.workflow.record_temp'},\n", - " {'class': 'SigAnaRecord',\n", - " 'module_path': 'qlib.workflow.record_temp'}],\n", - " 'task_key': 'task_lgb'},\n", - " {'dataset': {'class': 'DatasetH',\n", - " 'kwargs': {'handler': {'class': 'Alpha158',\n", - " 'kwargs': {'end_time': '2020-08-01',\n", - " 'fit_end_time': '2014-12-31',\n", - " 'fit_start_time': '2008-01-01',\n", - " 'instruments': 'csi100',\n", - " 'start_time': '2008-01-01'},\n", - " 'module_path': 'qlib.contrib.data.handler'},\n", - " 'segments': {'test': (Timestamp('2019-04-09 00:00:00'),\n", - " Timestamp('2021-07-12 00:00:00')),\n", - " 'train': (Timestamp('2010-04-23 00:00:00'),\n", - " Timestamp('2017-05-24 00:00:00')),\n", - " 'valid': (Timestamp('2017-05-25 00:00:00'),\n", - " Timestamp('2019-04-08 00:00:00'))}},\n", - " 'module_path': 'qlib.data.dataset'},\n", - " 'model': {'class': 'LGBModel', 'module_path': 'qlib.contrib.model.gbdt'},\n", - " 'record': [{'class': 'SignalRecord',\n", - " 'module_path': 'qlib.workflow.record_temp'},\n", - " {'class': 'SigAnaRecord',\n", - " 'module_path': 'qlib.workflow.record_temp'}],\n", - " 'task_key': 'task_lgb'}]\n", - "Total Tasks, New Tasks: 4 0\n" - ] - } - ], + "execution_count": null, + "outputs": [], "source": [ "tasks = task_generator(\n", " xgboost_task_template, # default task name\n", @@ -206,8 +93,8 @@ " task_lgb=lgb_task_template # use \"task_lgb\" as the task name\n", ")\n", "# Uncomment next two lines to see the generated tasks\n", - "from pprint import pprint\n", - "pprint(tasks)\n", + "# from pprint import pprint\n", + "# pprint(tasks)\n", "tm = TaskManager(task_pool=task_pool)\n", "tm.create_task(tasks) # all tasks will be saved to MongoDB" ], @@ -220,133 +107,8 @@ }, { "cell_type": "code", - "execution_count": 26, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "2021-03-09 14:55:51.600 | INFO | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')), 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')), 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2019-04-08 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 1}\n", - "[8348:MainThread](2021-03-09 14:56:46,051) INFO - qlib.timer - [log.py:81] - Time cost: 54.448s | Loading data Done\n", - "[8348:MainThread](2021-03-09 14:56:46,440) INFO - qlib.timer - [log.py:81] - Time cost: 0.322s | DropnaLabel Done\n", - "[8348:MainThread](2021-03-09 14:56:52,461) INFO - qlib.timer - [log.py:81] - Time cost: 6.019s | CSZScoreNorm Done\n", - "[8348:MainThread](2021-03-09 14:56:52,464) INFO - qlib.timer - [log.py:81] - Time cost: 6.411s | fit & process data Done\n", - "[8348:MainThread](2021-03-09 14:56:52,468) INFO - qlib.timer - [log.py:81] - Time cost: 60.865s | Init data Done\n", - "[8348:MainThread](2021-03-09 14:56:52,471) INFO - qlib.workflow - [expm.py:245] - No tracking URI is provided. Use the default tracking URI.\n", - "[8348:MainThread](2021-03-09 14:56:52,500) INFO - qlib.workflow - [exp.py:181] - Experiment 2 starts running ...\n", - "[8348:MainThread](2021-03-09 14:56:52,567) INFO - qlib.workflow - [recorder.py:233] - Recorder dd6bceb6d319493686ab6565633c0b5a starts running under Experiment 2 ...\n", - "[0]\ttrain-rmse:1.05165\tvalid-rmse:1.05565\n", - "[20]\ttrain-rmse:0.97071\tvalid-rmse:1.00077\n", - "[40]\ttrain-rmse:0.95124\tvalid-rmse:1.00609\n", - "[59]\ttrain-rmse:0.93833\tvalid-rmse:1.00945\n", - "[8348:MainThread](2021-03-09 14:59:37,266) INFO - qlib.workflow - [record_temp.py:126] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2\n", - "'The following are prediction results of the XGBModel model.'\n", - " score\n", - "datetime instrument \n", - "2017-01-03 SH600000 -0.103259\n", - " SH600010 -0.084365\n", - " SH600015 -0.107433\n", - " SH600016 -0.064723\n", - " SH600018 -0.038639\n", - "{'IC': 0.05347474869798698,\n", - " 'ICIR': 0.29781294430945265,\n", - " 'Rank IC': 0.0484064337863249,\n", - " 'Rank ICIR': 0.36035393716962033}\n", - "2021-03-09 14:59:38.633 | INFO | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2010-04-23 00:00:00'), Timestamp('2017-05-24 00:00:00')), 'valid': (Timestamp('2017-05-25 00:00:00'), Timestamp('2019-04-08 00:00:00')), 'test': (Timestamp('2019-04-09 00:00:00'), Timestamp('2021-07-12 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 1}\n", - "[8348:MainThread](2021-03-09 15:00:36,591) INFO - qlib.timer - [log.py:81] - Time cost: 57.954s | Loading data Done\n", - "[8348:MainThread](2021-03-09 15:00:36,997) INFO - qlib.timer - [log.py:81] - Time cost: 0.338s | DropnaLabel Done\n", - "[8348:MainThread](2021-03-09 15:00:43,728) INFO - qlib.timer - [log.py:81] - Time cost: 6.728s | CSZScoreNorm Done\n", - "[8348:MainThread](2021-03-09 15:00:43,731) INFO - qlib.timer - [log.py:81] - Time cost: 7.137s | fit & process data Done\n", - "[8348:MainThread](2021-03-09 15:00:43,734) INFO - qlib.timer - [log.py:81] - Time cost: 65.097s | Init data Done\n", - "[8348:MainThread](2021-03-09 15:00:43,740) INFO - qlib.workflow - [expm.py:245] - No tracking URI is provided. Use the default tracking URI.\n", - "[8348:MainThread](2021-03-09 15:00:43,768) INFO - qlib.workflow - [exp.py:181] - Experiment 2 starts running ...\n", - "[8348:MainThread](2021-03-09 15:00:43,851) INFO - qlib.workflow - [recorder.py:233] - Recorder de2f892b569c436ba642a23e99f4f2b0 starts running under Experiment 2 ...\n", - "[0]\ttrain-rmse:1.05178\tvalid-rmse:1.05345\n", - "[20]\ttrain-rmse:0.96764\tvalid-rmse:0.99546\n", - "[40]\ttrain-rmse:0.94957\tvalid-rmse:0.99798\n", - "[57]\ttrain-rmse:0.93592\tvalid-rmse:1.00030\n", - "[8348:MainThread](2021-03-09 15:03:12,764) INFO - qlib.workflow - [record_temp.py:126] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2\n", - "'The following are prediction results of the XGBModel model.'\n", - " score\n", - "datetime instrument \n", - "2019-04-09 SH600000 0.006996\n", - " SH600009 -0.102482\n", - " SH600010 0.016398\n", - " SH600011 0.004459\n", - " SH600015 -0.128315\n", - "{'IC': 0.013224093132176661,\n", - " 'ICIR': 0.08254897170570956,\n", - " 'Rank IC': 0.02472594591723197,\n", - " 'Rank ICIR': 0.16330982475433398}\n", - "2021-03-09 15:03:13.593 | INFO | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'LGBModel', 'module_path': 'qlib.contrib.model.gbdt'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')), 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')), 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2019-04-08 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 'task_lgb'}\n", - "[8348:MainThread](2021-03-09 15:04:06,545) INFO - qlib.timer - [log.py:81] - Time cost: 52.814s | Loading data Done\n", - "[8348:MainThread](2021-03-09 15:04:06,919) INFO - qlib.timer - [log.py:81] - Time cost: 0.312s | DropnaLabel Done\n", - "[8348:MainThread](2021-03-09 15:04:12,850) INFO - qlib.timer - [log.py:81] - Time cost: 5.928s | CSZScoreNorm Done\n", - "[8348:MainThread](2021-03-09 15:04:12,853) INFO - qlib.timer - [log.py:81] - Time cost: 6.305s | fit & process data Done\n", - "[8348:MainThread](2021-03-09 15:04:12,856) INFO - qlib.timer - [log.py:81] - Time cost: 59.125s | Init data Done\n", - "[8348:MainThread](2021-03-09 15:04:12,859) INFO - qlib.workflow - [expm.py:245] - No tracking URI is provided. Use the default tracking URI.\n", - "[8348:MainThread](2021-03-09 15:04:12,888) INFO - qlib.workflow - [exp.py:181] - Experiment 2 starts running ...\n", - "[8348:MainThread](2021-03-09 15:04:12,958) INFO - qlib.workflow - [recorder.py:233] - Recorder 15df799127a74656829978c1b9352e60 starts running under Experiment 2 ...\n", - "Training until validation scores don't improve for 50 rounds\n", - "[20]\ttrain's l2: 0.970491\tvalid's l2: 0.987723\n", - "[40]\ttrain's l2: 0.957984\tvalid's l2: 0.990056\n", - "[60]\ttrain's l2: 0.947201\tvalid's l2: 0.991459\n", - "Early stopping, best iteration is:\n", - "[18]\ttrain's l2: 0.971834\tvalid's l2: 0.987481\n", - "[8348:MainThread](2021-03-09 15:04:19,847) INFO - qlib.workflow - [record_temp.py:126] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2\n", - "'The following are prediction results of the LGBModel model.'\n", - " score\n", - "datetime instrument \n", - "2017-01-03 SH600000 -0.013089\n", - " SH600010 -0.006642\n", - " SH600015 -0.035137\n", - " SH600016 -0.034634\n", - " SH600018 -0.029493\n", - "{'IC': 0.05704431372255674,\n", - " 'ICIR': 0.28879437007622133,\n", - " 'Rank IC': 0.05181220321608411,\n", - " 'Rank ICIR': 0.3233833799543165}\n", - "2021-03-09 15:04:21.111 | INFO | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'LGBModel', 'module_path': 'qlib.contrib.model.gbdt'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2010-04-23 00:00:00'), Timestamp('2017-05-24 00:00:00')), 'valid': (Timestamp('2017-05-25 00:00:00'), Timestamp('2019-04-08 00:00:00')), 'test': (Timestamp('2019-04-09 00:00:00'), Timestamp('2021-07-12 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 'task_lgb'}\n", - "[8348:MainThread](2021-03-09 15:05:16,072) INFO - qlib.timer - [log.py:81] - Time cost: 54.958s | Loading data Done\n", - "[8348:MainThread](2021-03-09 15:05:16,466) INFO - qlib.timer - [log.py:81] - Time cost: 0.334s | DropnaLabel Done\n", - "[8348:MainThread](2021-03-09 15:05:22,281) INFO - qlib.timer - [log.py:81] - Time cost: 5.812s | CSZScoreNorm Done\n", - "[8348:MainThread](2021-03-09 15:05:22,283) INFO - qlib.timer - [log.py:81] - Time cost: 6.209s | fit & process data Done\n", - "[8348:MainThread](2021-03-09 15:05:22,286) INFO - qlib.timer - [log.py:81] - Time cost: 61.172s | Init data Done\n", - "[8348:MainThread](2021-03-09 15:05:22,291) INFO - qlib.workflow - [expm.py:245] - No tracking URI is provided. Use the default tracking URI.\n", - "[8348:MainThread](2021-03-09 15:05:22,317) INFO - qlib.workflow - [exp.py:181] - Experiment 2 starts running ...\n", - "[8348:MainThread](2021-03-09 15:05:22,386) INFO - qlib.workflow - [recorder.py:233] - Recorder 0c814539f55842b9b6310843fc5ec708 starts running under Experiment 2 ...\n", - "Training until validation scores don't improve for 50 rounds\n", - "[20]\ttrain's l2: 0.969033\tvalid's l2: 0.98571\n", - "[40]\ttrain's l2: 0.955399\tvalid's l2: 0.986164\n", - "[60]\ttrain's l2: 0.943514\tvalid's l2: 0.986301\n", - "Early stopping, best iteration is:\n", - "[26]\ttrain's l2: 0.964587\tvalid's l2: 0.985356\n", - "[8348:MainThread](2021-03-09 15:05:29,546) INFO - qlib.workflow - [record_temp.py:126] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2\n", - "'The following are prediction results of the LGBModel model.'\n", - " score\n", - "datetime instrument \n", - "2019-04-09 SH600000 0.029586\n", - " SH600009 0.004306\n", - " SH600010 -0.004411\n", - " SH600011 0.002707\n", - " SH600015 -0.029124\n", - "{'IC': 0.020784811232504984,\n", - " 'ICIR': 0.11590182186569555,\n", - " 'Rank IC': 0.028925697036767055,\n", - " 'Rank ICIR': 0.16388058980901396}\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "True" - ] - }, - "metadata": {}, - "execution_count": 26 - } - ], + "execution_count": null, + "outputs": [], "source": [ "from qlib.workflow.task.manage import run_task\n", "from qlib.workflow.task.collect import TaskCollector\n", @@ -363,43 +125,12 @@ }, { "cell_type": "code", - "execution_count": 27, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Loading data: 100%|██████████| 4/4 [00:00<00:00, 37.38it/s]\n" - ] - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "{('task_lgb', '2019-04-08'): datetime instrument\n", - " 2017-01-03 SH600000 -0.013089\n", - " SH600010 -0.006642\n", - " SH600015 -0.035137\n", - " SH600016 -0.034634\n", - " SH600018 -0.029493\n", - " ... \n", - " 2019-04-08 SZ002415 0.049199\n", - " SZ002450 -0.013450\n", - " SZ002594 0.022395\n", - " SZ002736 0.091433\n", - " SZ300059 -0.016237\n", - " Name: score, Length: 55000, dtype: float64}" - ] - }, - "metadata": {}, - "execution_count": 27 - } - ], + "execution_count": null, + "outputs": [], "source": [ "def get_task_key(task):\n", " task_key = task[\"task_key\"]\n", " rolling_end_timestamp = task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]\n", - " #rolling_end_datatime = rolling_end_timestamp.to_pydatetime()\n", " return task_key, rolling_end_timestamp.strftime('%Y-%m-%d')\n", "\n", "def my_filter(task):\n", @@ -410,7 +141,7 @@ " return False\n", "\n", "# name tasks by \"get_task_key\" and filter tasks by \"my_filter\"\n", - "pred_rolling = TaskCollector.collect(exp_name, get_task_key, my_filter) \n", + "pred_rolling = TaskCollector.collect_predictions(exp_name, get_task_key, my_filter) \n", "pred_rolling" ], "metadata": { diff --git a/examples/taskmanager/task_manager_rolling.py b/examples/taskmanager/task_manager_rolling.py index 7519bc4be..db5d1817f 100644 --- a/examples/taskmanager/task_manager_rolling.py +++ b/examples/taskmanager/task_manager_rolling.py @@ -85,7 +85,7 @@ tm = TaskManager(task_pool=task_pool) tm.create_task(tasks) # all tasks will be saved to MongoDB from qlib.workflow.task.manage import run_task -from qlib.workflow.task.collect import RollingCollector +from qlib.workflow.task.collect import TaskCollector from qlib.model.trainer import task_train run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method @@ -93,7 +93,6 @@ run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be tr def get_task_key(task_config): task_key = task_config["task_key"] rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1] - #rolling_end_datatime = rolling_end_timestamp.to_pydatetime() return task_key, rolling_end_timestamp.strftime('%Y-%m-%d') def my_filter(task_config): @@ -103,6 +102,6 @@ def my_filter(task_config): return True return False -collector = RollingCollector(get_task_key, my_filter) -pred_rolling = collector(exp_name) # name tasks by "get_task_key" and filter tasks by "my_filter" -print(pred_rolling) \ No newline at end of file +# name tasks by "get_task_key" and filter tasks by "my_filter" +pred_rolling = TaskCollector.collect_predictions(exp_name, get_task_key, my_filter) +pred_rolling \ No newline at end of file diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 91061636d..82d770b96 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -6,7 +6,7 @@ from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord -def task_train(task_config: dict, experiment_name: str): +def task_train(task_config: dict, experiment_name: str) -> str: """ task based training @@ -16,6 +16,11 @@ def task_train(task_config: dict, experiment_name: str): A dict describes a task setting. experiment_name: str The name of experiment + + Returns + ---------- + rid : str + The id of the recorder of this task """ # model initiaiton @@ -29,7 +34,7 @@ def task_train(task_config: dict, experiment_name: str): model.fit(dataset) recorder = R.get_recorder() R.save_objects(**{"params.pkl": model}) - R.save_objects(param=task_config) # keep the original format and datatype + R.save_objects(**{"task.pkl": task_config}) # keep the original format and datatype # generate records: prediction, backtest, and analysis records = task_config.get("record", []) @@ -48,3 +53,4 @@ def task_train(task_config: dict, experiment_name: str): record["kwargs"].update(rconf) ar = init_instance_by_config(record) ar.generate() + return record.info["id"] diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 4562a1cec..834189561 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,7 +1,7 @@ from qlib.workflow import R import pandas as pd from typing import Union -from tqdm.auto import tqdm +from qlib import get_module_logger class TaskCollector: @@ -10,10 +10,8 @@ class TaskCollector: """ @staticmethod - def collect( - experiment_name: str, - get_key_func, - filter_func=None, + def collect_predictions( + experiment_name: str, get_key_func, filter_func=None, ): """ @@ -34,8 +32,8 @@ class TaskCollector: recs = exp.list_recorders() recs_flt = {} - for rid, rec in tqdm(recs.items(), desc="Loading data"): - params = rec.load_object("param") + for rid, rec in recs.items(): + params = rec.load_object("task.pkl") if rec.status == rec.STATUS_FI: if filter_func is None or filter_func(params): rec.params = params @@ -57,6 +55,7 @@ class TaskCollector: pred = pd.concat(pred_l).sort_index() reduce_group[k] = pred + get_module_logger("TaskCollector").info(f"Collect {len(reduce_group)} predictions in {experiment_name}") return reduce_group @@ -82,7 +81,7 @@ class RollingCollector: recs_flt = {} for rid, rec in tqdm(recs.items(), desc="Loading data"): - params = rec.load_object("param") + params = rec.load_object("task.pkl") if rec.status == rec.STATUS_FI: if self.flt_func is None or self.flt_func(params): rec.params = params diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index b1c2e0ce2..19793c485 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -168,7 +168,7 @@ class RollingGen(TaskGen): # 1) prepare the end point segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1] - # 2) and the init test segments + # 2) and init test segments test_start_idx = self.ta.align_idx(segments[self.test_key][0]) segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) else: