diff --git a/docs/component/data.rst b/docs/component/data.rst index 0a650c523..cd30ee98b 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -72,12 +72,19 @@ Converting CSV Format into Qlib Format ``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format. -Users can download the demo china-stock data in CSV format as follows for reference to the CSV format. +Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format. +Here are some example: -.. code-block:: bash +for daily data: + .. code-block:: bash python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data +for 1min data: + .. code-block:: bash + + python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10 + Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions: - CSV file is named after a specific stock *or* the CSV file includes a column of the stock name @@ -145,6 +152,16 @@ After conversion, users can find their Qlib format data in the directory `~/.qli In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended. +Stock Pool (Market) +-------------------------------- + +``Qlib`` defines `stock pool `_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows. + +.. code-block:: bash + + python collector.py --index_name CSI300 --qlib_dir --method parse_instruments + + Multiple Stock Modes -------------------------------- diff --git a/docs/component/report.rst b/docs/component/report.rst index 7d8053c78..6f4bff4f9 100644 --- a/docs/component/report.rst +++ b/docs/component/report.rst @@ -101,7 +101,7 @@ Graphical Result - Axis Y: - `ic` The `Pearson correlation coefficient` series between `label` and `prediction score`. - In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue `_ for more details. + In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature `_ for more details. - `rank_ic` The `Spearman's rank correlation coefficient` series between `label` and `prediction score`. diff --git a/docs/component/strategy.rst b/docs/component/strategy.rst index 0720dcdad..e4a5a94d1 100644 --- a/docs/component/strategy.rst +++ b/docs/component/strategy.rst @@ -111,8 +111,6 @@ Usage & Example pred_score, strategy=strategy, **BACKTEST_CONFIG ) -Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``. - To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction `_. To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing `_. diff --git a/docs/reference/api.rst b/docs/reference/api.rst index 57f61f18b..5e6e50b0b 100644 --- a/docs/reference/api.rst +++ b/docs/reference/api.rst @@ -53,6 +53,34 @@ Cache .. autoclass:: qlib.data.cache.DiskDatasetCache :members: + +Storage +------------- +.. autoclass:: qlib.data.storage.storage.BaseStorage + :members: + +.. autoclass:: qlib.data.storage.storage.CalendarStorage + :members: + +.. autoclass:: qlib.data.storage.storage.InstrumentStorage + :members: + +.. autoclass:: qlib.data.storage.storage.FeatureStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage + :members: + +.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage + :members: + + Dataset --------------- diff --git a/docs/start/integration.rst b/docs/start/integration.rst index 3ecae1090..3d4043826 100644 --- a/docs/start/integration.rst +++ b/docs/start/integration.rst @@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html# return pd.Series(self.model.predict(x_test.values), index=x_test.index) - Override the `finetune` method (Optional) - - This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`. + - This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`. - The parameters must include the parameter `dataset`. - Code Example: In the following example, users will use `LightGBM` as the model and finetune it. .. code-block:: Python diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 5660ab2e9..7bf5fd09a 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -1,24 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import fire -from pathlib import Path import qlib import pickle -import numpy as np -import pandas as pd from qlib.config import REG_CN, HIGH_FREQ_CONFIG -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.utils import init_instance_by_config, exists_qlib_data +from qlib.utils import init_instance_by_config from qlib.data.dataset.handler import DataHandlerLP from qlib.data.ops import Operators from qlib.data.data import Cal @@ -96,9 +85,7 @@ class HighfreqWorkflow: # use yahoo_cn_1min data QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} provider_uri = QLIB_INIT_CONFIG.get("provider_uri") - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True) qlib.init(**QLIB_INIT_CONFIG) def _prepare_calender_cache(self): diff --git a/examples/hyperparameter/LightGBM/hyperparameter_158.py b/examples/hyperparameter/LightGBM/hyperparameter_158.py index 5e4887a14..89cc10cc6 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_158.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_158.py @@ -1,46 +1,9 @@ import qlib -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data, init_instance_by_config import optuna - -provider_uri = "~/.qlib/qlib_data/cn_data" -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(scripts_dir)) - from get_data import GetData - - GetData().qlib_data(target_dir=provider_uri, region="cn") -qlib.init(provider_uri=provider_uri, region="cn") - -market = "csi300" -benchmark = "SH000300" - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} -dataset_task = { - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} -dataset = init_instance_by_config(dataset_task["dataset"]) +from qlib.config import REG_CN +from qlib.utils import init_instance_by_config +from qlib.tests.config import CSI300_DATASET_CONFIG +from qlib.tests.data import GetData def objective(trial): @@ -65,12 +28,19 @@ def objective(trial): }, }, } - evals_result = dict() model = init_instance_by_config(task["model"]) model.fit(dataset, evals_result=evals_result) return min(evals_result["valid"]) -study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3") -study.optimize(objective, n_jobs=6) +if __name__ == "__main__": + + provider_uri = "~/.qlib/qlib_data/cn_data" + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) + qlib.init(provider_uri=provider_uri, region="cn") + + dataset = init_instance_by_config(CSI300_DATASET_CONFIG) + + study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3") + study.optimize(objective, n_jobs=6) diff --git a/examples/hyperparameter/LightGBM/hyperparameter_360.py b/examples/hyperparameter/LightGBM/hyperparameter_360.py index 8b498e912..bc0cc245d 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_360.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_360.py @@ -1,46 +1,11 @@ import qlib -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data, init_instance_by_config import optuna +from qlib.config import REG_CN +from qlib.utils import init_instance_by_config +from qlib.tests.data import GetData +from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS -provider_uri = "~/.qlib/qlib_data/cn_data" -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(scripts_dir)) - from get_data import GetData - - GetData().qlib_data(target_dir=provider_uri, region="cn") -qlib.init(provider_uri=provider_uri, region="cn") - -market = "csi300" -benchmark = "SH000300" - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} -dataset_task = { - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha360", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} -dataset = init_instance_by_config(dataset_task["dataset"]) +DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS) def objective(trial): @@ -72,5 +37,13 @@ def objective(trial): return min(evals_result["valid"]) -study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3") -study.optimize(objective, n_jobs=6) +if __name__ == "__main__": + + provider_uri = "~/.qlib/qlib_data/cn_data" + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) + qlib.init(provider_uri=provider_uri, region=REG_CN) + + dataset = init_instance_by_config(DATASET_CONFIG) + + study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3") + study.optimize(objective, n_jobs=6) diff --git a/examples/model_interpreter/feature.py b/examples/model_interpreter/feature.py new file mode 100644 index 000000000..a1288e07d --- /dev/null +++ b/examples/model_interpreter/feature.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import qlib +from qlib.config import REG_CN + +from qlib.utils import init_instance_by_config +from qlib.tests.data import GetData +from qlib.tests.config import CSI300_GBDT_TASK + + +if __name__ == "__main__": + + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + ################################### + # train model + ################################### + # model initialization + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) + model.fit(dataset) + + # get model feature importance + feature_importance = model.get_feature_importance() + print("feature importance:") + print(feature_importance) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 89233b37b..844f18198 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -17,64 +17,8 @@ from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import TrainerRM, task_train - - -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.model.trainer import TrainerRM +from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG class RollingTaskExample: @@ -86,11 +30,13 @@ class RollingTaskExample: task_db_name="rolling_db", experiment_name="rolling_exp", task_pool="rolling_task", - task_config=[task_xgboost_config, task_lgb_config], + task_config=None, rolling_step=550, rolling_type=RollingGen.ROLL_SD, ): # TaskManager config + if task_config is None: + task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index de6dbcb21..5f024192f 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -13,63 +13,7 @@ from qlib.workflow.online.manager import OnlineManager from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager - - -data_handler_config = { - "start_time": "2018-01-01", - "end_time": "2018-10-31", - "fit_start_time": "2018-01-01", - "fit_end_time": "2018-03-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2018-01-01", "2018-03-31"), - "valid": ("2018-04-01", "2018-05-31"), - "test": ("2018-06-01", "2018-09-10"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb model -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost model -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG class OnlineSimulationExample: @@ -84,7 +28,7 @@ class OnlineSimulationExample: rolling_step=80, start_time="2018-09-10", end_time="2018-10-31", - tasks=[task_xgboost_config, task_lgb_config], + tasks=None, ): """ Init OnlineManagerExample. @@ -101,6 +45,8 @@ class OnlineSimulationExample: end_time (str, optional): the end time of simulating. Defaults to "2018-10-31". tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ + if tasks is None: + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 40da30db7..b4f7245b7 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -18,63 +18,7 @@ from qlib.workflow import R from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager -from qlib.workflow.task.manage import TaskManager, run_task - -data_handler_config = { - "start_time": "2013-01-01", - "end_time": "2020-09-25", - "fit_start_time": "2013-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} - -dataset_config = { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2013-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2015-12-31"), - "test": ("2016-01-01", "2020-07-10"), - }, - }, -} - -record_config = [ - { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, - { - "class": "SigAnaRecord", - "module_path": "qlib.workflow.record_temp", - }, -] - -# use lgb model -task_lgb_config = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - }, - "dataset": dataset_config, - "record": record_config, -} - -# use xgboost model -task_xgboost_config = { - "model": { - "class": "XGBModel", - "module_path": "qlib.contrib.model.xgboost", - }, - "dataset": dataset_config, - "record": record_config, -} +from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG class RollingOnlineExample: @@ -86,9 +30,13 @@ class RollingOnlineExample: task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR rolling_step=550, - tasks=[task_xgboost_config], - add_tasks=[task_lgb_config], + tasks=None, + add_tasks=None, ): + if add_tasks is None: + add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG] + if tasks is None: + tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG] mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index 228bc0dac..8afc66553 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -7,56 +7,19 @@ There are two parts including first_train and update_online_pred. Firstly, we will finish the training and set the trained models to the `online` models. Next, we will finish updating online predictions. """ +import copy import fire import qlib from qlib.config import REG_CN from qlib.model.trainer import task_train from qlib.workflow.online.utils import OnlineToolR +from qlib.tests.config import CSI300_GBDT_TASK -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": "csi100", -} +task = copy.deepcopy(CSI300_GBDT_TASK) -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - "record": { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - }, +task["record"] = { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", } diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index 5757aaa87..387d5cde7 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -4,13 +4,11 @@ import qlib import fire import pickle -import pandas as pd from datetime import datetime from qlib.config import REG_CN from qlib.data.dataset.handler import DataHandlerLP -from qlib.contrib.data.handler import Alpha158 -from qlib.utils import exists_qlib_data, init_instance_by_config +from qlib.utils import init_instance_by_config from qlib.tests.data import GetData @@ -25,9 +23,7 @@ class RollingDataWorkflow: """initialize qlib""" # use yahoo_cn_1min data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN) def _dump_pre_handler(self, path): diff --git a/examples/run_all_model.py b/examples/run_all_model.py index d587eff15..c79fee004 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -5,13 +5,11 @@ import os import sys import fire import time -import venv import glob import shutil import signal import inspect import tempfile -import traceback import functools import statistics import subprocess @@ -23,8 +21,7 @@ from pprint import pprint import qlib from qlib.config import REG_CN from qlib.workflow import R -from qlib.workflow.cli import workflow -from qlib.utils import exists_qlib_data +from qlib.tests.data import GetData # init qlib @@ -39,12 +36,8 @@ exp_manager = { "default_exp_name": "Experiment", }, } -if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) - from get_data import GetData - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) +GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager) # decorator to check the arguments diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index d5dab8917..1cdf2ac80 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -1,82 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys -from pathlib import Path - import qlib -import pandas as pd from qlib.config import REG_CN -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) -from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict +from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData +from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK + if __name__ == "__main__": # use default data provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data(target_dir=provider_uri, region=REG_CN) - + GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN) - market = "csi300" - benchmark = "SH000300" - - ################################### - # train model - ################################### - data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, - } - - task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - } - port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", @@ -90,7 +30,7 @@ if __name__ == "__main__": "verbose": False, "limit_threshold": 0.095, "account": 100000000, - "benchmark": benchmark, + "benchmark": CSI300_BENCH, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, @@ -100,8 +40,8 @@ if __name__ == "__main__": } # model initialization - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # NOTE: This line is optional # It demonstrates that the dataset can be used standalone. @@ -110,7 +50,7 @@ if __name__ == "__main__": # start exp with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) R.save_objects(**{"params.pkl": model}) diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 09313f933..8a4e137ca 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -166,7 +166,7 @@ class Position: def save_position(self, path, last_trade_date): path = pathlib.Path(path) p = copy.deepcopy(self.position) - cash = pd.Series(dtype=np.float) + cash = pd.Series(dtype=float) cash["init_cash"] = self.init_cash cash["cash"] = p["cash"] cash["today_account_value"] = p["today_account_value"] diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py index 98b9b9c2d..5138e0e6f 100644 --- a/qlib/contrib/model/catboost_model.py +++ b/qlib/contrib/model/catboost_model.py @@ -10,9 +10,10 @@ from catboost.utils import get_gpu_device_count from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import FeatureInt -class CatBoostModel(Model): +class CatBoostModel(Model, FeatureInt): """CatBoost Model""" def __init__(self, loss="RMSE", **kwargs): @@ -69,6 +70,18 @@ class CatBoostModel(Model): x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) return pd.Series(self.model.predict(x_test.values), index=x_test.index) + def get_feature_importance(self, *args, **kwargs) -> pd.Series: + """get feature importance + + Notes + ----- + parameters references: + https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance + """ + return pd.Series( + data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_ + ).sort_values(ascending=False) + if __name__ == "__main__": cat = CatBoostModel() diff --git a/qlib/contrib/model/double_ensemble.py b/qlib/contrib/model/double_ensemble.py index 4b267a2b0..d3ca898f8 100644 --- a/qlib/contrib/model/double_ensemble.py +++ b/qlib/contrib/model/double_ensemble.py @@ -1,251 +1,265 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import lightgbm as lgb -import numpy as np -import pandas as pd -from typing import Text, Union -from ...model.base import Model -from ...data.dataset import DatasetH -from ...data.dataset.handler import DataHandlerLP -from ...log import get_module_logger - - -class DEnsembleModel(Model): - """Double Ensemble Model""" - - def __init__( - self, - base_model="gbm", - loss="mse", - num_models=6, - enable_sr=True, - enable_fs=True, - alpha1=1.0, - alpha2=1.0, - bins_sr=10, - bins_fs=5, - decay=None, - sample_ratios=None, - sub_weights=None, - epochs=100, - **kwargs - ): - self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" - self.num_models = num_models # the number of sub-models - self.enable_sr = enable_sr - self.enable_fs = enable_fs - self.alpha1 = alpha1 - self.alpha2 = alpha2 - self.bins_sr = bins_sr - self.bins_fs = bins_fs - self.decay = decay - if sample_ratios is None: # the default values for sample_ratios - sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4] - if sub_weights is None: # the default values for sub_weights - sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2] - if not len(sample_ratios) == bins_fs: - raise ValueError("The length of sample_ratios should be equal to bins_fs.") - self.sample_ratios = sample_ratios - if not len(sub_weights) == num_models: - raise ValueError("The length of sub_weights should be equal to num_models.") - self.sub_weights = sub_weights - self.epochs = epochs - self.logger = get_module_logger("DEnsembleModel") - self.logger.info("Double Ensemble Model...") - self.ensemble = [] # the current ensemble model, a list contains all the sub-models - self.sub_features = [] # the features for each sub model in the form of pandas.Index - self.params = {"objective": loss} - self.params.update(kwargs) - self.loss = loss - - def fit(self, dataset: DatasetH): - df_train, df_valid = dataset.prepare( - ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L - ) - x_train, y_train = df_train["feature"], df_train["label"] - # initialize the sample weights - N, F = x_train.shape - weights = pd.Series(np.ones(N, dtype=float)) - # initialize the features - features = x_train.columns - pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index) - # train sub-models - for k in range(self.num_models): - self.sub_features.append(features) - self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models)) - model_k = self.train_submodel(df_train, df_valid, weights, features) - self.ensemble.append(model_k) - # no further sample re-weight and feature selection needed for the last sub-model - if k + 1 == self.num_models: - break - - self.logger.info("Retrieving loss curve and loss values...") - loss_curve = self.retrieve_loss_curve(model_k, df_train, features) - pred_k = self.predict_sub(model_k, df_train, features) - pred_sub.iloc[:, k] = pred_k - pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1) - loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values)) - - if self.enable_sr: - self.logger.info("Sample re-weighting...") - weights = self.sample_reweight(loss_curve, loss_values, k + 1) - - if self.enable_fs: - self.logger.info("Feature selection...") - features = self.feature_selection(df_train, loss_values) - - def train_submodel(self, df_train, df_valid, weights, features): - dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features) - evals_result = dict() - model = lgb.train( - self.params, - dtrain, - num_boost_round=self.epochs, - valid_sets=[dtrain, dvalid], - valid_names=["train", "valid"], - verbose_eval=20, - evals_result=evals_result, - ) - evals_result["train"] = list(evals_result["train"].values())[0] - evals_result["valid"] = list(evals_result["valid"].values())[0] - return model - - def _prepare_data_gbm(self, df_train, df_valid, weights, features): - x_train, y_train = df_train["feature"].loc[:, features], df_train["label"] - x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"] - - # Lightgbm need 1D array as its label - if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: - y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values) - else: - raise ValueError("LightGBM doesn't support multi-label training") - - dtrain = lgb.Dataset(x_train.values, label=y_train, weight=weights) - dvalid = lgb.Dataset(x_valid.values, label=y_valid) - return dtrain, dvalid - - def sample_reweight(self, loss_curve, loss_values, k_th): - """ - the SR module of Double Ensemble - :param loss_curve: the shape is NxT - the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample - after the t-th iteration in the training of the previous sub-model. - :param loss_values: the shape is N - the loss of the current ensemble on the i-th sample. - :param k_th: the index of the current sub-model, starting from 1 - :return: weights - the weights for all the samples. - """ - # normalize loss_curve and loss_values with ranking - loss_curve_norm = loss_curve.rank(axis=0, pct=True) - loss_values_norm = (-loss_values).rank(pct=True) - - # calculate l_start and l_end from loss_curve - N, T = loss_curve.shape - part = np.maximum(int(T * 0.1), 1) - l_start = loss_curve_norm.iloc[:, :part].mean(axis=1) - l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1) - - # calculate h-value for each sample - h1 = loss_values_norm - h2 = (l_end / l_start).rank(pct=True) - h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2}) - - # calculate weights - h["bins"] = pd.cut(h["h_value"], self.bins_sr) - h_avg = h.groupby("bins")["h_value"].mean() - weights = pd.Series(np.zeros(N, dtype=float)) - for i_b, b in enumerate(h_avg.index): - weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1) - return weights - - def feature_selection(self, df_train, loss_values): - """ - the FS module of Double Ensemble - :param df_train: the shape is NxF - :param loss_values: the shape is N - the loss of the current ensemble on the i-th sample. - :return: res_feat: in the form of pandas.Index - - """ - x_train, y_train = df_train["feature"], df_train["label"] - features = x_train.columns - N, F = x_train.shape - g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)}) - M = len(self.ensemble) - - # shuffle specific columns and calculate g-value for each feature - x_train_tmp = x_train.copy() - for i_f, feat in enumerate(features): - x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values) - pred = pd.Series(np.zeros(N), index=x_train_tmp.index) - for i_s, submodel in enumerate(self.ensemble): - pred += ( - pd.Series( - submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index - ) - / M - ) - loss_feat = self.get_loss(y_train.values.squeeze(), pred.values) - g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7) - x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy() - - # one column in train features is all-nan # if g['g_value'].isna().any() - g["g_value"].replace(np.nan, 0, inplace=True) - - # divide features into bins_fs bins - g["bins"] = pd.cut(g["g_value"], self.bins_fs) - - # randomly sample features from bins to construct the new features - res_feat = [] - sorted_bins = sorted(g["bins"].unique(), reverse=True) - for i_b, b in enumerate(sorted_bins): - b_feat = features[g["bins"] == b] - num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat))) - res_feat = res_feat + np.random.choice(b_feat, size=num_feat).tolist() - return pd.Index(res_feat) - - def get_loss(self, label, pred): - if self.loss == "mse": - return (label - pred) ** 2 - else: - raise ValueError("not implemented yet") - - def retrieve_loss_curve(self, model, df_train, features): - if self.base_model == "gbm": - num_trees = model.num_trees() - x_train, y_train = df_train["feature"].loc[:, features], df_train["label"] - # Lightgbm need 1D array as its label - if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: - y_train = np.squeeze(y_train.values) - else: - raise ValueError("LightGBM doesn't support multi-label training") - - N = x_train.shape[0] - loss_curve = pd.DataFrame(np.zeros((N, num_trees))) - pred_tree = np.zeros(N, dtype=float) - for i_tree in range(num_trees): - pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1) - loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree) - else: - raise ValueError("not implemented yet") - return loss_curve - - def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): - if self.ensemble is None: - raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) - pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index) - for i_sub, submodel in enumerate(self.ensemble): - feat_sub = self.sub_features[i_sub] - pred += ( - pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index) - * self.sub_weights[i_sub] - ) - return pred - - def predict_sub(self, submodel, df_data, features): - x_data, y_data = df_data["feature"].loc[:, features], df_data["label"] - pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index) - return pred_sub +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import lightgbm as lgb +import numpy as np +import pandas as pd +from typing import Text, Union +from ...model.base import Model +from ...data.dataset import DatasetH +from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import FeatureInt +from ...log import get_module_logger + + +class DEnsembleModel(Model, FeatureInt): + """Double Ensemble Model""" + + def __init__( + self, + base_model="gbm", + loss="mse", + num_models=6, + enable_sr=True, + enable_fs=True, + alpha1=1.0, + alpha2=1.0, + bins_sr=10, + bins_fs=5, + decay=None, + sample_ratios=None, + sub_weights=None, + epochs=100, + **kwargs + ): + self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" + self.num_models = num_models # the number of sub-models + self.enable_sr = enable_sr + self.enable_fs = enable_fs + self.alpha1 = alpha1 + self.alpha2 = alpha2 + self.bins_sr = bins_sr + self.bins_fs = bins_fs + self.decay = decay + if sample_ratios is None: # the default values for sample_ratios + sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4] + if sub_weights is None: # the default values for sub_weights + sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2] + if not len(sample_ratios) == bins_fs: + raise ValueError("The length of sample_ratios should be equal to bins_fs.") + self.sample_ratios = sample_ratios + if not len(sub_weights) == num_models: + raise ValueError("The length of sub_weights should be equal to num_models.") + self.sub_weights = sub_weights + self.epochs = epochs + self.logger = get_module_logger("DEnsembleModel") + self.logger.info("Double Ensemble Model...") + self.ensemble = [] # the current ensemble model, a list contains all the sub-models + self.sub_features = [] # the features for each sub model in the form of pandas.Index + self.params = {"objective": loss} + self.params.update(kwargs) + self.loss = loss + + def fit(self, dataset: DatasetH): + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + x_train, y_train = df_train["feature"], df_train["label"] + # initialize the sample weights + N, F = x_train.shape + weights = pd.Series(np.ones(N, dtype=float)) + # initialize the features + features = x_train.columns + pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index) + # train sub-models + for k in range(self.num_models): + self.sub_features.append(features) + self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models)) + model_k = self.train_submodel(df_train, df_valid, weights, features) + self.ensemble.append(model_k) + # no further sample re-weight and feature selection needed for the last sub-model + if k + 1 == self.num_models: + break + + self.logger.info("Retrieving loss curve and loss values...") + loss_curve = self.retrieve_loss_curve(model_k, df_train, features) + pred_k = self.predict_sub(model_k, df_train, features) + pred_sub.iloc[:, k] = pred_k + pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1) + loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values)) + + if self.enable_sr: + self.logger.info("Sample re-weighting...") + weights = self.sample_reweight(loss_curve, loss_values, k + 1) + + if self.enable_fs: + self.logger.info("Feature selection...") + features = self.feature_selection(df_train, loss_values) + + def train_submodel(self, df_train, df_valid, weights, features): + dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features) + evals_result = dict() + model = lgb.train( + self.params, + dtrain, + num_boost_round=self.epochs, + valid_sets=[dtrain, dvalid], + valid_names=["train", "valid"], + verbose_eval=20, + evals_result=evals_result, + ) + evals_result["train"] = list(evals_result["train"].values())[0] + evals_result["valid"] = list(evals_result["valid"].values())[0] + return model + + def _prepare_data_gbm(self, df_train, df_valid, weights, features): + x_train, y_train = df_train["feature"].loc[:, features], df_train["label"] + x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"] + + # Lightgbm need 1D array as its label + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values) + else: + raise ValueError("LightGBM doesn't support multi-label training") + + dtrain = lgb.Dataset(x_train, label=y_train, weight=weights) + dvalid = lgb.Dataset(x_valid, label=y_valid) + return dtrain, dvalid + + def sample_reweight(self, loss_curve, loss_values, k_th): + """ + the SR module of Double Ensemble + :param loss_curve: the shape is NxT + the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample + after the t-th iteration in the training of the previous sub-model. + :param loss_values: the shape is N + the loss of the current ensemble on the i-th sample. + :param k_th: the index of the current sub-model, starting from 1 + :return: weights + the weights for all the samples. + """ + # normalize loss_curve and loss_values with ranking + loss_curve_norm = loss_curve.rank(axis=0, pct=True) + loss_values_norm = (-loss_values).rank(pct=True) + + # calculate l_start and l_end from loss_curve + N, T = loss_curve.shape + part = np.maximum(int(T * 0.1), 1) + l_start = loss_curve_norm.iloc[:, :part].mean(axis=1) + l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1) + + # calculate h-value for each sample + h1 = loss_values_norm + h2 = (l_end / l_start).rank(pct=True) + h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2}) + + # calculate weights + h["bins"] = pd.cut(h["h_value"], self.bins_sr) + h_avg = h.groupby("bins")["h_value"].mean() + weights = pd.Series(np.zeros(N, dtype=float)) + for i_b, b in enumerate(h_avg.index): + weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1) + return weights + + def feature_selection(self, df_train, loss_values): + """ + the FS module of Double Ensemble + :param df_train: the shape is NxF + :param loss_values: the shape is N + the loss of the current ensemble on the i-th sample. + :return: res_feat: in the form of pandas.Index + + """ + x_train, y_train = df_train["feature"], df_train["label"] + features = x_train.columns + N, F = x_train.shape + g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)}) + M = len(self.ensemble) + + # shuffle specific columns and calculate g-value for each feature + x_train_tmp = x_train.copy() + for i_f, feat in enumerate(features): + x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values) + pred = pd.Series(np.zeros(N), index=x_train_tmp.index) + for i_s, submodel in enumerate(self.ensemble): + pred += ( + pd.Series( + submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index + ) + / M + ) + loss_feat = self.get_loss(y_train.values.squeeze(), pred.values) + g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7) + x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy() + + # one column in train features is all-nan # if g['g_value'].isna().any() + g["g_value"].replace(np.nan, 0, inplace=True) + + # divide features into bins_fs bins + g["bins"] = pd.cut(g["g_value"], self.bins_fs) + + # randomly sample features from bins to construct the new features + res_feat = [] + sorted_bins = sorted(g["bins"].unique(), reverse=True) + for i_b, b in enumerate(sorted_bins): + b_feat = features[g["bins"] == b] + num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat))) + res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist() + return pd.Index(set(res_feat)) + + def get_loss(self, label, pred): + if self.loss == "mse": + return (label - pred) ** 2 + else: + raise ValueError("not implemented yet") + + def retrieve_loss_curve(self, model, df_train, features): + if self.base_model == "gbm": + num_trees = model.num_trees() + x_train, y_train = df_train["feature"].loc[:, features], df_train["label"] + # Lightgbm need 1D array as its label + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + y_train = np.squeeze(y_train.values) + else: + raise ValueError("LightGBM doesn't support multi-label training") + + N = x_train.shape[0] + loss_curve = pd.DataFrame(np.zeros((N, num_trees))) + pred_tree = np.zeros(N, dtype=float) + for i_tree in range(num_trees): + pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1) + loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree) + else: + raise ValueError("not implemented yet") + return loss_curve + + def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): + if self.ensemble is None: + raise ValueError("model is not fitted yet!") + x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index) + for i_sub, submodel in enumerate(self.ensemble): + feat_sub = self.sub_features[i_sub] + pred += ( + pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index) + * self.sub_weights[i_sub] + ) + return pred + + def predict_sub(self, submodel, df_data, features): + x_data, y_data = df_data["feature"].loc[:, features], df_data["label"] + pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index) + return pred_sub + + def get_feature_importance(self, *args, **kwargs) -> pd.Series: + """get feature importance + + Notes + ----- + parameters reference: + https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance + """ + res = [] + for _model, _weight in zip(self.ensemble, self.sub_weights): + res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight) + return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False) diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index 463cf8f4f..1a7cf7fba 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -8,9 +8,10 @@ from typing import Text, Union from ...model.base import ModelFT from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import LightGBMFInt -class LGBModel(ModelFT): +class LGBModel(ModelFT, LightGBMFInt): """LightGBM Model""" def __init__(self, loss="mse", **kwargs): @@ -33,8 +34,8 @@ class LGBModel(ModelFT): else: raise ValueError("LightGBM doesn't support multi-label training") - dtrain = lgb.Dataset(x_train.values, label=y_train) - dvalid = lgb.Dataset(x_valid.values, label=y_valid) + dtrain = lgb.Dataset(x_train, label=y_train) + dvalid = lgb.Dataset(x_valid, label=y_valid) return dtrain, dvalid def fit( diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index 5a2eeb50a..04d6ab9d5 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -1,17 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import warnings import numpy as np import pandas as pd import lightgbm as lgb -from qlib.model.base import ModelFT -from qlib.data.dataset import DatasetH -from qlib.data.dataset.handler import DataHandlerLP -import warnings +from ...model.base import ModelFT +from ...data.dataset import DatasetH +from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import LightGBMFInt -class HFLGBModel(ModelFT): +class HFLGBModel(ModelFT, LightGBMFInt): """LightGBM Model for high frequency prediction""" def __init__(self, loss="mse", **kwargs): @@ -97,8 +98,8 @@ class HFLGBModel(ModelFT): else: raise ValueError("LightGBM doesn't support multi-label training") - dtrain = lgb.Dataset(x_train.values, label=y_train) - dvalid = lgb.Dataset(x_valid.values, label=y_valid) + dtrain = lgb.Dataset(x_train, label=y_train) + dvalid = lgb.Dataset(x_valid, label=y_valid) return dtrain, dvalid def fit( diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py index cbba14678..2a38f4fe1 100755 --- a/qlib/contrib/model/xgboost.py +++ b/qlib/contrib/model/xgboost.py @@ -8,9 +8,10 @@ from typing import Text, Union from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP +from ...model.interpret.base import FeatureInt -class XGBModel(Model): +class XGBModel(Model, FeatureInt): """XGBModel Model""" def __init__(self, **kwargs): @@ -42,8 +43,8 @@ class XGBModel(Model): else: raise ValueError("XGBoost doesn't support multi-label training") - dtrain = xgb.DMatrix(x_train.values, label=y_train_1d) - dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d) + dtrain = xgb.DMatrix(x_train, label=y_train_1d) + dvalid = xgb.DMatrix(x_valid, label=y_valid_1d) self.model = xgb.train( self._params, dtrain=dtrain, @@ -62,3 +63,13 @@ class XGBModel(Model): raise ValueError("model is not fitted yet!") x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index) + + def get_feature_importance(self, *args, **kwargs) -> pd.Series: + """get feature importance + + Notes + ------- + parameters reference: + https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.get_score + """ + return pd.Series(self.model.get_score(*args, **kwargs)).sort_values(ascending=False) diff --git a/qlib/data/data.py b/qlib/data/data.py index c2638e234..eb7fbe0ea 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -6,7 +6,9 @@ from __future__ import division from __future__ import print_function import os +import re import abc +import copy import time import queue import bisect @@ -27,12 +29,41 @@ from .cache import DiskDatasetCache, DiskExpressionCache from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path -class CalendarProvider(abc.ABC): +class ProviderBackendMixin: + def get_default_backend(self): + backend = {} + provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2] + # set default storage class + backend.setdefault("class", f"File{provider_name}Storage") + # set default storage module + backend.setdefault("module_path", "qlib.data.storage.file_storage") + return backend + + def backend_obj(self, **kwargs): + backend = self.backend if self.backend else self.get_default_backend() + backend = copy.deepcopy(backend) + + # set default storage kwargs + backend_kwargs = backend.setdefault("kwargs", {}) + # default provider_uri map + if "provider_uri" not in backend_kwargs: + # if the user has no uri configured, use: uri = uri_map[freq] + freq = kwargs.get("freq", "day") + provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()}) + backend_kwargs["provider_uri"] = provider_uri_map[freq] + backend.setdefault("kwargs", {}).update(**kwargs) + return init_instance_by_config(backend) + + +class CalendarProvider(abc.ABC, ProviderBackendMixin): """Calendar provider base class Provide calendar data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + @abc.abstractmethod def calendar(self, start_time=None, end_time=None, freq="day", future=False): """Get calendar of certain market in given time range. @@ -127,12 +158,15 @@ class CalendarProvider(abc.ABC): return hash_args(start_time, end_time, freq, future) -class InstrumentProvider(abc.ABC): +class InstrumentProvider(abc.ABC, ProviderBackendMixin): """Instrument provider base class Provide instrument data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + @staticmethod def instruments(market="all", filter_pipe=None): """Get the general config dictionary for a base market adding several dynamic filters. @@ -215,12 +249,15 @@ class InstrumentProvider(abc.ABC): raise ValueError(f"Unknown instrument type {inst}") -class FeatureProvider(abc.ABC): +class FeatureProvider(abc.ABC, ProviderBackendMixin): """Feature provider class Provide feature data. """ + def __init__(self, *args, **kwargs): + self.backend = kwargs.get("backend", {}) + @abc.abstractmethod def feature(self, instrument, field, start_time, end_time, freq): """Get feature data. @@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider): """ def __init__(self, **kwargs): + super(LocalCalendarProvider, self).__init__(**kwargs) self.remote = kwargs.get("remote", False) @property @@ -517,21 +555,22 @@ class LocalCalendarProvider(CalendarProvider): list list of timestamps """ - if future: - fname = self._uri_cal.format(freq + "_future") - # if future calendar not exists, return current calendar - if not os.path.exists(fname): - get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") + + try: + backend_obj = self.backend_obj(freq=freq, future=future).data + except ValueError: + if future: + get_module_logger("data").warning( + f"load calendar error: freq={freq}, future={future}; return current calendar!" + ) get_module_logger("data").warning( "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" ) - fname = self._uri_cal.format(freq) - else: - fname = self._uri_cal.format(freq) - if not os.path.exists(fname): - raise ValueError("calendar not exists for freq " + freq) - with open(fname) as f: - return [pd.Timestamp(x.strip()) for x in f] + backend_obj = self.backend_obj(freq=freq, future=False).data + else: + raise + + return [pd.Timestamp(x) for x in backend_obj] def calendar(self, start_time=None, end_time=None, freq="day", future=False): _calendar, _calendar_index = self._get_calendar(freq, future) @@ -562,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider): Provide instrument data from local data source. """ - def __init__(self): - pass - @property def _uri_inst(self): """Instrument file uri.""" return os.path.join(C.get_data_path(), "instruments", "{}.txt") - def _load_instruments(self, market): - fname = self._uri_inst.format(market) - if not os.path.exists(fname): - raise ValueError("instruments not exists for market " + market) - - _instruments = dict() - df = pd.read_csv( - fname, - sep="\t", - usecols=[0, 1, 2], - names=["inst", "start_datetime", "end_datetime"], - dtype={"inst": str}, - parse_dates=["start_datetime", "end_datetime"], - ) - for row in df.itertuples(index=False): - _instruments.setdefault(row[0], []).append((row[1], row[2])) - return _instruments + def _load_instruments(self, market, freq): + return self.backend_obj(market=market, freq=freq).data def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): market = instruments["market"] if market in H["i"]: _instruments = H["i"][market] else: - _instruments = self._load_instruments(market) + _instruments = self._load_instruments(market, freq=freq) H["i"][market] = _instruments # strip # use calendar boundary @@ -604,7 +625,7 @@ class LocalInstrumentProvider(InstrumentProvider): inst: list( filter( lambda x: x[0] <= x[1], - [(max(start_time, x[0]), min(end_time, x[1])) for x in spans], + [(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans], ) ) for inst, spans in _instruments.items() @@ -630,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider): """ def __init__(self, **kwargs): + super(LocalFeatureProvider, self).__init__(**kwargs) self.remote = kwargs.get("remote", False) @property @@ -641,14 +663,7 @@ class LocalFeatureProvider(FeatureProvider): # validate field = str(field).lower()[1:] instrument = code_to_fname(instrument) - uri_data = self._uri_data.format(instrument.lower(), field, freq) - if not os.path.exists(uri_data): - get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field)) - return pd.Series(dtype=np.float32) - # raise ValueError('uri_data not found: ' + uri_data) - # load - series = read_bin(uri_data, start_index, end_index) - return series + return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] class LocalExpressionProvider(ExpressionProvider): @@ -1065,7 +1080,8 @@ def register_all_wrappers(C): register_wrapper(Cal, _calendar_provider, "qlib.data") logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}") - register_wrapper(Inst, C.instrument_provider, "qlib.data") + _instrument_provider = init_instance_by_config(C.instrument_provider, module) + register_wrapper(Inst, _instrument_provider, "qlib.data") logger.debug(f"registering Inst {C.instrument_provider}") if getattr(C, "feature_provider", None) is not None: diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 206561aed..8d7786368 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -357,7 +357,7 @@ class TSDataSampler: # get the previous index of a line given index """ # object incase of pandas converting int to flaot - idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object) + idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object) idx_df = lazy_sort_index(idx_df.unstack()) # NOTE: the correctness of `__getitem__` depends on columns sorted here idx_df = lazy_sort_index(idx_df, axis=1) diff --git a/qlib/data/storage/__init__.py b/qlib/data/storage/__init__.py new file mode 100644 index 000000000..552e1e3e8 --- /dev/null +++ b/qlib/data/storage/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py new file mode 100644 index 000000000..a2b145c4d --- /dev/null +++ b/qlib/data/storage/file_storage.py @@ -0,0 +1,292 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import struct +from pathlib import Path +from typing import Iterable, Union, Dict, Mapping, Tuple, List + +import numpy as np +import pandas as pd + +from qlib.log import get_module_logger +from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT + +logger = get_module_logger("file_storage") + + +class FileStorageMixin: + @property + def uri(self) -> Path: + _provider_uri = self.kwargs.get("provider_uri", None) + if _provider_uri is None: + raise ValueError( + f"The `provider_uri` parameter is not found in {self.__class__.__name__}, " + f'please specify `provider_uri` in the "provider\'s backend"' + ) + return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name) + + def check(self): + """check self.uri + + Raises + ------- + ValueError + """ + if not self.uri.exists(): + raise ValueError(f"{self.storage_name} not exists: {self.uri}") + + +class FileCalendarStorage(FileStorageMixin, CalendarStorage): + def __init__(self, freq: str, future: bool, **kwargs): + super(FileCalendarStorage, self).__init__(freq, future, **kwargs) + self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower() + + def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]: + if not self.uri.exists(): + self._write_calendar(values=[]) + with self.uri.open("rb") as fp: + return [ + str(x) + for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8") + ] + + def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"): + with self.uri.open(mode=mode) as fp: + np.savetxt(fp, values, fmt="%s", encoding="utf-8") + + @property + def data(self) -> List[CalVT]: + self.check() + return self._read_calendar() + + def extend(self, values: Iterable[CalVT]) -> None: + self._write_calendar(values, mode="ab") + + def clear(self) -> None: + self._write_calendar(values=[]) + + def index(self, value: CalVT) -> int: + self.check() + calendar = self._read_calendar() + return int(np.argwhere(calendar == value)[0]) + + def insert(self, index: int, value: CalVT): + calendar = self._read_calendar() + calendar = np.insert(calendar, index, value) + self._write_calendar(values=calendar) + + def remove(self, value: CalVT) -> None: + self.check() + index = self.index(value) + calendar = self._read_calendar() + calendar = np.delete(calendar, index) + self._write_calendar(values=calendar) + + def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None: + calendar = self._read_calendar() + calendar[i] = values + self._write_calendar(values=calendar) + + def __delitem__(self, i: Union[int, slice]) -> None: + self.check() + calendar = self._read_calendar() + calendar = np.delete(calendar, i) + self._write_calendar(values=calendar) + + def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]: + self.check() + return self._read_calendar()[i] + + def __len__(self) -> int: + return len(self.data) + + +class FileInstrumentStorage(FileStorageMixin, InstrumentStorage): + + INSTRUMENT_SEP = "\t" + INSTRUMENT_START_FIELD = "start_datetime" + INSTRUMENT_END_FIELD = "end_datetime" + SYMBOL_FIELD_NAME = "instrument" + + def __init__(self, market: str, **kwargs): + super(FileInstrumentStorage, self).__init__(market, **kwargs) + self.file_name = f"{market.lower()}.txt" + + def _read_instrument(self) -> Dict[InstKT, InstVT]: + if not self.uri.exists(): + self._write_instrument() + + _instruments = dict() + df = pd.read_csv( + self.uri, + sep="\t", + usecols=[0, 1, 2], + names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], + dtype={self.SYMBOL_FIELD_NAME: str}, + parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], + ) + for row in df.itertuples(index=False): + _instruments.setdefault(row[0], []).append((row[1], row[2])) + return _instruments + + def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None: + if not data: + with self.uri.open("w") as _: + pass + return + + res = [] + for inst, v_list in data.items(): + _df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]) + _df[self.SYMBOL_FIELD_NAME] = inst + res.append(_df) + + df = pd.concat(res, sort=False) + df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv( + self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False + ) + df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False) + + def clear(self) -> None: + self._write_instrument(data={}) + + @property + def data(self) -> Dict[InstKT, InstVT]: + self.check() + return self._read_instrument() + + def __setitem__(self, k: InstKT, v: InstVT) -> None: + inst = self._read_instrument() + inst[k] = v + self._write_instrument(inst) + + def __delitem__(self, k: InstKT) -> None: + self.check() + inst = self._read_instrument() + del inst[k] + self._write_instrument(inst) + + def __getitem__(self, k: InstKT) -> InstVT: + self.check() + return self._read_instrument()[k] + + def update(self, *args, **kwargs) -> None: + + if len(args) > 1: + raise TypeError(f"update expected at most 1 arguments, got {len(args)}") + inst = self._read_instrument() + if args: + other = args[0] # type: dict + if isinstance(other, Mapping): + for key in other: + inst[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + inst[key] = other[key] + else: + for key, value in other: + inst[key] = value + for key, value in kwargs.items(): + inst[key] = value + + self._write_instrument(inst) + + def __len__(self) -> int: + return len(self.data) + + +class FileFeatureStorage(FileStorageMixin, FeatureStorage): + def __init__(self, instrument: str, field: str, freq: str, **kwargs): + super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs) + self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin" + + def clear(self): + with self.uri.open("wb") as _: + pass + + @property + def data(self) -> pd.Series: + return self[:] + + def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None: + if len(data_array) == 0: + logger.info( + "len(data_array) == 0, write" + "if you need to clear the FeatureStorage, please execute: FeatureStorage.clear" + ) + return + if not self.uri.exists(): + # write + index = 0 if index is None else index + with self.uri.open("wb") as fp: + np.hstack([index, data_array]).astype(" self.end_index: + # append + index = 0 if index is None else index + with self.uri.open("ab+") as fp: + np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype(" Union[int, None]: + if not self.uri.exists(): + return None + with self.uri.open("rb") as fp: + index = int(np.frombuffer(fp.read(4), dtype=" Union[int, None]: + if not self.uri.exists(): + return None + # The next data appending index point will be `end_index + 1` + return self.start_index + len(self) - 1 + + def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]: + if not self.uri.exists(): + if isinstance(i, int): + return None, None + elif isinstance(i, slice): + return pd.Series(dtype=np.float32) + else: + raise TypeError(f"type(i) = {type(i)}") + + storage_start_index = self.start_index + storage_end_index = self.end_index + with self.uri.open("rb") as fp: + if isinstance(i, int): + + if storage_start_index > i: + raise IndexError(f"{i}: start index is {storage_start_index}") + fp.seek(4 * (i - storage_start_index) + 4) + return i, struct.unpack("f", fp.read(4))[0] + elif isinstance(i, slice): + start_index = storage_start_index if i.start is None else i.start + end_index = storage_end_index if i.stop is None else i.stop - 1 + si = max(start_index, storage_start_index) + if si > end_index: + return pd.Series(dtype=np.float32) + fp.seek(4 * (si - storage_start_index) + 4) + # read n bytes + count = end_index - si + 1 + data = np.frombuffer(fp.read(4 * count), dtype=" int: + self.check() + return self.uri.stat().st_size // 4 - 1 diff --git a/qlib/data/storage/storage.py b/qlib/data/storage/storage.py new file mode 100644 index 000000000..8426ebe66 --- /dev/null +++ b/qlib/data/storage/storage.py @@ -0,0 +1,501 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import re +from typing import Iterable, overload, Tuple, List, Text, Union, Dict + +import numpy as np +import pandas as pd +from qlib.log import get_module_logger + +# calendar value type +CalVT = str + +# instrument value +InstVT = List[Tuple[CalVT, CalVT]] +# instrument key +InstKT = Text + +logger = get_module_logger("storage") + +""" +If the user is only using it in `qlib`, you can customize Storage to implement only the following methods: + +class UserCalendarStorage(CalendarStorage): + + @property + def data(self) -> Iterable[CalVT]: + '''get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + ''' + raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") + + +class UserInstrumentStorage(InstrumentStorage): + + @property + def data(self) -> Dict[InstKT, InstVT]: + '''get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + ''' + raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") + + +class UserFeatureStorage(FeatureStorage): + + def __getitem__(self, s: slice) -> pd.Series: + '''x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] + + Returns + ------- + pd.Series(values, index=pd.RangeIndex(start, len(values)) + + Notes + ------- + if data(storage) does not exist: + if isinstance(i, int): + return (None, None) + if isinstance(i, slice): + # return empty pd.Series + return pd.Series(dtype=np.float32) + ''' + raise NotImplementedError( + "Subclass of FeatureStorage must implement `__getitem__(s: slice)` method" + ) + + +""" + + +class BaseStorage: + @property + def storage_name(self) -> str: + return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2].lower() + + +class CalendarStorage(BaseStorage): + """ + The behavior of CalendarStorage's methods and List's methods of the same name remain consistent + """ + + def __init__(self, freq: str, future: bool, **kwargs): + self.freq = freq + self.future = future + self.kwargs = kwargs + + @property + def data(self) -> Iterable[CalVT]: + """get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") + + def clear(self) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method") + + def extend(self, iterable: Iterable[CalVT]) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") + + def index(self, value: CalVT) -> int: + """ + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError("Subclass of CalendarStorage must implement `index` method") + + def insert(self, index: int, value: CalVT) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method") + + def remove(self, value: CalVT) -> None: + raise NotImplementedError("Subclass of CalendarStorage must implement `remove` method") + + @overload + def __setitem__(self, i: int, value: CalVT) -> None: + """x.__setitem__(i, o) <==> (x[i] = o)""" + ... + + @overload + def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None: + """x.__setitem__(s, o) <==> (x[s] = o)""" + ... + + def __setitem__(self, i, value) -> None: + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])` method" + ) + + @overload + def __delitem__(self, i: int) -> None: + """x.__delitem__(i) <==> del x[i]""" + ... + + @overload + def __delitem__(self, i: slice) -> None: + """x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]""" + ... + + def __delitem__(self, i) -> None: + """ + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)` method" + ) + + @overload + def __getitem__(self, s: slice) -> Iterable[CalVT]: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]""" + ... + + @overload + def __getitem__(self, i: int) -> CalVT: + """x.__getitem__(i) <==> x[i]""" + ... + + def __getitem__(self, i) -> CalVT: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" + ) + + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") + + +class InstrumentStorage(BaseStorage): + def __init__(self, market: str, **kwargs): + self.market = market + self.kwargs = kwargs + + @property + def data(self) -> Dict[InstKT, InstVT]: + """get all data + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") + + def clear(self) -> None: + raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") + + def update(self, *args, **kwargs) -> None: + """D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. + + Notes + ------ + If E present and has a .keys() method, does: for k in E: D[k] = E[k] + + If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v + + In either case, this is followed by: for k, v in F.items(): D[k] = v + + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method") + + def __setitem__(self, k: InstKT, v: InstVT) -> None: + """Set self[key] to value.""" + raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") + + def __delitem__(self, k: InstKT) -> None: + """Delete self[key]. + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method") + + def __getitem__(self, k: InstKT) -> InstVT: + """x.__getitem__(k) <==> x[k]""" + raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") + + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method") + + +class FeatureStorage(BaseStorage): + def __init__(self, instrument: str, field: str, freq: str, **kwargs): + self.instrument = instrument + self.field = field + self.freq = freq + self.kwargs = kwargs + + @property + def data(self) -> pd.Series: + """get all data + + Notes + ------ + if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)` + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") + + @property + def start_index(self) -> Union[int, None]: + """get FeatureStorage start index + + Notes + ----- + If the data(storage) does not exist, return None + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `start_index` method") + + @property + def end_index(self) -> Union[int, None]: + """get FeatureStorage end index + + Notes + ----- + The right index of the data range (both sides are closed) + + The next data appending point will be `end_index + 1` + + If the data(storage) does not exist, return None + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `end_index` method") + + def clear(self) -> None: + raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") + + def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None): + """Write data_array to FeatureStorage starting from index. + + Notes + ------ + If index is None, append data_array to feature. + + If len(data_array) == 0; return + + If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan + + Examples + --------- + .. code-block:: + + feature: + 3 4 + 4 5 + 5 6 + + + >>> self.write([6, 7], index=6) + + feature: + 3 4 + 4 5 + 5 6 + 6 6 + 7 7 + + >>> self.write([8], index=9) + + feature: + 3 4 + 4 5 + 5 6 + 6 6 + 7 7 + 8 np.nan + 9 8 + + >>> self.write([1, np.nan], index=3) + + feature: + 3 1 + 4 np.nan + 5 6 + 6 6 + 7 7 + 8 np.nan + 9 8 + + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `write` method") + + def rebase(self, start_index: int = None, end_index: int = None): + """Rebase the start_index and end_index of the FeatureStorage. + + start_index and end_index are closed intervals: [start_index, end_index] + + Examples + --------- + + .. code-block:: + + feature: + 3 4 + 4 5 + 5 6 + + + >>> self.rebase(start_index=4) + + feature: + 4 5 + 5 6 + + >>> self.rebase(start_index=3) + + feature: + 3 np.nan + 4 5 + 5 6 + + >>> self.write([3], index=3) + + feature: + 3 3 + 4 5 + 5 6 + + >>> self.rebase(end_index=4) + + feature: + 3 3 + 4 5 + + >>> self.write([6, 7, 8], index=4) + + feature: + 3 3 + 4 6 + 5 7 + 6 8 + + >>> self.rebase(start_index=4, end_index=5) + + feature: + 4 6 + 5 7 + + """ + storage_si = self.start_index + storage_ei = self.end_index + if storage_si is None or storage_ei is None: + raise ValueError("storage.start_index or storage.end_index is None, storage may not exist") + + start_index = storage_si if start_index is None else start_index + end_index = storage_ei if end_index is None else end_index + + if start_index is None or end_index is None: + logger.warning("both start_index and end_index are None, or storage does not exist; rebase is ignored") + return + + if start_index < 0 or end_index < 0: + logger.warning("start_index or end_index cannot be less than 0") + return + if start_index > end_index: + logger.warning( + f"start_index({start_index}) > end_index({end_index}), rebase is ignored; " + f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear" + ) + return + + if start_index <= storage_si: + self.write([np.nan] * (storage_si - start_index), start_index) + else: + self.rewrite(self[start_index:].values, start_index) + + if end_index >= self.end_index: + self.write([np.nan] * (end_index - self.end_index)) + else: + self.rewrite(self[: end_index + 1].values, start_index) + + def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int): + """overwrite all data in FeatureStorage with data + + Parameters + ---------- + data: Union[List, np.ndarray, Tuple] + data + index: int + data start index + """ + self.clear() + self.write(data, index) + + @overload + def __getitem__(self, s: slice) -> pd.Series: + """x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step] + + Returns + ------- + pd.Series(values, index=pd.RangeIndex(start, len(values)) + """ + ... + + @overload + def __getitem__(self, i: int) -> Tuple[int, float]: + """x.__getitem__(y) <==> x[y]""" + ... + + def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]: + """x.__getitem__(y) <==> x[y] + + Notes + ------- + if data(storage) does not exist: + if isinstance(i, int): + return (None, None) + if isinstance(i, slice): + # return empty pd.Series + return pd.Series(dtype=np.float32) + """ + raise NotImplementedError( + "Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method" + ) + + def __len__(self) -> int: + """ + + Raises + ------ + ValueError + If the data(storage) does not exist, raise ValueError + + """ + raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") diff --git a/qlib/model/interpret/__init__.py b/qlib/model/interpret/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/model/interpret/base.py b/qlib/model/interpret/base.py new file mode 100644 index 000000000..57cc7929a --- /dev/null +++ b/qlib/model/interpret/base.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Interfaces to interpret models +""" + +import pandas as pd +from abc import abstractmethod + + +class FeatureInt: + """Feature (Int)erpreter""" + + @abstractmethod + def get_feature_importance(self) -> pd.Series: + """get feature importance + + Returns + ------- + The index is the feature name. + + The greater the value, the higher importance. + """ + + +class LightGBMFInt(FeatureInt): + """LightGBM (F)eature (Int)erpreter""" + + def get_feature_importance(self, *args, **kwargs) -> pd.Series: + """get feature importance + + Notes + ----- + parameters reference: + https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance + """ + return pd.Series(self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()).sort_values( + ascending=False + ) diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index f92e72787..7f43cd99a 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -1,6 +1,4 @@ -import sys import unittest -from ..utils import exists_qlib_data from .data import GetData from .. import init from ..config import REG_CN @@ -9,19 +7,18 @@ from ..config import REG_CN class TestAutoData(unittest.TestCase): _setup_kwargs = {} + provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir @classmethod def setUpClass(cls) -> None: # use default data - provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - GetData().qlib_data( - name="qlib_data_simple", - region="cn", - interval="1d", - target_dir=provider_uri, - delete_old=False, - ) - init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs) + GetData().qlib_data( + name="qlib_data_simple", + region=REG_CN, + interval="1d", + target_dir=cls.provider_uri, + delete_old=False, + exists_skip=True, + ) + init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs) diff --git a/qlib/tests/config.py b/qlib/tests/config.py new file mode 100644 index 000000000..80461f6f9 --- /dev/null +++ b/qlib/tests/config.py @@ -0,0 +1,108 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +CSI300_MARKET = "csi300" +CSI100_MARKET = "csi100" + +CSI300_BENCH = "SH000300" + +DATASET_ALPHA158_CLASS = "Alpha158" +DATASET_ALPHA360_CLASS = "Alpha360" + +################################### +# config +################################### + + +GBDT_MODEL = { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, +} + + +RECORD_CONFIG = [ + { + "class": "SignalRecord", + "module_path": "qlib.workflow.record_temp", + }, + { + "class": "SigAnaRecord", + "module_path": "qlib.workflow.record_temp", + }, +] + + +def get_data_handler_config(market=CSI300_MARKET): + return { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, + } + + +def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS): + return { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": dataset_class, + "module_path": "qlib.contrib.data.handler", + "kwargs": get_data_handler_config(market), + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + } + + +def get_gbdt_task(market=CSI300_MARKET): + return { + "model": GBDT_MODEL, + "dataset": get_dataset_config(market), + } + + +def get_record_lgb_config(market=CSI300_MARKET): + return { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + }, + "dataset": get_dataset_config(market), + "record": RECORD_CONFIG, + } + + +def get_record_xgboost_config(market=CSI300_MARKET): + return { + "model": { + "class": "XGBModel", + "module_path": "qlib.contrib.model.xgboost", + }, + "dataset": get_dataset_config(market), + "record": RECORD_CONFIG, + } + + +CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET) +CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET) + +CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET) +CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET) diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 3bf6a2c96..2bfe43590 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -10,6 +10,7 @@ import datetime from tqdm import tqdm from pathlib import Path from loguru import logger +from qlib.utils import exists_qlib_data class GetData: @@ -112,6 +113,7 @@ class GetData: interval="1d", region="cn", delete_old=True, + exists_skip=False, ): """download cn qlib data from remote @@ -129,6 +131,8 @@ class GetData: data region, value from [cn, us], by default cn delete_old: bool delete an existing directory, by default True + exists_skip: bool + exists skip, by default False Examples --------- @@ -140,6 +144,13 @@ class GetData: ------- """ + if exists_skip and exists_qlib_data(target_dir): + logger.warning( + f"Data already exists: {target_dir}, the data download will be skipped\n" + f"\tIf downloading is required: `exists_skip=False` or `change target_dir`" + ) + return + qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__)) def _get_file_name(v): diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index dbbe69d43..1e8ee2e48 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -668,7 +668,10 @@ def exists_qlib_data(qlib_dir): return False # check calendar bin for _calendar in calendars_dir.iterdir(): - if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")): + + if ("_future" not in _calendar.name) and ( + not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")) + ): return False # check instruments diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index b9fd9123c..0413f32b6 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -121,7 +121,7 @@ df = D.features(D.instruments("all"), ["$close"], freq="day") ### Help ```bash -pythono collector.py collector_data --help +python collector.py collector_data --help ``` ## Parameters diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index a6e06613e..d42ce0e4c 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -191,7 +191,7 @@ class YahooCollector(BaseCollector): class YahooCollectorCN(YahooCollector, ABC): def get_instrument_list(self): - logger.info("get HS stock symbos......") + logger.info("get HS stock symbols......") symbols = get_hs_stock_symbols() logger.info(f"get {len(symbols)} symbols.") return symbols @@ -581,7 +581,6 @@ class Run(BaseRun): delay=0, start=None, end=None, - interval="1d", check_data_length=False, limit_nums=None, ): @@ -593,8 +592,6 @@ class Run(BaseRun): default 2 delay: float time.sleep(delay), default 0 - interval: str - freq, value from [1min, 1d], default 1d start: str start datetime, default "2000-01-01" end: str @@ -611,8 +608,9 @@ class Run(BaseRun): # get 1m data $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ - - super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums) + super(Run, self).download_data( + max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums + ) def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data diff --git a/scripts/data_collector/yahoo/requirements.txt b/scripts/data_collector/yahoo/requirements.txt index 3e3e0d1e0..5f08026e5 100644 --- a/scripts/data_collector/yahoo/requirements.txt +++ b/scripts/data_collector/yahoo/requirements.txt @@ -5,5 +5,4 @@ numpy pandas tqdm lxml -loguru yahooquery diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 0b063fdda..b3a18cc90 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -120,7 +120,7 @@ class DumpDataBase: else: df = file_or_df if df.empty or self.date_field_name not in df.columns.tolist(): - _calendars = pd.Series() + _calendars = pd.Series(dtype=np.float32) else: _calendars = df[self.date_field_name] diff --git a/tests/dataset_tests/test_datalayer.py b/tests/dataset_tests/test_datalayer.py index 9d282b167..bdd0d915b 100644 --- a/tests/dataset_tests/test_datalayer.py +++ b/tests/dataset_tests/test_datalayer.py @@ -1,26 +1,10 @@ -import sys -from pathlib import Path -import qlib -from qlib.data import D -from qlib.config import REG_CN import unittest import numpy as np -from qlib.utils import exists_qlib_data +from qlib.data import D +from qlib.tests import TestAutoData -class TestDataset(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts"))) - from get_data import GetData - - GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri) - qlib.init(provider_uri=provider_uri, region=REG_CN) - +class TestDataset(TestAutoData): def testCSI300(self): close_p = D.features(D.instruments("csi300"), ["$close"]) size = close_p.groupby("datetime").size() diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py new file mode 100644 index 000000000..aad8d11e4 --- /dev/null +++ b/tests/storage_tests/test_storage.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from pathlib import Path +from collections.abc import Iterable + +import pytest +import numpy as np +from qlib.tests import TestAutoData + +from qlib.data.storage.file_storage import ( + FileCalendarStorage as CalendarStorage, + FileInstrumentStorage as InstrumentStorage, + FileFeatureStorage as FeatureStorage, +) + +_file_name = Path(__file__).name.split(".")[0] +DATA_DIR = Path(__file__).parent.joinpath(f"{_file_name}_data") +QLIB_DIR = DATA_DIR.joinpath("qlib") +QLIB_DIR.mkdir(exist_ok=True, parents=True) + + +class TestStorage(TestAutoData): + def test_calendar_storage(self): + + calendar = CalendarStorage(freq="day", future=False, provider_uri=self.provider_uri) + assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable" + assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable" + + print(f"calendar[1: 5]: {calendar[1:5]}") + print(f"calendar[0]: {calendar[0]}") + print(f"calendar[-1]: {calendar[-1]}") + + calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found") + with pytest.raises(ValueError): + print(calendar.data) + + with pytest.raises(ValueError): + print(calendar[:]) + + with pytest.raises(ValueError): + print(calendar[0]) + + def test_instrument_storage(self): + """ + The meaning of instrument, such as CSI500: + + CSI500 composition changes: + + date add remove + 2005-01-01 SH600000 + 2005-01-01 SH600001 + 2005-01-01 SH600002 + 2005-02-01 SH600003 SH600000 + 2005-02-15 SH600000 SH600002 + + Calendar: + pd.date_range(start="2020-01-01", stop="2020-03-01", freq="1D") + + Instrument: + symbol start_time end_time + SH600000 2005-01-01 2005-01-31 (2005-02-01 Last trading day) + SH600000 2005-02-15 2005-03-01 + SH600001 2005-01-01 2005-03-01 + SH600002 2005-01-01 2005-02-14 (2005-02-15 Last trading day) + SH600003 2005-02-01 2005-03-01 + + InstrumentStorage: + { + "SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)], + "SH600001": [(2005-01-01, 2005-03-01)], + "SH600002": [(2005-01-01, 2005-02-14)], + "SH600003": [(2005-02-01, 2005-03-01)], + } + + """ + + instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri) + + for inst, spans in instrument.data.items(): + assert isinstance(inst, str) and isinstance( + spans, Iterable + ), f"{instrument.__class__.__name__} value is not Iterable" + for s_e in spans: + assert ( + isinstance(s_e, tuple) and len(s_e) == 2 + ), f"{instrument.__class__.__name__}.__getitem__(k) TypeError" + + print(f"instrument['SH600000']: {instrument['SH600000']}") + + instrument = InstrumentStorage(market="csi300", provider_uri="not_found") + with pytest.raises(ValueError): + print(instrument.data) + + with pytest.raises(ValueError): + print(instrument["sSH600000"]) + + def test_feature_storage(self): + """ + Calendar: + pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D") + + Instrument: + { + "SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)], + "SH600001": [(2005-01-01, 2005-03-01)], + "SH600002": [(2005-01-01, 2005-02-14)], + "SH600003": [(2005-02-01, 2005-03-01)], + } + + Feature: + Stock data(close): + 2005-01-01 ... 2005-02-01 ... 2005-02-14 2005-02-15 ... 2005-03-01 + SH600000 1 ... 3 ... 4 5 6 + SH600001 1 ... 4 ... 5 6 7 + SH600002 1 ... 5 ... 6 nan nan + SH600003 nan ... 1 ... 2 3 4 + + FeatureStorage(SH600000, close): + + [ + (calendar.index("2005-01-01"), 1), + ..., + (calendar.index("2005-03-01"), 6) + ] + + ====> [(0, 1), ..., (59, 6)] + + + FeatureStorage(SH600002, close): + + [ + (calendar.index("2005-01-01"), 1), + ..., + (calendar.index("2005-02-14"), 6) + ] + + ===> [(0, 1), ..., (44, 6)] + + FeatureStorage(SH600003, close): + + [ + (calendar.index("2005-02-01"), 1), + ..., + (calendar.index("2005-03-01"), 4) + ] + + ===> [(31, 1), ..., (59, 4)] + + """ + + feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri) + + with pytest.raises(IndexError): + print(feature[0]) + assert isinstance( + feature[815][1], (float, np.float32) + ), f"{feature.__class__.__name__}.__getitem__(i: int) error" + assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error" + print(f"feature[815: 818]: \n{feature[815: 818]}") + + print(f"feature[:].tail(): \n{feature[:].tail()}") + + feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount") + + assert feature[0] == (None, None), "FeatureStorage does not exist, feature[i] should return `(None, None)`" + assert feature[:].empty, "FeatureStorage does not exist, feature[:] should return `pd.Series(dtype=np.float32)`" + assert ( + feature.data.empty + ), "FeatureStorage does not exist, feature.data should return `pd.Series(dtype=np.float32)`" diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index d34c1773a..4c20405fa 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -12,55 +12,7 @@ from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord from qlib.tests import TestAutoData - - -market = "csi300" -benchmark = "SH000300" - -################################### -# train model -################################### -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} - -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} +from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH port_analysis_config = { "strategy": { @@ -75,7 +27,7 @@ port_analysis_config = { "verbose": False, "limit_threshold": 0.095, "account": 100000000, - "benchmark": benchmark, + "benchmark": CSI300_BENCH, "deal_price": "close", "open_cost": 0.0005, "close_cost": 0.0015, @@ -96,15 +48,15 @@ def train(): """ # model initiaiton - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # To test __repr__ print(dataset) print(R) # start exp with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) # prediction @@ -137,12 +89,12 @@ def train_with_sigana(): performance: dict model performance """ - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) # start exp with R.start(experiment_name="workflow_with_sigana"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) # predict and calculate ic and ric @@ -171,7 +123,7 @@ def fake_experiment(): default_uri = R.get_uri() current_uri = "file:./temp-test-exp-mag" with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) current_uri_to_check = R.get_uri() default_uri_to_check = R.get_uri() diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py index ccd3c6a90..9b1edbd4e 100644 --- a/tests/test_contrib_workflow.py +++ b/tests/test_contrib_workflow.py @@ -1,73 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import shutil import unittest from pathlib import Path -import qlib -from qlib.config import C from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.tests import TestAutoData - - -market = "csi300" -benchmark = "SH000300" - -################################### -# train model -################################### -data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, -} - -task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, -} +from qlib.tests.config import CSI300_GBDT_TASK def train_multiseg(): - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = MultiSegRecord(model, dataset, recorder) @@ -77,10 +26,10 @@ def train_multiseg(): def train_mse(): - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) with R.start(experiment_name="workflow"): - R.log_params(**flatten_dict(task)) + R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) recorder = R.get_recorder() sr = SignalMseRecord(recorder, model=model, dataset=dataset) diff --git a/tests/test_get_data.py b/tests/test_get_data.py index c511d1b91..93a852f55 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -1,16 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import shutil import unittest from pathlib import Path -sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) -from get_data import GetData - import qlib from qlib.data import D +from qlib.tests.data import GetData DATA_DIR = Path(__file__).parent.joinpath("test_get_data") SOURCE_DIR = DATA_DIR.joinpath("source") @@ -37,7 +34,9 @@ class TestGetData(unittest.TestCase): def test_0_qlib_data(self): - GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False) + GetData().qlib_data( + name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False, exists_skip=True + ) df = D.features(D.instruments("csi300"), self.FIELDS) self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed") self.assertFalse(df.dropna().empty, "get qlib data failed") diff --git a/tests/test_register_ops.py b/tests/test_register_ops.py index 7d3322ddc..ac86be59c 100644 --- a/tests/test_register_ops.py +++ b/tests/test_register_ops.py @@ -1,17 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys import unittest import numpy as np -import qlib from qlib.data import D from qlib.data.ops import ElemOperator, PairOperator -from qlib.config import REG_CN -from qlib.utils import exists_qlib_data from qlib.tests import TestAutoData -from qlib.tests.data import GetData class Diff(ElemOperator):