diff --git a/docs/advanced/task_managment.rst b/docs/advanced/task_managment.rst new file mode 100644 index 000000000..78ac62410 --- /dev/null +++ b/docs/advanced/task_managment.rst @@ -0,0 +1,67 @@ +.. _task_managment: + +================================= +Task Management +================================= +.. currentmodule:: qlib + + +Introduction +============= + +The `Workflow <../component/introduction.html>`_ part introduce how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``. To automatically generate and execute different tasks, Task Management module provide a whole process including `Task Generating`_, `Task Storing`_, `Task Running`_ and `Task Collecting`_. +With this module, users can run their ``task`` automatically at different periods, in different losses or even by different models. + +An example of the entire process is shown `here <>`_. + +Task Generating +=============== +A ``task`` consists of `Model`, `Dataset`, `Record` or anything added by users. +The specific task template can be viewed in +`Task Section <../component/workflow.html#task-section>`_. +Even though the task template is fixed, Users can use ``TaskGen`` to generate different ``task`` by task template. + +Here is the base class of TaskGen: + +.. autoclass:: qlib.workflow.task.gen.TaskGen + :members: + +``Qlib`` provider a class `RollingGen`_ to generate a list of ``task`` of dataset in different date segments. +This allows users to verify the effect of data from different periods on the model in one experiment. + +Task Storing +=============== +In order to achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB `_. +Users **MUST** finished the configuration of `MongoDB `_ when using this module. + +Users need to provide the url and database of ``task`` storing like this. + + .. code-block:: python + + from qlib.config import C + C["mongo"] = { + "task_url" : "mongodb://localhost:27017/", # maybe you need to change it to your url + "task_db_name" : "rolling_db" # you can custom database name + } + +The CRUD methods of ``task`` can be found in TaskManager. More methods can be seen in the `Github`_. + +.. autoclass:: qlib.workflow.task.manage.TaskManager + :members: + +Task Running +=============== +After generating and storing those ``task``, it's time to run the ``task`` in the *WAITING* status. +``qlib`` provide a method to run those ``task`` in task pool, however users can also customize how tasks are executed. +An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly. +It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*. + +.. autofunction:: qlib.workflow.task.manage.run_task + +Task Collecting +=============== +To see the results of ``task`` after running, ``Qlib`` provide a task collector to collect the tasks by filter condition (optional). +The collector will return a dict of filtered key (users defined by task config) and value (predict scores from ``pred.pkl``). + +.. autoclass:: qlib.workflow.task.collect.TaskCollector + :members: \ No newline at end of file diff --git a/examples/taskmanager/task_manager_rolling.ipynb b/examples/taskmanager/task_manager_rolling.ipynb new file mode 100644 index 000000000..43ae5b1d1 --- /dev/null +++ b/examples/taskmanager/task_manager_rolling.ipynb @@ -0,0 +1,445 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "import mlflow\n", + "mlflow.end_run()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "collapsed": true + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[8348:MainThread](2021-03-09 14:55:48,543) INFO - qlib.Initialization - [config.py:279] - default_conf: client.\n", + "[8348:MainThread](2021-03-09 14:55:50,592) WARNING - qlib.Initialization - [config.py:295] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n", + "[8348:MainThread](2021-03-09 14:55:50,597) INFO - qlib.Initialization - [__init__.py:48] - qlib successfully initialized based on client settings.\n", + "[8348:MainThread](2021-03-09 14:55:50,601) INFO - qlib.Initialization - [__init__.py:49] - data_path=C:\\Users\\lzh222333\\.qlib\\qlib_data\\cn_data\n" + ] + } + ], + "source": [ + "import qlib\n", + "from qlib.config import REG_CN\n", + "from qlib.workflow.task.gen import RollingGen, task_generator\n", + "from qlib.workflow.task.manage import TaskManager\n", + "from qlib.config import C\n", + "\n", + "data_handler_template = {\n", + " \"start_time\": \"2008-01-01\",\n", + " \"end_time\": \"2020-08-01\",\n", + " \"fit_start_time\": \"2008-01-01\",\n", + " \"fit_end_time\": \"2014-12-31\",\n", + " \"instruments\": 'csi100',\n", + "}\n", + "\n", + "dataset_template = {\n", + " \"class\": \"DatasetH\",\n", + " \"module_path\": \"qlib.data.dataset\",\n", + " \"kwargs\": {\n", + " \"handler\": {\n", + " \"class\": \"Alpha158\",\n", + " \"module_path\": \"qlib.contrib.data.handler\",\n", + " \"kwargs\": data_handler_template,\n", + " },\n", + " \"segments\": {\n", + " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", + " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", + " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", + " },\n", + " },\n", + " }\n", + "\n", + "record_template = [\n", + " {\n", + " \"class\": \"SignalRecord\",\n", + " \"module_path\": \"qlib.workflow.record_temp\",\n", + " },\n", + " {\n", + " \"class\": \"SigAnaRecord\",\n", + " \"module_path\": \"qlib.workflow.record_temp\",\n", + " }\n", + "]\n", + "\n", + "# use lgb\n", + "lgb_task_template = {\n", + " \"model\": {\n", + " \"class\": \"LGBModel\",\n", + " \"module_path\": \"qlib.contrib.model.gbdt\",\n", + " },\n", + " \"dataset\": dataset_template,\n", + " \"record\": record_template,\n", + "}\n", + "\n", + "# use xgboost\n", + "xgboost_task_template = {\n", + " \"model\": {\n", + " \"class\": \"XGBModel\",\n", + " \"module_path\": \"qlib.contrib.model.xgboost\",\n", + " },\n", + " \"dataset\": dataset_template,\n", + " \"record\": record_template,\n", + "}\n", + "\n", + "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n", + "qlib.init(provider_uri=provider_uri, region=REG_CN)\n", + "\n", + "C[\"mongo\"] = {\n", + " \"task_url\" : \"mongodb://localhost:27017/\", # maybe you need to change it to your url\n", + " \"task_db_name\" : \"rolling_db3\"\n", + "}\n", + "\n", + "exp_name = 'rolling_exp3' # experiment name, will be used as the experiment in MLflow\n", + "task_pool = 'rolling_task3' # task pool name, will be used as the document in MongoDB" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[{'dataset': {'class': 'DatasetH',\n", + " 'kwargs': {'handler': {'class': 'Alpha158',\n", + " 'kwargs': {'end_time': '2020-08-01',\n", + " 'fit_end_time': '2014-12-31',\n", + " 'fit_start_time': '2008-01-01',\n", + " 'instruments': 'csi100',\n", + " 'start_time': '2008-01-01'},\n", + " 'module_path': 'qlib.contrib.data.handler'},\n", + " 'segments': {'test': (Timestamp('2017-01-03 00:00:00'),\n", + " Timestamp('2019-04-08 00:00:00')),\n", + " 'train': (Timestamp('2008-01-02 00:00:00'),\n", + " Timestamp('2014-12-31 00:00:00')),\n", + " 'valid': (Timestamp('2015-01-05 00:00:00'),\n", + " Timestamp('2016-12-30 00:00:00'))}},\n", + " 'module_path': 'qlib.data.dataset'},\n", + " 'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'},\n", + " 'record': [{'class': 'SignalRecord',\n", + " 'module_path': 'qlib.workflow.record_temp'},\n", + " {'class': 'SigAnaRecord',\n", + " 'module_path': 'qlib.workflow.record_temp'}],\n", + " 'task_key': 1},\n", + " {'dataset': {'class': 'DatasetH',\n", + " 'kwargs': {'handler': {'class': 'Alpha158',\n", + " 'kwargs': {'end_time': '2020-08-01',\n", + " 'fit_end_time': '2014-12-31',\n", + " 'fit_start_time': '2008-01-01',\n", + " 'instruments': 'csi100',\n", + " 'start_time': '2008-01-01'},\n", + " 'module_path': 'qlib.contrib.data.handler'},\n", + " 'segments': {'test': (Timestamp('2019-04-09 00:00:00'),\n", + " Timestamp('2021-07-12 00:00:00')),\n", + " 'train': (Timestamp('2010-04-23 00:00:00'),\n", + " Timestamp('2017-05-24 00:00:00')),\n", + " 'valid': (Timestamp('2017-05-25 00:00:00'),\n", + " Timestamp('2019-04-08 00:00:00'))}},\n", + " 'module_path': 'qlib.data.dataset'},\n", + " 'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'},\n", + " 'record': [{'class': 'SignalRecord',\n", + " 'module_path': 'qlib.workflow.record_temp'},\n", + " {'class': 'SigAnaRecord',\n", + " 'module_path': 'qlib.workflow.record_temp'}],\n", + " 'task_key': 1},\n", + " {'dataset': {'class': 'DatasetH',\n", + " 'kwargs': {'handler': {'class': 'Alpha158',\n", + " 'kwargs': {'end_time': '2020-08-01',\n", + " 'fit_end_time': '2014-12-31',\n", + " 'fit_start_time': '2008-01-01',\n", + " 'instruments': 'csi100',\n", + " 'start_time': '2008-01-01'},\n", + " 'module_path': 'qlib.contrib.data.handler'},\n", + " 'segments': {'test': (Timestamp('2017-01-03 00:00:00'),\n", + " Timestamp('2019-04-08 00:00:00')),\n", + " 'train': (Timestamp('2008-01-02 00:00:00'),\n", + " Timestamp('2014-12-31 00:00:00')),\n", + " 'valid': (Timestamp('2015-01-05 00:00:00'),\n", + " Timestamp('2016-12-30 00:00:00'))}},\n", + " 'module_path': 'qlib.data.dataset'},\n", + " 'model': {'class': 'LGBModel', 'module_path': 'qlib.contrib.model.gbdt'},\n", + " 'record': [{'class': 'SignalRecord',\n", + " 'module_path': 'qlib.workflow.record_temp'},\n", + " {'class': 'SigAnaRecord',\n", + " 'module_path': 'qlib.workflow.record_temp'}],\n", + " 'task_key': 'task_lgb'},\n", + " {'dataset': {'class': 'DatasetH',\n", + " 'kwargs': {'handler': {'class': 'Alpha158',\n", + " 'kwargs': {'end_time': '2020-08-01',\n", + " 'fit_end_time': '2014-12-31',\n", + " 'fit_start_time': '2008-01-01',\n", + " 'instruments': 'csi100',\n", + " 'start_time': '2008-01-01'},\n", + " 'module_path': 'qlib.contrib.data.handler'},\n", + " 'segments': {'test': (Timestamp('2019-04-09 00:00:00'),\n", + " Timestamp('2021-07-12 00:00:00')),\n", + " 'train': (Timestamp('2010-04-23 00:00:00'),\n", + " Timestamp('2017-05-24 00:00:00')),\n", + " 'valid': (Timestamp('2017-05-25 00:00:00'),\n", + " Timestamp('2019-04-08 00:00:00'))}},\n", + " 'module_path': 'qlib.data.dataset'},\n", + " 'model': {'class': 'LGBModel', 'module_path': 'qlib.contrib.model.gbdt'},\n", + " 'record': [{'class': 'SignalRecord',\n", + " 'module_path': 'qlib.workflow.record_temp'},\n", + " {'class': 'SigAnaRecord',\n", + " 'module_path': 'qlib.workflow.record_temp'}],\n", + " 'task_key': 'task_lgb'}]\n", + "Total Tasks, New Tasks: 4 0\n" + ] + } + ], + "source": [ + "tasks = task_generator(\n", + " xgboost_task_template, # default task name\n", + " RollingGen(step=550,rtype=RollingGen.ROLL_SD), # generate different date segment\n", + " task_lgb=lgb_task_template # use \"task_lgb\" as the task name\n", + ")\n", + "# Uncomment next two lines to see the generated tasks\n", + "from pprint import pprint\n", + "pprint(tasks)\n", + "tm = TaskManager(task_pool=task_pool)\n", + "tm.create_task(tasks) # all tasks will be saved to MongoDB" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 26, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2021-03-09 14:55:51.600 | INFO | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')), 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')), 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2019-04-08 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 1}\n", + "[8348:MainThread](2021-03-09 14:56:46,051) INFO - qlib.timer - [log.py:81] - Time cost: 54.448s | Loading data Done\n", + "[8348:MainThread](2021-03-09 14:56:46,440) INFO - qlib.timer - [log.py:81] - Time cost: 0.322s | DropnaLabel Done\n", + "[8348:MainThread](2021-03-09 14:56:52,461) INFO - qlib.timer - [log.py:81] - Time cost: 6.019s | CSZScoreNorm Done\n", + "[8348:MainThread](2021-03-09 14:56:52,464) INFO - qlib.timer - [log.py:81] - Time cost: 6.411s | fit & process data Done\n", + "[8348:MainThread](2021-03-09 14:56:52,468) INFO - qlib.timer - [log.py:81] - Time cost: 60.865s | Init data Done\n", + "[8348:MainThread](2021-03-09 14:56:52,471) INFO - qlib.workflow - [expm.py:245] - No tracking URI is provided. Use the default tracking URI.\n", + "[8348:MainThread](2021-03-09 14:56:52,500) INFO - qlib.workflow - [exp.py:181] - Experiment 2 starts running ...\n", + "[8348:MainThread](2021-03-09 14:56:52,567) INFO - qlib.workflow - [recorder.py:233] - Recorder dd6bceb6d319493686ab6565633c0b5a starts running under Experiment 2 ...\n", + "[0]\ttrain-rmse:1.05165\tvalid-rmse:1.05565\n", + "[20]\ttrain-rmse:0.97071\tvalid-rmse:1.00077\n", + "[40]\ttrain-rmse:0.95124\tvalid-rmse:1.00609\n", + "[59]\ttrain-rmse:0.93833\tvalid-rmse:1.00945\n", + "[8348:MainThread](2021-03-09 14:59:37,266) INFO - qlib.workflow - [record_temp.py:126] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2\n", + "'The following are prediction results of the XGBModel model.'\n", + " score\n", + "datetime instrument \n", + "2017-01-03 SH600000 -0.103259\n", + " SH600010 -0.084365\n", + " SH600015 -0.107433\n", + " SH600016 -0.064723\n", + " SH600018 -0.038639\n", + "{'IC': 0.05347474869798698,\n", + " 'ICIR': 0.29781294430945265,\n", + " 'Rank IC': 0.0484064337863249,\n", + " 'Rank ICIR': 0.36035393716962033}\n", + "2021-03-09 14:59:38.633 | INFO | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'XGBModel', 'module_path': 'qlib.contrib.model.xgboost'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2010-04-23 00:00:00'), Timestamp('2017-05-24 00:00:00')), 'valid': (Timestamp('2017-05-25 00:00:00'), Timestamp('2019-04-08 00:00:00')), 'test': (Timestamp('2019-04-09 00:00:00'), Timestamp('2021-07-12 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 1}\n", + "[8348:MainThread](2021-03-09 15:00:36,591) INFO - qlib.timer - [log.py:81] - Time cost: 57.954s | Loading data Done\n", + "[8348:MainThread](2021-03-09 15:00:36,997) INFO - qlib.timer - [log.py:81] - Time cost: 0.338s | DropnaLabel Done\n", + "[8348:MainThread](2021-03-09 15:00:43,728) INFO - qlib.timer - [log.py:81] - Time cost: 6.728s | CSZScoreNorm Done\n", + "[8348:MainThread](2021-03-09 15:00:43,731) INFO - qlib.timer - [log.py:81] - Time cost: 7.137s | fit & process data Done\n", + "[8348:MainThread](2021-03-09 15:00:43,734) INFO - qlib.timer - [log.py:81] - Time cost: 65.097s | Init data Done\n", + "[8348:MainThread](2021-03-09 15:00:43,740) INFO - qlib.workflow - [expm.py:245] - No tracking URI is provided. Use the default tracking URI.\n", + "[8348:MainThread](2021-03-09 15:00:43,768) INFO - qlib.workflow - [exp.py:181] - Experiment 2 starts running ...\n", + "[8348:MainThread](2021-03-09 15:00:43,851) INFO - qlib.workflow - [recorder.py:233] - Recorder de2f892b569c436ba642a23e99f4f2b0 starts running under Experiment 2 ...\n", + "[0]\ttrain-rmse:1.05178\tvalid-rmse:1.05345\n", + "[20]\ttrain-rmse:0.96764\tvalid-rmse:0.99546\n", + "[40]\ttrain-rmse:0.94957\tvalid-rmse:0.99798\n", + "[57]\ttrain-rmse:0.93592\tvalid-rmse:1.00030\n", + "[8348:MainThread](2021-03-09 15:03:12,764) INFO - qlib.workflow - [record_temp.py:126] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2\n", + "'The following are prediction results of the XGBModel model.'\n", + " score\n", + "datetime instrument \n", + "2019-04-09 SH600000 0.006996\n", + " SH600009 -0.102482\n", + " SH600010 0.016398\n", + " SH600011 0.004459\n", + " SH600015 -0.128315\n", + "{'IC': 0.013224093132176661,\n", + " 'ICIR': 0.08254897170570956,\n", + " 'Rank IC': 0.02472594591723197,\n", + " 'Rank ICIR': 0.16330982475433398}\n", + "2021-03-09 15:03:13.593 | INFO | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'LGBModel', 'module_path': 'qlib.contrib.model.gbdt'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')), 'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')), 'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2019-04-08 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 'task_lgb'}\n", + "[8348:MainThread](2021-03-09 15:04:06,545) INFO - qlib.timer - [log.py:81] - Time cost: 52.814s | Loading data Done\n", + "[8348:MainThread](2021-03-09 15:04:06,919) INFO - qlib.timer - [log.py:81] - Time cost: 0.312s | DropnaLabel Done\n", + "[8348:MainThread](2021-03-09 15:04:12,850) INFO - qlib.timer - [log.py:81] - Time cost: 5.928s | CSZScoreNorm Done\n", + "[8348:MainThread](2021-03-09 15:04:12,853) INFO - qlib.timer - [log.py:81] - Time cost: 6.305s | fit & process data Done\n", + "[8348:MainThread](2021-03-09 15:04:12,856) INFO - qlib.timer - [log.py:81] - Time cost: 59.125s | Init data Done\n", + "[8348:MainThread](2021-03-09 15:04:12,859) INFO - qlib.workflow - [expm.py:245] - No tracking URI is provided. Use the default tracking URI.\n", + "[8348:MainThread](2021-03-09 15:04:12,888) INFO - qlib.workflow - [exp.py:181] - Experiment 2 starts running ...\n", + "[8348:MainThread](2021-03-09 15:04:12,958) INFO - qlib.workflow - [recorder.py:233] - Recorder 15df799127a74656829978c1b9352e60 starts running under Experiment 2 ...\n", + "Training until validation scores don't improve for 50 rounds\n", + "[20]\ttrain's l2: 0.970491\tvalid's l2: 0.987723\n", + "[40]\ttrain's l2: 0.957984\tvalid's l2: 0.990056\n", + "[60]\ttrain's l2: 0.947201\tvalid's l2: 0.991459\n", + "Early stopping, best iteration is:\n", + "[18]\ttrain's l2: 0.971834\tvalid's l2: 0.987481\n", + "[8348:MainThread](2021-03-09 15:04:19,847) INFO - qlib.workflow - [record_temp.py:126] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2\n", + "'The following are prediction results of the LGBModel model.'\n", + " score\n", + "datetime instrument \n", + "2017-01-03 SH600000 -0.013089\n", + " SH600010 -0.006642\n", + " SH600015 -0.035137\n", + " SH600016 -0.034634\n", + " SH600018 -0.029493\n", + "{'IC': 0.05704431372255674,\n", + " 'ICIR': 0.28879437007622133,\n", + " 'Rank IC': 0.05181220321608411,\n", + " 'Rank ICIR': 0.3233833799543165}\n", + "2021-03-09 15:04:21.111 | INFO | qlib.workflow.task.manage:run_task:355 - {'model': {'class': 'LGBModel', 'module_path': 'qlib.contrib.model.gbdt'}, 'dataset': {'class': 'DatasetH', 'module_path': 'qlib.data.dataset', 'kwargs': {'handler': {'class': 'Alpha158', 'module_path': 'qlib.contrib.data.handler', 'kwargs': {'start_time': '2008-01-01', 'end_time': '2020-08-01', 'fit_start_time': '2008-01-01', 'fit_end_time': '2014-12-31', 'instruments': 'csi100'}}, 'segments': {'train': (Timestamp('2010-04-23 00:00:00'), Timestamp('2017-05-24 00:00:00')), 'valid': (Timestamp('2017-05-25 00:00:00'), Timestamp('2019-04-08 00:00:00')), 'test': (Timestamp('2019-04-09 00:00:00'), Timestamp('2021-07-12 00:00:00'))}}}, 'record': [{'class': 'SignalRecord', 'module_path': 'qlib.workflow.record_temp'}, {'class': 'SigAnaRecord', 'module_path': 'qlib.workflow.record_temp'}], 'task_key': 'task_lgb'}\n", + "[8348:MainThread](2021-03-09 15:05:16,072) INFO - qlib.timer - [log.py:81] - Time cost: 54.958s | Loading data Done\n", + "[8348:MainThread](2021-03-09 15:05:16,466) INFO - qlib.timer - [log.py:81] - Time cost: 0.334s | DropnaLabel Done\n", + "[8348:MainThread](2021-03-09 15:05:22,281) INFO - qlib.timer - [log.py:81] - Time cost: 5.812s | CSZScoreNorm Done\n", + "[8348:MainThread](2021-03-09 15:05:22,283) INFO - qlib.timer - [log.py:81] - Time cost: 6.209s | fit & process data Done\n", + "[8348:MainThread](2021-03-09 15:05:22,286) INFO - qlib.timer - [log.py:81] - Time cost: 61.172s | Init data Done\n", + "[8348:MainThread](2021-03-09 15:05:22,291) INFO - qlib.workflow - [expm.py:245] - No tracking URI is provided. Use the default tracking URI.\n", + "[8348:MainThread](2021-03-09 15:05:22,317) INFO - qlib.workflow - [exp.py:181] - Experiment 2 starts running ...\n", + "[8348:MainThread](2021-03-09 15:05:22,386) INFO - qlib.workflow - [recorder.py:233] - Recorder 0c814539f55842b9b6310843fc5ec708 starts running under Experiment 2 ...\n", + "Training until validation scores don't improve for 50 rounds\n", + "[20]\ttrain's l2: 0.969033\tvalid's l2: 0.98571\n", + "[40]\ttrain's l2: 0.955399\tvalid's l2: 0.986164\n", + "[60]\ttrain's l2: 0.943514\tvalid's l2: 0.986301\n", + "Early stopping, best iteration is:\n", + "[26]\ttrain's l2: 0.964587\tvalid's l2: 0.985356\n", + "[8348:MainThread](2021-03-09 15:05:29,546) INFO - qlib.workflow - [record_temp.py:126] - Signal record 'pred.pkl' has been saved as the artifact of the Experiment 2\n", + "'The following are prediction results of the LGBModel model.'\n", + " score\n", + "datetime instrument \n", + "2019-04-09 SH600000 0.029586\n", + " SH600009 0.004306\n", + " SH600010 -0.004411\n", + " SH600011 0.002707\n", + " SH600015 -0.029124\n", + "{'IC': 0.020784811232504984,\n", + " 'ICIR': 0.11590182186569555,\n", + " 'Rank IC': 0.028925697036767055,\n", + " 'Rank ICIR': 0.16388058980901396}\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 26 + } + ], + "source": [ + "from qlib.workflow.task.manage import run_task\n", + "from qlib.workflow.task.collect import TaskCollector\n", + "from qlib.model.trainer import task_train\n", + "\n", + "run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using \"task_train\" method" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 27, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Loading data: 100%|██████████| 4/4 [00:00<00:00, 37.38it/s]\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{('task_lgb', '2019-04-08'): datetime instrument\n", + " 2017-01-03 SH600000 -0.013089\n", + " SH600010 -0.006642\n", + " SH600015 -0.035137\n", + " SH600016 -0.034634\n", + " SH600018 -0.029493\n", + " ... \n", + " 2019-04-08 SZ002415 0.049199\n", + " SZ002450 -0.013450\n", + " SZ002594 0.022395\n", + " SZ002736 0.091433\n", + " SZ300059 -0.016237\n", + " Name: score, Length: 55000, dtype: float64}" + ] + }, + "metadata": {}, + "execution_count": 27 + } + ], + "source": [ + "def get_task_key(task):\n", + " task_key = task[\"task_key\"]\n", + " rolling_end_timestamp = task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]\n", + " #rolling_end_datatime = rolling_end_timestamp.to_pydatetime()\n", + " return task_key, rolling_end_timestamp.strftime('%Y-%m-%d')\n", + "\n", + "def my_filter(task):\n", + " # only choose the results of \"task_lgb\" and test segment end in 2019 from all tasks\n", + " task_key, rolling_end = get_task_key(task)\n", + " if task_key==\"task_lgb\" and rolling_end.startswith('2019'):\n", + " return True\n", + " return False\n", + "\n", + "# name tasks by \"get_task_key\" and filter tasks by \"my_filter\"\n", + "pred_rolling = TaskCollector.collect(exp_name, get_task_key, my_filter) \n", + "pred_rolling" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "3.6.5-final" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/taskmanager/task_manager_rolling.py b/examples/taskmanager/task_manager_rolling.py new file mode 100644 index 000000000..7519bc4be --- /dev/null +++ b/examples/taskmanager/task_manager_rolling.py @@ -0,0 +1,108 @@ +import qlib +from qlib.config import REG_CN +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.manage import TaskManager +from qlib.config import C + +data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": 'csi100', +} + +dataset_config = { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + } + +record_config = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + } +] + +# use lgb +task_lgb_config = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": dataset_config, + "record": record_config, +} + +# use xgboost +task_xgboost_config = { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": dataset_config, + "record": record_config, +} + +provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir +qlib.init(provider_uri=provider_uri, region=REG_CN) + +C["mongo"] = { + "task_url" : "mongodb://localhost:27017/", # maybe you need to change it to your url + "task_db_name" : "rolling_db" +} + +exp_name = 'rolling_exp' # experiment name, will be used as the experiment in MLflow +task_pool = 'rolling_task' # task pool name, will be used as the document in MongoDB + +tasks = task_generator( + task_xgboost_config, # default task name + RollingGen(step=550,rtype=RollingGen.ROLL_SD), # generate different date segment + task_lgb=task_lgb_config # use "task_lgb" as the task name +) + +# Uncomment next two lines to see the generated tasks +# from pprint import pprint +# pprint(tasks) + +tm = TaskManager(task_pool=task_pool) +tm.create_task(tasks) # all tasks will be saved to MongoDB + +from qlib.workflow.task.manage import run_task +from qlib.workflow.task.collect import RollingCollector +from qlib.model.trainer import task_train + +run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method + +def get_task_key(task_config): + task_key = task_config["task_key"] + rolling_end_timestamp = task_config["dataset"]["kwargs"]["segments"]["test"][1] + #rolling_end_datatime = rolling_end_timestamp.to_pydatetime() + return task_key, rolling_end_timestamp.strftime('%Y-%m-%d') + +def my_filter(task_config): + # only choose the results of "task_lgb" and test in 2019 from all tasks + task_key, rolling_end = get_task_key(task_config) + if task_key=="task_lgb" and rolling_end.startswith('2019'): + return True + return False + +collector = RollingCollector(get_task_key, my_filter) +pred_rolling = collector(exp_name) # name tasks by "get_task_key" and filter tasks by "my_filter" +print(pred_rolling) \ No newline at end of file diff --git a/examples/workflow_task_rolling.ipynb b/examples/workflow_task_rolling.ipynb deleted file mode 100644 index c2d399be0..000000000 --- a/examples/workflow_task_rolling.ipynb +++ /dev/null @@ -1,177 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [ - "import qlib\n", - "from qlib.config import REG_CN\n", - "from qlib.workflow.task.gen import RollingGen, task_generator\n", - "from qlib.workflow.task.manage import TaskManager\n", - "from qlib.config import C\n", - "\n", - "data_handler_config = {\n", - " \"start_time\": \"2008-01-01\",\n", - " \"end_time\": \"2020-08-01\",\n", - " \"fit_start_time\": \"2008-01-01\",\n", - " \"fit_end_time\": \"2014-12-31\",\n", - " \"instruments\": 'csi100',\n", - "}\n", - "\n", - "dataset_config = {\n", - " \"class\": \"DatasetH\",\n", - " \"module_path\": \"qlib.data.dataset\",\n", - " \"kwargs\": {\n", - " \"handler\": {\n", - " \"class\": \"Alpha158\",\n", - " \"module_path\": \"qlib.contrib.data.handler\",\n", - " \"kwargs\": data_handler_config,\n", - " },\n", - " \"segments\": {\n", - " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", - " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", - " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", - " },\n", - " },\n", - " }\n", - "\n", - "record_config = [\n", - " {\n", - " \"class\": \"SignalRecord\",\n", - " \"module_path\": \"qlib.workflow.record_temp\",\n", - " },\n", - " {\n", - " \"class\": \"SigAnaRecord\",\n", - " \"module_path\": \"qlib.workflow.record_temp\",\n", - " }\n", - "]\n", - "\n", - "# use lgb\n", - "task_lgb_config = {\n", - " \"model\": {\n", - " \"class\": \"LGBModel\",\n", - " \"module_path\": \"qlib.contrib.model.gbdt\",\n", - " },\n", - " \"dataset\": dataset_config,\n", - " \"record\": record_config,\n", - "}\n", - "\n", - "# use xgboost\n", - "task_xgboost_config = {\n", - " \"model\": {\n", - " \"class\": \"XGBModel\",\n", - " \"module_path\": \"qlib.contrib.model.xgboost\",\n", - " },\n", - " \"dataset\": dataset_config,\n", - " \"record\": record_config,\n", - "}\n", - "provider_uri = r\"../qlib-main/qlib_data/cn_data\"\n", - "#provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n", - "qlib.init(provider_uri=provider_uri, region=REG_CN)\n", - "\n", - "C[\"mongo\"] = {\n", - " \"task_url\" : \"mongodb://localhost:27017/\", # maybe you need to change it to your url\n", - " \"task_db_name\" : \"rolling_db\"\n", - "}\n", - "\n", - "exp_name = 'rolling_exp' # experiment name, will be used as the experiment in MLflow\n", - "task_pool = 'rolling_task' # task pool name, will be used as the document in MongoDB" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "tasks = task_generator(\n", - " task_xgboost_config, # default task name\n", - " RollingGen(step=550,rtype=RollingGen.ROLL_SD), # generate different date segment\n", - " task_lgb=task_lgb_config # use \"task_lgb\" as the task name\n", - ")\n", - "# Uncomment next two lines to see the generated tasks\n", - "# from pprint import pprint\n", - "# pprint(tasks)\n", - "tm = TaskManager(task_pool=task_pool)\n", - "tm.create_task(tasks) # all tasks will be saved to MongoDB" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "from qlib.workflow.task.manage import run_task\n", - "from qlib.workflow.task.collect import RollingCollector\n", - "from qlib.model.trainer import task_train\n", - "\n", - "run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using \"task_train\" method" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "def get_task_key(task_config):\n", - " task_key = task_config[\"task_key\"]\n", - " rolling_end_timestamp = task_config[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]\n", - " rolling_end_datatime = rolling_end_timestamp.to_pydatetime()\n", - " return task_key, rolling_end_datatime.strftime('%Y-%m-%d')\n", - "\n", - "def my_filter(task_config):\n", - " # only choose the results of \"task_lgb\" and test in 2019 from all tasks\n", - " task_key, rolling_end = get_task_key(task_config)\n", - " if task_key==\"task_lgb\" and rolling_end.startswith('2019'):\n", - " return True\n", - " return False\n", - "\n", - "collector = RollingCollector(get_task_key, my_filter)\n", - "pred_rolling = collector(exp_name) # name tasks by \"get_task_key\" and filter tasks by \"my_filter\"\n", - "pred_rolling" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file