{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import qlib\n", "from qlib.config import REG_CN\n", "from qlib.workflow.task.gen import RollingGen, task_generator\n", "from qlib.workflow.task.manage import TaskManager\n", "from qlib.config import C\n", "\n", "data_handler_template = {\n", " \"start_time\": \"2008-01-01\",\n", " \"end_time\": \"2020-08-01\",\n", " \"fit_start_time\": \"2008-01-01\",\n", " \"fit_end_time\": \"2014-12-31\",\n", " \"instruments\": 'csi100',\n", "}\n", "\n", "dataset_template = {\n", " \"class\": \"DatasetH\",\n", " \"module_path\": \"qlib.data.dataset\",\n", " \"kwargs\": {\n", " \"handler\": {\n", " \"class\": \"Alpha158\",\n", " \"module_path\": \"qlib.contrib.data.handler\",\n", " \"kwargs\": data_handler_template,\n", " },\n", " \"segments\": {\n", " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n", " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n", " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n", " },\n", " },\n", " }\n", "\n", "record_template = [\n", " {\n", " \"class\": \"SignalRecord\",\n", " \"module_path\": \"qlib.workflow.record_temp\",\n", " },\n", " {\n", " \"class\": \"SigAnaRecord\",\n", " \"module_path\": \"qlib.workflow.record_temp\",\n", " }\n", "]\n", "\n", "# use lgb\n", "lgb_task_template = {\n", " \"model\": {\n", " \"class\": \"LGBModel\",\n", " \"module_path\": \"qlib.contrib.model.gbdt\",\n", " },\n", " \"dataset\": dataset_template,\n", " \"record\": record_template,\n", "}\n", "\n", "# use xgboost\n", "xgboost_task_template = {\n", " \"model\": {\n", " \"class\": \"XGBModel\",\n", " \"module_path\": \"qlib.contrib.model.xgboost\",\n", " },\n", " \"dataset\": dataset_template,\n", " \"record\": record_template,\n", "}\n", "\n", "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n", "qlib.init(provider_uri=provider_uri, region=REG_CN)\n", "\n", "C[\"mongo\"] = {\n", " \"task_url\" : \"mongodb://localhost:27017/\", # maybe you need to change it to your url\n", " \"task_db_name\" : \"rolling_db\"\n", "}\n", "\n", "exp_name = 'rolling_exp' # experiment name, will be used as the experiment in MLflow\n", "task_pool = 'rolling_task' # task pool name, will be used as the document in MongoDB" ] }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "tasks = task_generator(\n", " xgboost_task_template, # default task name\n", " RollingGen(step=550,rtype=RollingGen.ROLL_SD), # generate different date segment\n", " task_lgb=lgb_task_template # use \"task_lgb\" as the task name\n", ")\n", "# Uncomment next two lines to see the generated tasks\n", "# from pprint import pprint\n", "# pprint(tasks)\n", "tm = TaskManager(task_pool=task_pool)\n", "tm.create_task(tasks) # all tasks will be saved to MongoDB" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "from qlib.workflow.task.manage import run_task\n", "from qlib.workflow.task.collect import TaskCollector\n", "from qlib.model.trainer import task_train\n", "\n", "run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using \"task_train\" method" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "def get_task_key(task):\n", " task_key = task[\"task_key\"]\n", " rolling_end_timestamp = task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]\n", " return task_key, rolling_end_timestamp.strftime('%Y-%m-%d')\n", "\n", "def my_filter(task):\n", " # only choose the results of \"task_lgb\" and test segment end in 2019 from all tasks\n", " task_key, rolling_end = get_task_key(task)\n", " if task_key==\"task_lgb\" and rolling_end.startswith('2019'):\n", " return True\n", " return False\n", "\n", "# name tasks by \"get_task_key\" and filter tasks by \"my_filter\"\n", "pred_rolling = TaskCollector.collect_predictions(exp_name, get_task_key, my_filter) \n", "pred_rolling" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "3.6.5-final" } }, "nbformat": 4, "nbformat_minor": 0 }