From 98eacf8f88de66aa88ee877004a76c2a60f7c5f5 Mon Sep 17 00:00:00 2001 From: zhupr Date: Fri, 28 May 2021 13:24:47 +0800 Subject: [PATCH] add test/config.py --- examples/highfreq/workflow.py | 17 +-- .../LightGBM/hyperparameter_158.py | 58 +++------- .../LightGBM/hyperparameter_360.py | 57 +++------ examples/model_interpreter.py | 81 ------------- examples/model_interpreter/feature.py | 32 ++++++ .../model_rolling/task_manager_rolling.py | 62 +--------- .../online_srv/online_management_simulate.py | 62 +--------- .../online_srv/rolling_online_management.py | 65 ++--------- examples/online_srv/update_online_pred.py | 49 +------- examples/rolling_process_data/workflow.py | 8 +- examples/run_all_model.py | 11 +- examples/workflow_by_code.py | 76 ++---------- qlib/model/interpret/base.py | 9 +- qlib/tests/__init__.py | 18 ++- qlib/tests/config.py | 108 ++++++++++++++++++ qlib/tests/data.py | 7 ++ tests/dataset_tests/test_datalayer.py | 22 +--- tests/test_all_pipeline.py | 66 ++--------- tests/test_contrib_workflow.py | 65 ++--------- tests/test_get_data.py | 5 +- tests/test_register_ops.py | 5 - 21 files changed, 246 insertions(+), 637 deletions(-) delete mode 100644 examples/model_interpreter.py create mode 100644 examples/model_interpreter/feature.py create mode 100644 qlib/tests/config.py diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 5660ab2e9..856885b25 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -1,24 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import fire -from pathlib import Path import qlib import pickle -import numpy as np -import pandas as pd from qlib.config import REG_CN, HIGH_FREQ_CONFIG -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.utils import init_instance_by_config, exists_qlib_data +from qlib.utils import init_instance_by_config from qlib.data.dataset.handler import DataHandlerLP from qlib.data.ops import Operators from qlib.data.data import Cal @@ -96,9 +85,7 @@ class HighfreqWorkflow: # use yahoo_cn_1min data QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} provider_uri = QLIB_INIT_CONFIG.get("provider_uri") - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) qlib.init(**QLIB_INIT_CONFIG) def _prepare_calender_cache(self): diff --git a/examples/hyperparameter/LightGBM/hyperparameter_158.py b/examples/hyperparameter/LightGBM/hyperparameter_158.py index 5e4887a14..9e4557ed5 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_158.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_158.py @@ -1,46 +1,9 @@ import qlib -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data, init_instance_by_config import optuna - -provider_uri = "~/.qlib/qlib_data/cn_data" -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(scripts_dir)) - from get_data import GetData - - GetData().qlib_data(target_dir=provider_uri, region="cn") -qlib.init(provider_uri=provider_uri, region="cn") - -market = "csi300" -benchmark = "SH000300" - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} -dataset_task = { - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} -dataset = init_instance_by_config(dataset_task["dataset"]) +from qlib.config import REG_CN +from qlib.utils import init_instance_by_config +from qlib.tests.config import CSI300_DATASET_CONFIG +from qlib.tests.data import GetData def objective(trial): @@ -65,12 +28,19 @@ def objective(trial): }, }, } - evals_result = dict() model = init_instance_by_config(task["model"]) model.fit(dataset, evals_result=evals_result) return min(evals_result["valid"]) -study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3") -study.optimize(objective, n_jobs=6) +if __name__ == "__main__": + + provider_uri = "~/.qlib/qlib_data/cn_data" + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + qlib.init(provider_uri=provider_uri, region="cn") + + dataset = init_instance_by_config(CSI300_DATASET_CONFIG) + + study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3") + study.optimize(objective, n_jobs=6) diff --git a/examples/hyperparameter/LightGBM/hyperparameter_360.py b/examples/hyperparameter/LightGBM/hyperparameter_360.py index 8b498e912..a8127014b 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_360.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_360.py @@ -1,46 +1,11 @@ import qlib -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data, init_instance_by_config import optuna +from qlib.config import REG_CN +from qlib.utils import init_instance_by_config +from qlib.tests.data import GetData +from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS -provider_uri = "~/.qlib/qlib_data/cn_data" -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(scripts_dir)) - from get_data import GetData - - GetData().qlib_data(target_dir=provider_uri, region="cn") -qlib.init(provider_uri=provider_uri, region="cn") - -market = "csi300" -benchmark = "SH000300" - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} -dataset_task = { - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha360", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} -dataset = init_instance_by_config(dataset_task["dataset"]) +DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS) def objective(trial): @@ -72,5 +37,13 @@ def objective(trial): return min(evals_result["valid"]) -study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3") -study.optimize(objective, n_jobs=6) +if __name__ == "__main__": + + provider_uri = "~/.qlib/qlib_data/cn_data" + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + qlib.init(provider_uri=provider_uri, region=REG_CN) + + dataset = init_instance_by_config(DATASET_CONFIG) + + study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3") + study.optimize(objective, n_jobs=6) diff --git a/examples/model_interpreter.py b/examples/model_interpreter.py deleted file mode 100644 index 1d9230b8c..000000000 --- a/examples/model_interpreter.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import qlib -from qlib.config import REG_CN - -from qlib.utils import exists_qlib_data, init_instance_by_config -from qlib.tests.data import GetData - -market = "csi300" -benchmark = "SH000300" - -################################### -# config -################################### -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} - -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} - - -if __name__ == "__main__": - - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - - qlib.init(provider_uri=provider_uri, region=REG_CN) - - ################################### - # train model - ################################### - # model initialization - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) - model.fit(dataset) - - # get model feature importance - feature_importance = model.get_feature_importance() - print("feature importance:") - print(feature_importance) diff --git a/examples/model_interpreter/feature.py b/examples/model_interpreter/feature.py new file mode 100644 index 000000000..1c29fda6e --- /dev/null +++ b/examples/model_interpreter/feature.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import qlib +from qlib.config import REG_CN + +from qlib.utils import init_instance_by_config +from qlib.tests.data import GetData +from qlib.tests.config import CSI300_GBDT_TASK + + +if __name__ == "__main__": + + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + ################################### + # train model + ################################### + # model initialization + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) + model.fit(dataset) + + # get model feature importance + feature_importance = model.get_feature_importance() + print("feature importance:") + print(feature_importance) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 4f3ac04b1..9ef8694bf 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -17,63 +17,7 @@ from qlib.workflow.task.manage import TaskManager from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup from qlib.model.trainer import TrainerRM - - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG class RollingTaskExample: @@ -85,11 +29,13 @@ class RollingTaskExample: task_db_name="rolling_db", experiment_name="rolling_exp", task_pool="rolling_task", - task_config=[task_xgboost_config, task_lgb_config], + task_config=None, rolling_step=550, rolling_type=RollingGen.ROLL_SD, ): # TaskManager config + if task_config is None: + task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 4bb5022ee..8c9e77bf7 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -13,63 +13,7 @@ from qlib.workflow.online.manager import OnlineManager from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager - - -data_handler_config = { - "start_time": "2018-01-01", - "end_time": "2018-10-31", - "fit_start_time": "2018-01-01", - "fit_end_time": "2018-03-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2018-01-01", "2018-03-31"), - "valid": ("2018-04-01", "2018-05-31"), - "test": ("2018-06-01", "2018-09-10"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb model -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost model -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG class OnlineSimulationExample: @@ -84,7 +28,7 @@ class OnlineSimulationExample: rolling_step=80, start_time="2018-09-10", end_time="2018-10-31", - tasks=[task_xgboost_config, task_lgb_config], + tasks=None, ): """ Init OnlineManagerExample. @@ -101,6 +45,8 @@ class OnlineSimulationExample: end_time (str, optional): the end time of simulating. Defaults to "2018-10-31". tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ + if tasks is None: + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 25b8b2a0c..592f1f866 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -17,62 +17,7 @@ from qlib.workflow import R from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager - -data_handler_config = { - "start_time": "2013-01-01", - "end_time": "2020-09-25", - "fit_start_time": "2013-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2013-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2015-12-31"), - "test": ("2016-01-01", "2020-07-10"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb model -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost model -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG class RollingOnlineExample: @@ -83,9 +28,13 @@ class RollingOnlineExample: task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550, - tasks=[task_xgboost_config], - add_tasks=[task_lgb_config], + tasks=None, + add_tasks=None, ): + if add_tasks is None: + add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG] + if tasks is None: + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG] mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index 228bc0dac..8afc66553 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -7,56 +7,19 @@ There are two parts including first_train and update_online_pred. Firstly, we will finish the training and set the trained models to the `online` models. Next, we will finish updating online predictions. """ +import copy import fire import qlib from qlib.config import REG_CN from qlib.model.trainer import task_train from qlib.workflow.online.utils import OnlineToolR +from qlib.tests.config import CSI300_GBDT_TASK -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} +task = copy.deepcopy(CSI300_GBDT_TASK) -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - "record": { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, +task["record"] = { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", } diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 5757aaa87..bfa2d1ec4 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -4,13 +4,11 @@ import qlib import fire import pickle -import pandas as pd from datetime import datetime from qlib.config import REG_CN from qlib.data.dataset.handler import DataHandlerLP -from qlib.contrib.data.handler import Alpha158 -from qlib.utils import exists_qlib_data, init_instance_by_config +from qlib.utils import init_instance_by_config from qlib.tests.data import GetData @@ -25,9 +23,7 @@ class RollingDataWorkflow: """initialize qlib""" # use yahoo_cn_1min data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) def _dump_pre_handler(self, path): diff --git a/examples/run_all_model.py b/examples/run_all_model.py index d587eff15..8875b9aa1 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -5,13 +5,11 @@ import os import sys import fire import time -import venv import glob import shutil import signal import inspect import tempfile -import traceback import functools import statistics import subprocess @@ -23,8 +21,7 @@ from pprint import pprint import qlib from qlib.config import REG_CN from qlib.workflow import R -from qlib.workflow.cli import workflow -from qlib.utils import exists_qlib_data +from qlib.tests.data import GetData # init qlib @@ -39,12 +36,8 @@ exp_manager = { "default_exp_name": "Experiment", }, } -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) - from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) +GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager) # decorator to check the arguments diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index d5dab8917..2e84cadc2 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -1,82 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys -from pathlib import Path - import qlib -import pandas as pd from qlib.config import REG_CN -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict +from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData +from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK + if __name__ == "__main__": # use default data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) qlib.init(provider_uri=provider_uri, region=REG_CN) - market = "csi300" - benchmark = "SH000300" - - ################################### - # train model - ################################### - data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, - } - - task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - } - port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", @@ -90,7 +30,7 @@ if __name__ == "__main__": "verbose": False, "limit_threshold": 0.095, "account": 100000000, - "benchmark": benchmark, + "benchmark": CSI300_BENCH, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, @@ -100,8 +40,8 @@ if __name__ == "__main__": } # model initialization - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # NOTE: This line is optional # It demonstrates that the dataset can be used standalone. @@ -110,7 +50,7 @@ if __name__ == "__main__": # start exp with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) R.save_objects(**{"params.pkl": model}) diff --git a/qlib/model/interpret/base.py b/qlib/model/interpret/base.py index 70d79faca..57cc7929a 100644 --- a/qlib/model/interpret/base.py +++ b/qlib/model/interpret/base.py @@ -14,7 +14,14 @@ class FeatureInt: @abstractmethod def get_feature_importance(self) -> pd.Series: - ... + """get feature importance + + Returns + ------- + The index is the feature name. + + The greater the value, the higher importance. + """ class LightGBMFInt(FeatureInt): diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index 8b53bc53a..e72f000ba 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -1,6 +1,4 @@ -import sys import unittest -from ..utils import exists_qlib_data from .data import GetData from .. import init from ..config import REG_CN @@ -14,14 +12,12 @@ class TestAutoData(unittest.TestCase): @classmethod def setUpClass(cls) -> None: # use default data - if not exists_qlib_data(cls.provider_uri): - print(f"Qlib data is not found in {cls.provider_uri}") - GetData().qlib_data( - name="qlib_data_simple", - region="cn", - interval="1d", - target_dir=cls.provider_uri, - delete_old=False, - ) + GetData().qlib_data( + name="qlib_data_simple", + region=REG_CN, + interval="1d", + target_dir=cls.provider_uri, + delete_old=False, + ) init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs) diff --git a/qlib/tests/config.py b/qlib/tests/config.py new file mode 100644 index 000000000..80461f6f9 --- /dev/null +++ b/qlib/tests/config.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +CSI300_MARKET = "csi300" +CSI100_MARKET = "csi100" + +CSI300_BENCH = "SH000300" + +DATASET_ALPHA158_CLASS = "Alpha158" +DATASET_ALPHA360_CLASS = "Alpha360" + +################################### +# config +################################### + + +GBDT_MODEL = { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, +} + + +RECORD_CONFIG = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + }, +] + + +def get_data_handler_config(market=CSI300_MARKET): + return { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, + } + + +def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS): + return { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": dataset_class, + "module_path": "qlib.contrib.data.handler", + "kwargs": get_data_handler_config(market), + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + } + + +def get_gbdt_task(market=CSI300_MARKET): + return { + "model": GBDT_MODEL, + "dataset": get_dataset_config(market), + } + + +def get_record_lgb_config(market=CSI300_MARKET): + return { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": get_dataset_config(market), + "record": RECORD_CONFIG, + } + + +def get_record_xgboost_config(market=CSI300_MARKET): + return { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": get_dataset_config(market), + "record": RECORD_CONFIG, + } + + +CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET) +CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET) + +CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET) +CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET) diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 3bf6a2c96..0f226c6b1 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -10,6 +10,7 @@ import datetime from tqdm import tqdm from pathlib import Path from loguru import logger +from qlib.utils import exists_qlib_data class GetData: @@ -112,6 +113,7 @@ class GetData: interval="1d", region="cn", delete_old=True, + exists_skip=True, ): """download cn qlib data from remote @@ -129,6 +131,8 @@ class GetData: data region, value from [cn, us], by default cn delete_old: bool delete an existing directory, by default True + exists_skip: bool + exists skip, by default True Examples --------- @@ -140,6 +144,9 @@ class GetData: ------- """ + if exists_skip and exists_qlib_data(target_dir): + return + qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__)) def _get_file_name(v): diff --git a/tests/dataset_tests/test_datalayer.py b/tests/dataset_tests/test_datalayer.py index 9d282b167..bdd0d915b 100644 --- a/tests/dataset_tests/test_datalayer.py +++ b/tests/dataset_tests/test_datalayer.py @@ -1,26 +1,10 @@ -import sys -from pathlib import Path -import qlib -from qlib.data import D -from qlib.config import REG_CN import unittest import numpy as np -from qlib.utils import exists_qlib_data +from qlib.data import D +from qlib.tests import TestAutoData -class TestDataset(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts"))) - from get_data import GetData - - GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri) - qlib.init(provider_uri=provider_uri, region=REG_CN) - +class TestDataset(TestAutoData): def testCSI300(self): close_p = D.features(D.instruments("csi300"), ["$close"]) size = close_p.groupby("datetime").size() diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index d34c1773a..4c20405fa 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -12,55 +12,7 @@ from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord from qlib.tests import TestAutoData - - -market = "csi300" -benchmark = "SH000300" - -################################### -# train model -################################### -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} - -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} +from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH port_analysis_config = { "strategy": { @@ -75,7 +27,7 @@ port_analysis_config = { "verbose": False, "limit_threshold": 0.095, "account": 100000000, - "benchmark": benchmark, + "benchmark": CSI300_BENCH, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, @@ -96,15 +48,15 @@ def train(): """ # model initiaiton - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # To test __repr__ print(dataset) print(R) # start exp with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) # prediction @@ -137,12 +89,12 @@ def train_with_sigana(): performance: dict model performance """ - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # start exp with R.start(experiment_name="workflow_with_sigana"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) # predict and calculate ic and ric @@ -171,7 +123,7 @@ def fake_experiment(): default_uri = R.get_uri() current_uri = "file:./temp-test-exp-mag" with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) current_uri_to_check = R.get_uri() default_uri_to_check = R.get_uri() diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py index ccd3c6a90..9b1edbd4e 100644 --- a/tests/test_contrib_workflow.py +++ b/tests/test_contrib_workflow.py @@ -1,73 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import shutil import unittest from pathlib import Path -import qlib -from qlib.config import C from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.tests import TestAutoData - - -market = "csi300" -benchmark = "SH000300" - -################################### -# train model -################################### -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} - -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} +from qlib.tests.config import CSI300_GBDT_TASK def train_multiseg(): - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = MultiSegRecord(model, dataset, recorder) @@ -77,10 +26,10 @@ def train_multiseg(): def train_mse(): - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = SignalMseRecord(recorder, model=model, dataset=dataset) diff --git a/tests/test_get_data.py b/tests/test_get_data.py index c511d1b91..55a2c3318 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -1,16 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import shutil import unittest from pathlib import Path -sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) -from get_data import GetData - import qlib from qlib.data import D +from qlib.tests.data import GetData DATA_DIR = Path(__file__).parent.joinpath("test_get_data") SOURCE_DIR = DATA_DIR.joinpath("source") diff --git a/tests/test_register_ops.py b/tests/test_register_ops.py index 7d3322ddc..ac86be59c 100644 --- a/tests/test_register_ops.py +++ b/tests/test_register_ops.py @@ -1,17 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import unittest import numpy as np -import qlib from qlib.data import D from qlib.data.ops import ElemOperator, PairOperator -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data from qlib.tests import TestAutoData -from qlib.tests.data import GetData class Diff(ElemOperator):