mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 01:21:18 +08:00
176 lines
5.3 KiB
Plaintext
176 lines
5.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import qlib\n",
|
|
"from qlib.config import REG_CN\n",
|
|
"from qlib.workflow.task.gen import RollingGen, task_generator\n",
|
|
"from qlib.workflow.task.manage import TaskManager\n",
|
|
"from qlib.config import C\n",
|
|
"\n",
|
|
"data_handler_template = {\n",
|
|
" \"start_time\": \"2008-01-01\",\n",
|
|
" \"end_time\": \"2020-08-01\",\n",
|
|
" \"fit_start_time\": \"2008-01-01\",\n",
|
|
" \"fit_end_time\": \"2014-12-31\",\n",
|
|
" \"instruments\": 'csi100',\n",
|
|
"}\n",
|
|
"\n",
|
|
"dataset_template = {\n",
|
|
" \"class\": \"DatasetH\",\n",
|
|
" \"module_path\": \"qlib.data.dataset\",\n",
|
|
" \"kwargs\": {\n",
|
|
" \"handler\": {\n",
|
|
" \"class\": \"Alpha158\",\n",
|
|
" \"module_path\": \"qlib.contrib.data.handler\",\n",
|
|
" \"kwargs\": data_handler_template,\n",
|
|
" },\n",
|
|
" \"segments\": {\n",
|
|
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
|
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
|
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
|
" },\n",
|
|
" },\n",
|
|
" }\n",
|
|
"\n",
|
|
"record_template = [\n",
|
|
" {\n",
|
|
" \"class\": \"SignalRecord\",\n",
|
|
" \"module_path\": \"qlib.workflow.record_temp\",\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"class\": \"SigAnaRecord\",\n",
|
|
" \"module_path\": \"qlib.workflow.record_temp\",\n",
|
|
" }\n",
|
|
"]\n",
|
|
"\n",
|
|
"# use lgb\n",
|
|
"lgb_task_template = {\n",
|
|
" \"model\": {\n",
|
|
" \"class\": \"LGBModel\",\n",
|
|
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
|
|
" },\n",
|
|
" \"dataset\": dataset_template,\n",
|
|
" \"record\": record_template,\n",
|
|
"}\n",
|
|
"\n",
|
|
"# use xgboost\n",
|
|
"xgboost_task_template = {\n",
|
|
" \"model\": {\n",
|
|
" \"class\": \"XGBModel\",\n",
|
|
" \"module_path\": \"qlib.contrib.model.xgboost\",\n",
|
|
" },\n",
|
|
" \"dataset\": dataset_template,\n",
|
|
" \"record\": record_template,\n",
|
|
"}\n",
|
|
"\n",
|
|
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
|
"qlib.init(provider_uri=provider_uri, region=REG_CN)\n",
|
|
"\n",
|
|
"C[\"mongo\"] = {\n",
|
|
" \"task_url\" : \"mongodb://localhost:27017/\", # maybe you need to change it to your url\n",
|
|
" \"task_db_name\" : \"rolling_db\"\n",
|
|
"}\n",
|
|
"\n",
|
|
"exp_name = 'rolling_exp' # experiment name, will be used as the experiment in MLflow\n",
|
|
"task_pool = 'rolling_task' # task pool name, will be used as the document in MongoDB"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"tasks = task_generator(\n",
|
|
" xgboost_task_template, # default task name\n",
|
|
" RollingGen(step=550,rtype=RollingGen.ROLL_SD), # generate different date segment\n",
|
|
" task_lgb=lgb_task_template # use \"task_lgb\" as the task name\n",
|
|
")\n",
|
|
"# Uncomment next two lines to see the generated tasks\n",
|
|
"# from pprint import pprint\n",
|
|
"# pprint(tasks)\n",
|
|
"tm = TaskManager(task_pool=task_pool)\n",
|
|
"tm.create_task(tasks) # all tasks will be saved to MongoDB"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"from qlib.workflow.task.manage import run_task\n",
|
|
"from qlib.workflow.task.collect import TaskCollector\n",
|
|
"from qlib.model.trainer import task_train\n",
|
|
"\n",
|
|
"run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using \"task_train\" method"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_task_key(task):\n",
|
|
" task_key = task[\"task_key\"]\n",
|
|
" rolling_end_timestamp = task[\"dataset\"][\"kwargs\"][\"segments\"][\"test\"][1]\n",
|
|
" return task_key, rolling_end_timestamp.strftime('%Y-%m-%d')\n",
|
|
"\n",
|
|
"def my_filter(task):\n",
|
|
" # only choose the results of \"task_lgb\" and test segment end in 2019 from all tasks\n",
|
|
" task_key, rolling_end = get_task_key(task)\n",
|
|
" if task_key==\"task_lgb\" and rolling_end.startswith('2019'):\n",
|
|
" return True\n",
|
|
" return False\n",
|
|
"\n",
|
|
"# name tasks by \"get_task_key\" and filter tasks by \"my_filter\"\n",
|
|
"pred_rolling = TaskCollector.collect_predictions(exp_name, get_task_key, my_filter) \n",
|
|
"pred_rolling"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "3.6.5-final"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |