diff --git a/examples/workflow_task_rolling.ipynb b/examples/workflow_task_rolling.ipynb new file mode 100644 index 000000000..c2d399be0 --- /dev/null +++ b/examples/workflow_task_rolling.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import qlib\n", + "from qlib.config import REG_CN\n", + "from qlib.workflow.task.gen import RollingGen, task_generator\n", + "from qlib.workflow.task.manage import TaskManager\n", + "from qlib.config import C\n", + "\n", + "data_handler_config = {\n", + " \"start_time\": \"2008-01-01\",\n", + " \"end_time\": \"2020-08-01\",\n", + " \"fit_start_time\": \"2008-01-01\",\n", + " \"fit_end_time\": \"2014-12-31\",\n", + " \"instruments\": 'csi100',\n", + "}\n", + "\n", + "dataset_config = {\n", + " \"class\": \"DatasetH\",\n", + " \"module_path\": \"qlib.data.dataset\",\n", + " \"kwargs\": {\n", + " \"handler\": {\n", + " \"class\": \"Alpha158\",\n", + " \"module_path\": \"qlib.contrib.data.handler\",\n", + " \"kwargs\": data_handler_config,\n", + " },\n", + " \"segments\": {\n", + " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", + " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", + " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", + " },\n", + " },\n", + " }\n", + "\n", + "record_config = [\n", + " {\n", + " \"class\": \"SignalRecord\",\n", + " \"module_path\": \"qlib.workflow.record_temp\",\n", + " },\n", + " {\n", + " \"class\": \"SigAnaRecord\",\n", + " \"module_path\": \"qlib.workflow.record_temp\",\n", + " }\n", + "]\n", + "\n", + "# use lgb\n", + "task_lgb_config = {\n", + " \"model\": {\n", + " \"class\": \"LGBModel\",\n", + " \"module_path\": \"qlib.contrib.model.gbdt\",\n", + " },\n", + " \"dataset\": dataset_config,\n", + " \"record\": record_config,\n", + "}\n", + "\n", + "# use xgboost\n", + "task_xgboost_config = {\n", + " \"model\": {\n", + " \"class\": \"XGBModel\",\n", + " \"module_path\": \"qlib.contrib.model.xgboost\",\n", + " },\n", + " \"dataset\": dataset_config,\n", + " \"record\": record_config,\n", + "}\n", + "provider_uri = r\"../qlib-main/qlib_data/cn_data\"\n", + "#provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n", + "qlib.init(provider_uri=provider_uri, region=REG_CN)\n", + "\n", + "C[\"mongo\"] = {\n", + " \"task_url\" : \"mongodb://localhost:27017/\", # maybe you need to change it to your url\n", + " \"task_db_name\" : \"rolling_db\"\n", + "}\n", + "\n", + "exp_name = 'rolling_exp' # experiment name, will be used as the experiment in MLflow\n", + "task_pool = 'rolling_task' # task pool name, will be used as the document in MongoDB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "tasks = task_generator(\n", + " task_xgboost_config, # default task name\n", + " RollingGen(step=550,rtype=RollingGen.ROLL_SD), # generate different date segment\n", + " task_lgb=task_lgb_config # use \"task_lgb\" as the task name\n", + ")\n", + "# Uncomment next two lines to see the generated tasks\n", + "# from pprint import pprint\n", + "# pprint(tasks)\n", + "tm = TaskManager(task_pool=task_pool)\n", + "tm.create_task(tasks) # all tasks will be saved to MongoDB" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from qlib.workflow.task.manage import run_task\n", + "from qlib.workflow.task.collect import RollingCollector\n", + "from qlib.model.trainer import task_train\n", + "\n", + "run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using \"task_train\" method" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "def get_task_key(task_config):\n", + " task_key = task_config[\"task_key\"]\n", + " rolling_end_timestamp = task_config[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]\n", + " rolling_end_datatime = rolling_end_timestamp.to_pydatetime()\n", + " return task_key, rolling_end_datatime.strftime('%Y-%m-%d')\n", + "\n", + "def my_filter(task_config):\n", + " # only choose the results of \"task_lgb\" and test in 2019 from all tasks\n", + " task_key, rolling_end = get_task_key(task_config)\n", + " if task_key==\"task_lgb\" and rolling_end.startswith('2019'):\n", + " return True\n", + " return False\n", + "\n", + "collector = RollingCollector(get_task_key, my_filter)\n", + "pred_rolling = collector(exp_name) # name tasks by \"get_task_key\" and filter tasks by \"my_filter\"\n", + "pred_rolling" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file