1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Merge remote-tracking branch 'microsoft/main' into online_srv

This commit is contained in:
lzh222333
2021-06-01 08:29:02 +00:00
43 changed files with 1685 additions and 889 deletions

View File

@@ -72,12 +72,19 @@ Converting CSV Format into Qlib Format
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
Here are some example:
.. code-block:: bash
for daily data:
.. code-block:: bash
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
for 1min data:
.. code-block:: bash
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
@@ -145,6 +152,16 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
Stock Pool (Market)
--------------------------------
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
.. code-block:: bash
python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments
Multiple Stock Modes
--------------------------------

View File

@@ -101,7 +101,7 @@ Graphical Result
- Axis Y:
- `ic`
The `Pearson correlation coefficient` series between `label` and `prediction score`.
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue <data.html#feature>`_ for more details.
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
- `rank_ic`
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.

View File

@@ -111,8 +111,6 @@ Usage & Example
pred_score, strategy=strategy, **BACKTEST_CONFIG
)
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.

View File

@@ -53,6 +53,34 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:
Storage
-------------
.. autoclass:: qlib.data.storage.storage.BaseStorage
:members:
.. autoclass:: qlib.data.storage.storage.CalendarStorage
:members:
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
:members:
.. autoclass:: qlib.data.storage.storage.FeatureStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
:members:
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
:members:
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
:members:
Dataset
---------------

View File

@@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
- Override the `finetune` method (Optional)
- This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
- The parameters must include the parameter `dataset`.
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
.. code-block:: Python

View File

@@ -1,24 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import fire
from pathlib import Path
import qlib
import pickle
import numpy as np
import pandas as pd
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import init_instance_by_config, exists_qlib_data
from qlib.utils import init_instance_by_config
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.ops import Operators
from qlib.data.data import Cal
@@ -96,9 +85,7 @@ class HighfreqWorkflow:
# use yahoo_cn_1min data
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
qlib.init(**QLIB_INIT_CONFIG)
def _prepare_calender_cache(self):

View File

@@ -1,46 +1,9 @@
import qlib
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
import optuna
provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(scripts_dir))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
qlib.init(provider_uri=provider_uri, region="cn")
market = "csi300"
benchmark = "SH000300"
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
dataset_task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(dataset_task["dataset"])
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.config import CSI300_DATASET_CONFIG
from qlib.tests.data import GetData
def objective(trial):
@@ -65,12 +28,19 @@ def objective(trial):
},
},
}
evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region="cn")
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)

View File

@@ -1,46 +1,11 @@
import qlib
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
import optuna
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(scripts_dir))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
qlib.init(provider_uri=provider_uri, region="cn")
market = "csi300"
benchmark = "SH000300"
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
dataset_task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(dataset_task["dataset"])
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
def objective(trial):
@@ -72,5 +37,13 @@ def objective(trial):
return min(evals_result["valid"])
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
if __name__ == "__main__":
provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
dataset = init_instance_by_config(DATASET_CONFIG)
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)

View File

@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_GBDT_TASK
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
###################################
# train model
###################################
# model initialization
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
model.fit(dataset)
# get model feature importance
feature_importance = model.get_feature_importance()
print("feature importance:")
print(feature_importance)

View File

@@ -17,64 +17,8 @@ from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM, task_train
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
from qlib.model.trainer import TrainerRM
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
class RollingTaskExample:
@@ -86,11 +30,13 @@ class RollingTaskExample:
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_config=[task_xgboost_config, task_lgb_config],
task_config=None,
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
):
# TaskManager config
if task_config is None:
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,

View File

@@ -13,63 +13,7 @@ from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
data_handler_config = {
"start_time": "2018-01-01",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
class OnlineSimulationExample:
@@ -84,7 +28,7 @@ class OnlineSimulationExample:
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=[task_xgboost_config, task_lgb_config],
tasks=None,
):
"""
Init OnlineManagerExample.
@@ -101,6 +45,8 @@ class OnlineSimulationExample:
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time

View File

@@ -18,63 +18,7 @@ from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.task.manage import TaskManager, run_task
data_handler_config = {
"start_time": "2013-01-01",
"end_time": "2020-09-25",
"fit_start_time": "2013-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2013-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2015-12-31"),
"test": ("2016-01-01", "2020-07-10"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
class RollingOnlineExample:
@@ -86,9 +30,13 @@ class RollingOnlineExample:
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
rolling_step=550,
tasks=[task_xgboost_config],
add_tasks=[task_lgb_config],
tasks=None,
add_tasks=None,
):
if add_tasks is None:
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name

View File

@@ -7,56 +7,19 @@ There are two parts including first_train and update_online_pred.
Firstly, we will finish the training and set the trained models to the `online` models.
Next, we will finish updating online predictions.
"""
import copy
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.online.utils import OnlineToolR
from qlib.tests.config import CSI300_GBDT_TASK
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
task = copy.deepcopy(CSI300_GBDT_TASK)
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
"record": {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
}

View File

@@ -4,13 +4,11 @@
import qlib
import fire
import pickle
import pandas as pd
from datetime import datetime
from qlib.config import REG_CN
from qlib.data.dataset.handler import DataHandlerLP
from qlib.contrib.data.handler import Alpha158
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
@@ -25,9 +23,7 @@ class RollingDataWorkflow:
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def _dump_pre_handler(self, path):

View File

@@ -5,13 +5,11 @@ import os
import sys
import fire
import time
import venv
import glob
import shutil
import signal
import inspect
import tempfile
import traceback
import functools
import statistics
import subprocess
@@ -23,8 +21,7 @@ from pprint import pprint
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.cli import workflow
from qlib.utils import exists_qlib_data
from qlib.tests.data import GetData
# init qlib
@@ -39,12 +36,8 @@ exp_manager = {
"default_exp_name": "Experiment",
},
}
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
# decorator to check the arguments

View File

@@ -1,82 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
@@ -90,7 +30,7 @@ if __name__ == "__main__":
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": benchmark,
"benchmark": CSI300_BENCH,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
@@ -100,8 +40,8 @@ if __name__ == "__main__":
}
# model initialization
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
# NOTE: This line is optional
# It demonstrates that the dataset can be used standalone.
@@ -110,7 +50,7 @@ if __name__ == "__main__":
# start exp
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})

View File

@@ -166,7 +166,7 @@ class Position:
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series(dtype=np.float)
cash = pd.Series(dtype=float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]

View File

@@ -10,9 +10,10 @@ from catboost.utils import get_gpu_device_count
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import FeatureInt
class CatBoostModel(Model):
class CatBoostModel(Model, FeatureInt):
"""CatBoost Model"""
def __init__(self, loss="RMSE", **kwargs):
@@ -69,6 +70,18 @@ class CatBoostModel(Model):
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-----
parameters references:
https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance
"""
return pd.Series(
data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_
).sort_values(ascending=False)
if __name__ == "__main__":
cat = CatBoostModel()

View File

@@ -1,251 +1,265 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import lightgbm as lgb
import numpy as np
import pandas as pd
from typing import Text, Union
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...log import get_module_logger
class DEnsembleModel(Model):
"""Double Ensemble Model"""
def __init__(
self,
base_model="gbm",
loss="mse",
num_models=6,
enable_sr=True,
enable_fs=True,
alpha1=1.0,
alpha2=1.0,
bins_sr=10,
bins_fs=5,
decay=None,
sample_ratios=None,
sub_weights=None,
epochs=100,
**kwargs
):
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
self.num_models = num_models # the number of sub-models
self.enable_sr = enable_sr
self.enable_fs = enable_fs
self.alpha1 = alpha1
self.alpha2 = alpha2
self.bins_sr = bins_sr
self.bins_fs = bins_fs
self.decay = decay
if sample_ratios is None: # the default values for sample_ratios
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
if sub_weights is None: # the default values for sub_weights
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
if not len(sample_ratios) == bins_fs:
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
self.sample_ratios = sample_ratios
if not len(sub_weights) == num_models:
raise ValueError("The length of sub_weights should be equal to num_models.")
self.sub_weights = sub_weights
self.epochs = epochs
self.logger = get_module_logger("DEnsembleModel")
self.logger.info("Double Ensemble Model...")
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
self.sub_features = [] # the features for each sub model in the form of pandas.Index
self.params = {"objective": loss}
self.params.update(kwargs)
self.loss = loss
def fit(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
# initialize the sample weights
N, F = x_train.shape
weights = pd.Series(np.ones(N, dtype=float))
# initialize the features
features = x_train.columns
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
# train sub-models
for k in range(self.num_models):
self.sub_features.append(features)
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
model_k = self.train_submodel(df_train, df_valid, weights, features)
self.ensemble.append(model_k)
# no further sample re-weight and feature selection needed for the last sub-model
if k + 1 == self.num_models:
break
self.logger.info("Retrieving loss curve and loss values...")
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
pred_k = self.predict_sub(model_k, df_train, features)
pred_sub.iloc[:, k] = pred_k
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
if self.enable_sr:
self.logger.info("Sample re-weighting...")
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
if self.enable_fs:
self.logger.info("Feature selection...")
features = self.feature_selection(df_train, loss_values)
def train_submodel(self, df_train, df_valid, weights, features):
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
evals_result = dict()
model = lgb.train(
self.params,
dtrain,
num_boost_round=self.epochs,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
verbose_eval=20,
evals_result=evals_result,
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
return model
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train.values, label=y_train, weight=weights)
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
return dtrain, dvalid
def sample_reweight(self, loss_curve, loss_values, k_th):
"""
the SR module of Double Ensemble
:param loss_curve: the shape is NxT
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
after the t-th iteration in the training of the previous sub-model.
:param loss_values: the shape is N
the loss of the current ensemble on the i-th sample.
:param k_th: the index of the current sub-model, starting from 1
:return: weights
the weights for all the samples.
"""
# normalize loss_curve and loss_values with ranking
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
loss_values_norm = (-loss_values).rank(pct=True)
# calculate l_start and l_end from loss_curve
N, T = loss_curve.shape
part = np.maximum(int(T * 0.1), 1)
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
# calculate h-value for each sample
h1 = loss_values_norm
h2 = (l_end / l_start).rank(pct=True)
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
# calculate weights
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
h_avg = h.groupby("bins")["h_value"].mean()
weights = pd.Series(np.zeros(N, dtype=float))
for i_b, b in enumerate(h_avg.index):
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
return weights
def feature_selection(self, df_train, loss_values):
"""
the FS module of Double Ensemble
:param df_train: the shape is NxF
:param loss_values: the shape is N
the loss of the current ensemble on the i-th sample.
:return: res_feat: in the form of pandas.Index
"""
x_train, y_train = df_train["feature"], df_train["label"]
features = x_train.columns
N, F = x_train.shape
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
M = len(self.ensemble)
# shuffle specific columns and calculate g-value for each feature
x_train_tmp = x_train.copy()
for i_f, feat in enumerate(features):
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
for i_s, submodel in enumerate(self.ensemble):
pred += (
pd.Series(
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
)
/ M
)
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
# one column in train features is all-nan # if g['g_value'].isna().any()
g["g_value"].replace(np.nan, 0, inplace=True)
# divide features into bins_fs bins
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
# randomly sample features from bins to construct the new features
res_feat = []
sorted_bins = sorted(g["bins"].unique(), reverse=True)
for i_b, b in enumerate(sorted_bins):
b_feat = features[g["bins"] == b]
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
res_feat = res_feat + np.random.choice(b_feat, size=num_feat).tolist()
return pd.Index(res_feat)
def get_loss(self, label, pred):
if self.loss == "mse":
return (label - pred) ** 2
else:
raise ValueError("not implemented yet")
def retrieve_loss_curve(self, model, df_train, features):
if self.base_model == "gbm":
num_trees = model.num_trees()
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train = np.squeeze(y_train.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
N = x_train.shape[0]
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
pred_tree = np.zeros(N, dtype=float)
for i_tree in range(num_trees):
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
else:
raise ValueError("not implemented yet")
return loss_curve
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.ensemble is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
for i_sub, submodel in enumerate(self.ensemble):
feat_sub = self.sub_features[i_sub]
pred += (
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
* self.sub_weights[i_sub]
)
return pred
def predict_sub(self, submodel, df_data, features):
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
return pred_sub
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import lightgbm as lgb
import numpy as np
import pandas as pd
from typing import Text, Union
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import FeatureInt
from ...log import get_module_logger
class DEnsembleModel(Model, FeatureInt):
"""Double Ensemble Model"""
def __init__(
self,
base_model="gbm",
loss="mse",
num_models=6,
enable_sr=True,
enable_fs=True,
alpha1=1.0,
alpha2=1.0,
bins_sr=10,
bins_fs=5,
decay=None,
sample_ratios=None,
sub_weights=None,
epochs=100,
**kwargs
):
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
self.num_models = num_models # the number of sub-models
self.enable_sr = enable_sr
self.enable_fs = enable_fs
self.alpha1 = alpha1
self.alpha2 = alpha2
self.bins_sr = bins_sr
self.bins_fs = bins_fs
self.decay = decay
if sample_ratios is None: # the default values for sample_ratios
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
if sub_weights is None: # the default values for sub_weights
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
if not len(sample_ratios) == bins_fs:
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
self.sample_ratios = sample_ratios
if not len(sub_weights) == num_models:
raise ValueError("The length of sub_weights should be equal to num_models.")
self.sub_weights = sub_weights
self.epochs = epochs
self.logger = get_module_logger("DEnsembleModel")
self.logger.info("Double Ensemble Model...")
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
self.sub_features = [] # the features for each sub model in the form of pandas.Index
self.params = {"objective": loss}
self.params.update(kwargs)
self.loss = loss
def fit(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
# initialize the sample weights
N, F = x_train.shape
weights = pd.Series(np.ones(N, dtype=float))
# initialize the features
features = x_train.columns
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
# train sub-models
for k in range(self.num_models):
self.sub_features.append(features)
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
model_k = self.train_submodel(df_train, df_valid, weights, features)
self.ensemble.append(model_k)
# no further sample re-weight and feature selection needed for the last sub-model
if k + 1 == self.num_models:
break
self.logger.info("Retrieving loss curve and loss values...")
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
pred_k = self.predict_sub(model_k, df_train, features)
pred_sub.iloc[:, k] = pred_k
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
if self.enable_sr:
self.logger.info("Sample re-weighting...")
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
if self.enable_fs:
self.logger.info("Feature selection...")
features = self.feature_selection(df_train, loss_values)
def train_submodel(self, df_train, df_valid, weights, features):
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
evals_result = dict()
model = lgb.train(
self.params,
dtrain,
num_boost_round=self.epochs,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
verbose_eval=20,
evals_result=evals_result,
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
return model
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train, label=y_train, weight=weights)
dvalid = lgb.Dataset(x_valid, label=y_valid)
return dtrain, dvalid
def sample_reweight(self, loss_curve, loss_values, k_th):
"""
the SR module of Double Ensemble
:param loss_curve: the shape is NxT
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
after the t-th iteration in the training of the previous sub-model.
:param loss_values: the shape is N
the loss of the current ensemble on the i-th sample.
:param k_th: the index of the current sub-model, starting from 1
:return: weights
the weights for all the samples.
"""
# normalize loss_curve and loss_values with ranking
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
loss_values_norm = (-loss_values).rank(pct=True)
# calculate l_start and l_end from loss_curve
N, T = loss_curve.shape
part = np.maximum(int(T * 0.1), 1)
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
# calculate h-value for each sample
h1 = loss_values_norm
h2 = (l_end / l_start).rank(pct=True)
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
# calculate weights
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
h_avg = h.groupby("bins")["h_value"].mean()
weights = pd.Series(np.zeros(N, dtype=float))
for i_b, b in enumerate(h_avg.index):
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
return weights
def feature_selection(self, df_train, loss_values):
"""
the FS module of Double Ensemble
:param df_train: the shape is NxF
:param loss_values: the shape is N
the loss of the current ensemble on the i-th sample.
:return: res_feat: in the form of pandas.Index
"""
x_train, y_train = df_train["feature"], df_train["label"]
features = x_train.columns
N, F = x_train.shape
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
M = len(self.ensemble)
# shuffle specific columns and calculate g-value for each feature
x_train_tmp = x_train.copy()
for i_f, feat in enumerate(features):
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
for i_s, submodel in enumerate(self.ensemble):
pred += (
pd.Series(
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
)
/ M
)
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
# one column in train features is all-nan # if g['g_value'].isna().any()
g["g_value"].replace(np.nan, 0, inplace=True)
# divide features into bins_fs bins
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
# randomly sample features from bins to construct the new features
res_feat = []
sorted_bins = sorted(g["bins"].unique(), reverse=True)
for i_b, b in enumerate(sorted_bins):
b_feat = features[g["bins"] == b]
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist()
return pd.Index(set(res_feat))
def get_loss(self, label, pred):
if self.loss == "mse":
return (label - pred) ** 2
else:
raise ValueError("not implemented yet")
def retrieve_loss_curve(self, model, df_train, features):
if self.base_model == "gbm":
num_trees = model.num_trees()
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train = np.squeeze(y_train.values)
else:
raise ValueError("LightGBM doesn't support multi-label training")
N = x_train.shape[0]
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
pred_tree = np.zeros(N, dtype=float)
for i_tree in range(num_trees):
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
else:
raise ValueError("not implemented yet")
return loss_curve
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.ensemble is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
for i_sub, submodel in enumerate(self.ensemble):
feat_sub = self.sub_features[i_sub]
pred += (
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
* self.sub_weights[i_sub]
)
return pred
def predict_sub(self, submodel, df_data, features):
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
return pred_sub
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-----
parameters reference:
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
"""
res = []
for _model, _weight in zip(self.ensemble, self.sub_weights):
res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight)
return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False)

View File

@@ -8,9 +8,10 @@ from typing import Text, Union
from ...model.base import ModelFT
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import LightGBMFInt
class LGBModel(ModelFT):
class LGBModel(ModelFT, LightGBMFInt):
"""LightGBM Model"""
def __init__(self, loss="mse", **kwargs):
@@ -33,8 +34,8 @@ class LGBModel(ModelFT):
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train.values, label=y_train)
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
dtrain = lgb.Dataset(x_train, label=y_train)
dvalid = lgb.Dataset(x_valid, label=y_valid)
return dtrain, dvalid
def fit(

View File

@@ -1,17 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import warnings
import numpy as np
import pandas as pd
import lightgbm as lgb
from qlib.model.base import ModelFT
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
import warnings
from ...model.base import ModelFT
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import LightGBMFInt
class HFLGBModel(ModelFT):
class HFLGBModel(ModelFT, LightGBMFInt):
"""LightGBM Model for high frequency prediction"""
def __init__(self, loss="mse", **kwargs):
@@ -97,8 +98,8 @@ class HFLGBModel(ModelFT):
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train.values, label=y_train)
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
dtrain = lgb.Dataset(x_train, label=y_train)
dvalid = lgb.Dataset(x_valid, label=y_valid)
return dtrain, dvalid
def fit(

View File

@@ -8,9 +8,10 @@ from typing import Text, Union
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import FeatureInt
class XGBModel(Model):
class XGBModel(Model, FeatureInt):
"""XGBModel Model"""
def __init__(self, **kwargs):
@@ -42,8 +43,8 @@ class XGBModel(Model):
else:
raise ValueError("XGBoost doesn't support multi-label training")
dtrain = xgb.DMatrix(x_train.values, label=y_train_1d)
dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d)
dtrain = xgb.DMatrix(x_train, label=y_train_1d)
dvalid = xgb.DMatrix(x_valid, label=y_valid_1d)
self.model = xgb.train(
self._params,
dtrain=dtrain,
@@ -62,3 +63,13 @@ class XGBModel(Model):
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-------
parameters reference:
https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.get_score
"""
return pd.Series(self.model.get_score(*args, **kwargs)).sort_values(ascending=False)

View File

@@ -6,7 +6,9 @@ from __future__ import division
from __future__ import print_function
import os
import re
import abc
import copy
import time
import queue
import bisect
@@ -27,12 +29,41 @@ from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
class CalendarProvider(abc.ABC):
class ProviderBackendMixin:
def get_default_backend(self):
backend = {}
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
# set default storage class
backend.setdefault("class", f"File{provider_name}Storage")
# set default storage module
backend.setdefault("module_path", "qlib.data.storage.file_storage")
return backend
def backend_obj(self, **kwargs):
backend = self.backend if self.backend else self.get_default_backend()
backend = copy.deepcopy(backend)
# set default storage kwargs
backend_kwargs = backend.setdefault("kwargs", {})
# default provider_uri map
if "provider_uri" not in backend_kwargs:
# if the user has no uri configured, use: uri = uri_map[freq]
freq = kwargs.get("freq", "day")
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
backend_kwargs["provider_uri"] = provider_uri_map[freq]
backend.setdefault("kwargs", {}).update(**kwargs)
return init_instance_by_config(backend)
class CalendarProvider(abc.ABC, ProviderBackendMixin):
"""Calendar provider base class
Provide calendar data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
"""Get calendar of certain market in given time range.
@@ -127,12 +158,15 @@ class CalendarProvider(abc.ABC):
return hash_args(start_time, end_time, freq, future)
class InstrumentProvider(abc.ABC):
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
"""Instrument provider base class
Provide instrument data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@staticmethod
def instruments(market="all", filter_pipe=None):
"""Get the general config dictionary for a base market adding several dynamic filters.
@@ -215,12 +249,15 @@ class InstrumentProvider(abc.ABC):
raise ValueError(f"Unknown instrument type {inst}")
class FeatureProvider(abc.ABC):
class FeatureProvider(abc.ABC, ProviderBackendMixin):
"""Feature provider class
Provide feature data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
"""Get feature data.
@@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider):
"""
def __init__(self, **kwargs):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
@@ -517,21 +555,22 @@ class LocalCalendarProvider(CalendarProvider):
list
list of timestamps
"""
if future:
fname = self._uri_cal.format(freq + "_future")
# if future calendar not exists, return current calendar
if not os.path.exists(fname):
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
try:
backend_obj = self.backend_obj(freq=freq, future=future).data
except ValueError:
if future:
get_module_logger("data").warning(
f"load calendar error: freq={freq}, future={future}; return current calendar!"
)
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
fname = self._uri_cal.format(freq)
else:
fname = self._uri_cal.format(freq)
if not os.path.exists(fname):
raise ValueError("calendar not exists for freq " + freq)
with open(fname) as f:
return [pd.Timestamp(x.strip()) for x in f]
backend_obj = self.backend_obj(freq=freq, future=False).data
else:
raise
return [pd.Timestamp(x) for x in backend_obj]
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
_calendar, _calendar_index = self._get_calendar(freq, future)
@@ -562,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider):
Provide instrument data from local data source.
"""
def __init__(self):
pass
@property
def _uri_inst(self):
"""Instrument file uri."""
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
def _load_instruments(self, market):
fname = self._uri_inst.format(market)
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)
_instruments = dict()
df = pd.read_csv(
fname,
sep="\t",
usecols=[0, 1, 2],
names=["inst", "start_datetime", "end_datetime"],
dtype={"inst": str},
parse_dates=["start_datetime", "end_datetime"],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
market = instruments["market"]
if market in H["i"]:
_instruments = H["i"][market]
else:
_instruments = self._load_instruments(market)
_instruments = self._load_instruments(market, freq=freq)
H["i"][market] = _instruments
# strip
# use calendar boundary
@@ -604,7 +625,7 @@ class LocalInstrumentProvider(InstrumentProvider):
inst: list(
filter(
lambda x: x[0] <= x[1],
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
)
)
for inst, spans in _instruments.items()
@@ -630,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider):
"""
def __init__(self, **kwargs):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
@property
@@ -641,14 +663,7 @@ class LocalFeatureProvider(FeatureProvider):
# validate
field = str(field).lower()[1:]
instrument = code_to_fname(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
return pd.Series(dtype=np.float32)
# raise ValueError('uri_data not found: ' + uri_data)
# load
series = read_bin(uri_data, start_index, end_index)
return series
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
class LocalExpressionProvider(ExpressionProvider):
@@ -1065,7 +1080,8 @@ def register_all_wrappers(C):
register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
register_wrapper(Inst, C.instrument_provider, "qlib.data")
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
register_wrapper(Inst, _instrument_provider, "qlib.data")
logger.debug(f"registering Inst {C.instrument_provider}")
if getattr(C, "feature_provider", None) is not None:

View File

@@ -357,7 +357,7 @@ class TSDataSampler:
# get the previous index of a line given index
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
idx_df = lazy_sort_index(idx_df.unstack())
# NOTE: the correctness of `__getitem__` depends on columns sorted here
idx_df = lazy_sort_index(idx_df, axis=1)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT

View File

@@ -0,0 +1,292 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import struct
from pathlib import Path
from typing import Iterable, Union, Dict, Mapping, Tuple, List
import numpy as np
import pandas as pd
from qlib.log import get_module_logger
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
logger = get_module_logger("file_storage")
class FileStorageMixin:
@property
def uri(self) -> Path:
_provider_uri = self.kwargs.get("provider_uri", None)
if _provider_uri is None:
raise ValueError(
f"The `provider_uri` parameter is not found in {self.__class__.__name__}, "
f'please specify `provider_uri` in the "provider\'s backend"'
)
return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name)
def check(self):
"""check self.uri
Raises
-------
ValueError
"""
if not self.uri.exists():
raise ValueError(f"{self.storage_name} not exists: {self.uri}")
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
def __init__(self, freq: str, future: bool, **kwargs):
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower()
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
if not self.uri.exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return [
str(x)
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
]
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
with self.uri.open(mode=mode) as fp:
np.savetxt(fp, values, fmt="%s", encoding="utf-8")
@property
def data(self) -> List[CalVT]:
self.check()
return self._read_calendar()
def extend(self, values: Iterable[CalVT]) -> None:
self._write_calendar(values, mode="ab")
def clear(self) -> None:
self._write_calendar(values=[])
def index(self, value: CalVT) -> int:
self.check()
calendar = self._read_calendar()
return int(np.argwhere(calendar == value)[0])
def insert(self, index: int, value: CalVT):
calendar = self._read_calendar()
calendar = np.insert(calendar, index, value)
self._write_calendar(values=calendar)
def remove(self, value: CalVT) -> None:
self.check()
index = self.index(value)
calendar = self._read_calendar()
calendar = np.delete(calendar, index)
self._write_calendar(values=calendar)
def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None:
calendar = self._read_calendar()
calendar[i] = values
self._write_calendar(values=calendar)
def __delitem__(self, i: Union[int, slice]) -> None:
self.check()
calendar = self._read_calendar()
calendar = np.delete(calendar, i)
self._write_calendar(values=calendar)
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:
self.check()
return self._read_calendar()[i]
def __len__(self) -> int:
return len(self.data)
class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
INSTRUMENT_SEP = "\t"
INSTRUMENT_START_FIELD = "start_datetime"
INSTRUMENT_END_FIELD = "end_datetime"
SYMBOL_FIELD_NAME = "instrument"
def __init__(self, market: str, **kwargs):
super(FileInstrumentStorage, self).__init__(market, **kwargs)
self.file_name = f"{market.lower()}.txt"
def _read_instrument(self) -> Dict[InstKT, InstVT]:
if not self.uri.exists():
self._write_instrument()
_instruments = dict()
df = pd.read_csv(
self.uri,
sep="\t",
usecols=[0, 1, 2],
names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
dtype={self.SYMBOL_FIELD_NAME: str},
parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
)
for row in df.itertuples(index=False):
_instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None:
if not data:
with self.uri.open("w") as _:
pass
return
res = []
for inst, v_list in data.items():
_df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD])
_df[self.SYMBOL_FIELD_NAME] = inst
res.append(_df)
df = pd.concat(res, sort=False)
df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv(
self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False
)
df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False)
def clear(self) -> None:
self._write_instrument(data={})
@property
def data(self) -> Dict[InstKT, InstVT]:
self.check()
return self._read_instrument()
def __setitem__(self, k: InstKT, v: InstVT) -> None:
inst = self._read_instrument()
inst[k] = v
self._write_instrument(inst)
def __delitem__(self, k: InstKT) -> None:
self.check()
inst = self._read_instrument()
del inst[k]
self._write_instrument(inst)
def __getitem__(self, k: InstKT) -> InstVT:
self.check()
return self._read_instrument()[k]
def update(self, *args, **kwargs) -> None:
if len(args) > 1:
raise TypeError(f"update expected at most 1 arguments, got {len(args)}")
inst = self._read_instrument()
if args:
other = args[0] # type: dict
if isinstance(other, Mapping):
for key in other:
inst[key] = other[key]
elif hasattr(other, "keys"):
for key in other.keys():
inst[key] = other[key]
else:
for key, value in other:
inst[key] = value
for key, value in kwargs.items():
inst[key] = value
self._write_instrument(inst)
def __len__(self) -> int:
return len(self.data)
class FileFeatureStorage(FileStorageMixin, FeatureStorage):
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
def clear(self):
with self.uri.open("wb") as _:
pass
@property
def data(self) -> pd.Series:
return self[:]
def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:
if len(data_array) == 0:
logger.info(
"len(data_array) == 0, write"
"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
)
return
if not self.uri.exists():
# write
index = 0 if index is None else index
with self.uri.open("wb") as fp:
np.hstack([index, data_array]).astype("<f").tofile(fp)
else:
if index is None or index > self.end_index:
# append
index = 0 if index is None else index
with self.uri.open("ab+") as fp:
np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype("<f").tofile(fp)
else:
# rewrite
with self.uri.open("rb+") as fp:
_old_data = np.fromfile(fp, dtype="<f")
_old_index = _old_data[0]
_old_df = pd.DataFrame(
_old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=["old"]
)
fp.seek(0)
_new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=["new"])
_df = pd.concat([_old_df, _new_df], sort=False, axis=1)
_df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))
_df["new"].fillna(_df["old"]).values.astype("<f").tofile(fp)
@property
def start_index(self) -> Union[int, None]:
if not self.uri.exists():
return None
with self.uri.open("rb") as fp:
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
return index
@property
def end_index(self) -> Union[int, None]:
if not self.uri.exists():
return None
# The next data appending index point will be `end_index + 1`
return self.start_index + len(self) - 1
def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
if not self.uri.exists():
if isinstance(i, int):
return None, None
elif isinstance(i, slice):
return pd.Series(dtype=np.float32)
else:
raise TypeError(f"type(i) = {type(i)}")
storage_start_index = self.start_index
storage_end_index = self.end_index
with self.uri.open("rb") as fp:
if isinstance(i, int):
if storage_start_index > i:
raise IndexError(f"{i}: start index is {storage_start_index}")
fp.seek(4 * (i - storage_start_index) + 4)
return i, struct.unpack("f", fp.read(4))[0]
elif isinstance(i, slice):
start_index = storage_start_index if i.start is None else i.start
end_index = storage_end_index if i.stop is None else i.stop - 1
si = max(start_index, storage_start_index)
if si > end_index:
return pd.Series(dtype=np.float32)
fp.seek(4 * (si - storage_start_index) + 4)
# read n bytes
count = end_index - si + 1
data = np.frombuffer(fp.read(4 * count), dtype="<f")
return pd.Series(data, index=pd.RangeIndex(si, si + len(data)))
else:
raise TypeError(f"type(i) = {type(i)}")
def __len__(self) -> int:
self.check()
return self.uri.stat().st_size // 4 - 1

View File

@@ -0,0 +1,501 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
from typing import Iterable, overload, Tuple, List, Text, Union, Dict
import numpy as np
import pandas as pd
from qlib.log import get_module_logger
# calendar value type
CalVT = str
# instrument value
InstVT = List[Tuple[CalVT, CalVT]]
# instrument key
InstKT = Text
logger = get_module_logger("storage")
"""
If the user is only using it in `qlib`, you can customize Storage to implement only the following methods:
class UserCalendarStorage(CalendarStorage):
@property
def data(self) -> Iterable[CalVT]:
'''get all data
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
'''
raise NotImplementedError("Subclass of CalendarStorage must implement `data` method")
class UserInstrumentStorage(InstrumentStorage):
@property
def data(self) -> Dict[InstKT, InstVT]:
'''get all data
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
'''
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
class UserFeatureStorage(FeatureStorage):
def __getitem__(self, s: slice) -> pd.Series:
'''x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]
Returns
-------
pd.Series(values, index=pd.RangeIndex(start, len(values))
Notes
-------
if data(storage) does not exist:
if isinstance(i, int):
return (None, None)
if isinstance(i, slice):
# return empty pd.Series
return pd.Series(dtype=np.float32)
'''
raise NotImplementedError(
"Subclass of FeatureStorage must implement `__getitem__(s: slice)` method"
)
"""
class BaseStorage:
@property
def storage_name(self) -> str:
return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2].lower()
class CalendarStorage(BaseStorage):
"""
The behavior of CalendarStorage's methods and List's methods of the same name remain consistent
"""
def __init__(self, freq: str, future: bool, **kwargs):
self.freq = freq
self.future = future
self.kwargs = kwargs
@property
def data(self) -> Iterable[CalVT]:
"""get all data
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of CalendarStorage must implement `data` method")
def clear(self) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method")
def extend(self, iterable: Iterable[CalVT]) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method")
def index(self, value: CalVT) -> int:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of CalendarStorage must implement `index` method")
def insert(self, index: int, value: CalVT) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method")
def remove(self, value: CalVT) -> None:
raise NotImplementedError("Subclass of CalendarStorage must implement `remove` method")
@overload
def __setitem__(self, i: int, value: CalVT) -> None:
"""x.__setitem__(i, o) <==> (x[i] = o)"""
...
@overload
def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None:
"""x.__setitem__(s, o) <==> (x[s] = o)"""
...
def __setitem__(self, i, value) -> None:
raise NotImplementedError(
"Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])` method"
)
@overload
def __delitem__(self, i: int) -> None:
"""x.__delitem__(i) <==> del x[i]"""
...
@overload
def __delitem__(self, i: slice) -> None:
"""x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]"""
...
def __delitem__(self, i) -> None:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError(
"Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)` method"
)
@overload
def __getitem__(self, s: slice) -> Iterable[CalVT]:
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]"""
...
@overload
def __getitem__(self, i: int) -> CalVT:
"""x.__getitem__(i) <==> x[i]"""
...
def __getitem__(self, i) -> CalVT:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError(
"Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
)
def __len__(self) -> int:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method")
class InstrumentStorage(BaseStorage):
def __init__(self, market: str, **kwargs):
self.market = market
self.kwargs = kwargs
@property
def data(self) -> Dict[InstKT, InstVT]:
"""get all data
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
def clear(self) -> None:
raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method")
def update(self, *args, **kwargs) -> None:
"""D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
Notes
------
If E present and has a .keys() method, does: for k in E: D[k] = E[k]
If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v
In either case, this is followed by: for k, v in F.items(): D[k] = v
"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method")
def __setitem__(self, k: InstKT, v: InstVT) -> None:
"""Set self[key] to value."""
raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method")
def __delitem__(self, k: InstKT) -> None:
"""Delete self[key].
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method")
def __getitem__(self, k: InstKT) -> InstVT:
"""x.__getitem__(k) <==> x[k]"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method")
def __len__(self) -> int:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method")
class FeatureStorage(BaseStorage):
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
self.instrument = instrument
self.field = field
self.freq = freq
self.kwargs = kwargs
@property
def data(self) -> pd.Series:
"""get all data
Notes
------
if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)`
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `data` method")
@property
def start_index(self) -> Union[int, None]:
"""get FeatureStorage start index
Notes
-----
If the data(storage) does not exist, return None
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `start_index` method")
@property
def end_index(self) -> Union[int, None]:
"""get FeatureStorage end index
Notes
-----
The right index of the data range (both sides are closed)
The next data appending point will be `end_index + 1`
If the data(storage) does not exist, return None
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `end_index` method")
def clear(self) -> None:
raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method")
def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None):
"""Write data_array to FeatureStorage starting from index.
Notes
------
If index is None, append data_array to feature.
If len(data_array) == 0; return
If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan
Examples
---------
.. code-block::
feature:
3 4
4 5
5 6
>>> self.write([6, 7], index=6)
feature:
3 4
4 5
5 6
6 6
7 7
>>> self.write([8], index=9)
feature:
3 4
4 5
5 6
6 6
7 7
8 np.nan
9 8
>>> self.write([1, np.nan], index=3)
feature:
3 1
4 np.nan
5 6
6 6
7 7
8 np.nan
9 8
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `write` method")
def rebase(self, start_index: int = None, end_index: int = None):
"""Rebase the start_index and end_index of the FeatureStorage.
start_index and end_index are closed intervals: [start_index, end_index]
Examples
---------
.. code-block::
feature:
3 4
4 5
5 6
>>> self.rebase(start_index=4)
feature:
4 5
5 6
>>> self.rebase(start_index=3)
feature:
3 np.nan
4 5
5 6
>>> self.write([3], index=3)
feature:
3 3
4 5
5 6
>>> self.rebase(end_index=4)
feature:
3 3
4 5
>>> self.write([6, 7, 8], index=4)
feature:
3 3
4 6
5 7
6 8
>>> self.rebase(start_index=4, end_index=5)
feature:
4 6
5 7
"""
storage_si = self.start_index
storage_ei = self.end_index
if storage_si is None or storage_ei is None:
raise ValueError("storage.start_index or storage.end_index is None, storage may not exist")
start_index = storage_si if start_index is None else start_index
end_index = storage_ei if end_index is None else end_index
if start_index is None or end_index is None:
logger.warning("both start_index and end_index are None, or storage does not exist; rebase is ignored")
return
if start_index < 0 or end_index < 0:
logger.warning("start_index or end_index cannot be less than 0")
return
if start_index > end_index:
logger.warning(
f"start_index({start_index}) > end_index({end_index}), rebase is ignored; "
f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
)
return
if start_index <= storage_si:
self.write([np.nan] * (storage_si - start_index), start_index)
else:
self.rewrite(self[start_index:].values, start_index)
if end_index >= self.end_index:
self.write([np.nan] * (end_index - self.end_index))
else:
self.rewrite(self[: end_index + 1].values, start_index)
def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int):
"""overwrite all data in FeatureStorage with data
Parameters
----------
data: Union[List, np.ndarray, Tuple]
data
index: int
data start index
"""
self.clear()
self.write(data, index)
@overload
def __getitem__(self, s: slice) -> pd.Series:
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]
Returns
-------
pd.Series(values, index=pd.RangeIndex(start, len(values))
"""
...
@overload
def __getitem__(self, i: int) -> Tuple[int, float]:
"""x.__getitem__(y) <==> x[y]"""
...
def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]:
"""x.__getitem__(y) <==> x[y]
Notes
-------
if data(storage) does not exist:
if isinstance(i, int):
return (None, None)
if isinstance(i, slice):
# return empty pd.Series
return pd.Series(dtype=np.float32)
"""
raise NotImplementedError(
"Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
)
def __len__(self) -> int:
"""
Raises
------
ValueError
If the data(storage) does not exist, raise ValueError
"""
raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method")

View File

View File

@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Interfaces to interpret models
"""
import pandas as pd
from abc import abstractmethod
class FeatureInt:
"""Feature (Int)erpreter"""
@abstractmethod
def get_feature_importance(self) -> pd.Series:
"""get feature importance
Returns
-------
The index is the feature name.
The greater the value, the higher importance.
"""
class LightGBMFInt(FeatureInt):
"""LightGBM (F)eature (Int)erpreter"""
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
Notes
-----
parameters reference:
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
"""
return pd.Series(self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()).sort_values(
ascending=False
)

View File

@@ -1,6 +1,4 @@
import sys
import unittest
from ..utils import exists_qlib_data
from .data import GetData
from .. import init
from ..config import REG_CN
@@ -9,19 +7,18 @@ from ..config import REG_CN
class TestAutoData(unittest.TestCase):
_setup_kwargs = {}
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
@classmethod
def setUpClass(cls) -> None:
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(
name="qlib_data_simple",
region="cn",
interval="1d",
target_dir=provider_uri,
delete_old=False,
)
init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs)
GetData().qlib_data(
name="qlib_data_simple",
region=REG_CN,
interval="1d",
target_dir=cls.provider_uri,
delete_old=False,
exists_skip=True,
)
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)

108
qlib/tests/config.py Normal file
View File

@@ -0,0 +1,108 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
CSI300_MARKET = "csi300"
CSI100_MARKET = "csi100"
CSI300_BENCH = "SH000300"
DATASET_ALPHA158_CLASS = "Alpha158"
DATASET_ALPHA360_CLASS = "Alpha360"
###################################
# config
###################################
GBDT_MODEL = {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
}
RECORD_CONFIG = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
def get_data_handler_config(market=CSI300_MARKET):
return {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
return {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": dataset_class,
"module_path": "qlib.contrib.data.handler",
"kwargs": get_data_handler_config(market),
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}
def get_gbdt_task(market=CSI300_MARKET):
return {
"model": GBDT_MODEL,
"dataset": get_dataset_config(market),
}
def get_record_lgb_config(market=CSI300_MARKET):
return {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": get_dataset_config(market),
"record": RECORD_CONFIG,
}
def get_record_xgboost_config(market=CSI300_MARKET):
return {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": get_dataset_config(market),
"record": RECORD_CONFIG,
}
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)

View File

@@ -10,6 +10,7 @@ import datetime
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from qlib.utils import exists_qlib_data
class GetData:
@@ -112,6 +113,7 @@ class GetData:
interval="1d",
region="cn",
delete_old=True,
exists_skip=False,
):
"""download cn qlib data from remote
@@ -129,6 +131,8 @@ class GetData:
data region, value from [cn, us], by default cn
delete_old: bool
delete an existing directory, by default True
exists_skip: bool
exists skip, by default False
Examples
---------
@@ -140,6 +144,13 @@ class GetData:
-------
"""
if exists_skip and exists_qlib_data(target_dir):
logger.warning(
f"Data already exists: {target_dir}, the data download will be skipped\n"
f"\tIf downloading is required: `exists_skip=False` or `change target_dir`"
)
return
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))
def _get_file_name(v):

View File

@@ -668,7 +668,10 @@ def exists_qlib_data(qlib_dir):
return False
# check calendar bin
for _calendar in calendars_dir.iterdir():
if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")):
if ("_future" not in _calendar.name) and (
not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin"))
):
return False
# check instruments

View File

@@ -121,7 +121,7 @@ df = D.features(D.instruments("all"), ["$close"], freq="day")
### Help
```bash
pythono collector.py collector_data --help
python collector.py collector_data --help
```
## Parameters

View File

@@ -191,7 +191,7 @@ class YahooCollector(BaseCollector):
class YahooCollectorCN(YahooCollector, ABC):
def get_instrument_list(self):
logger.info("get HS stock symbos......")
logger.info("get HS stock symbols......")
symbols = get_hs_stock_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols
@@ -581,7 +581,6 @@ class Run(BaseRun):
delay=0,
start=None,
end=None,
interval="1d",
check_data_length=False,
limit_nums=None,
):
@@ -593,8 +592,6 @@ class Run(BaseRun):
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1min, 1d], default 1d
start: str
start datetime, default "2000-01-01"
end: str
@@ -611,8 +608,9 @@ class Run(BaseRun):
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
super(Run, self).download_data(
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
)
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
"""normalize data

View File

@@ -5,5 +5,4 @@ numpy
pandas
tqdm
lxml
loguru
yahooquery

View File

@@ -120,7 +120,7 @@ class DumpDataBase:
else:
df = file_or_df
if df.empty or self.date_field_name not in df.columns.tolist():
_calendars = pd.Series()
_calendars = pd.Series(dtype=np.float32)
else:
_calendars = df[self.date_field_name]

View File

@@ -1,26 +1,10 @@
import sys
from pathlib import Path
import qlib
from qlib.data import D
from qlib.config import REG_CN
import unittest
import numpy as np
from qlib.utils import exists_qlib_data
from qlib.data import D
from qlib.tests import TestAutoData
class TestDataset(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
qlib.init(provider_uri=provider_uri, region=REG_CN)
class TestDataset(TestAutoData):
def testCSI300(self):
close_p = D.features(D.instruments("csi300"), ["$close"])
size = close_p.groupby("datetime").size()

View File

@@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from collections.abc import Iterable
import pytest
import numpy as np
from qlib.tests import TestAutoData
from qlib.data.storage.file_storage import (
FileCalendarStorage as CalendarStorage,
FileInstrumentStorage as InstrumentStorage,
FileFeatureStorage as FeatureStorage,
)
_file_name = Path(__file__).name.split(".")[0]
DATA_DIR = Path(__file__).parent.joinpath(f"{_file_name}_data")
QLIB_DIR = DATA_DIR.joinpath("qlib")
QLIB_DIR.mkdir(exist_ok=True, parents=True)
class TestStorage(TestAutoData):
def test_calendar_storage(self):
calendar = CalendarStorage(freq="day", future=False, provider_uri=self.provider_uri)
assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable"
assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable"
print(f"calendar[1: 5]: {calendar[1:5]}")
print(f"calendar[0]: {calendar[0]}")
print(f"calendar[-1]: {calendar[-1]}")
calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found")
with pytest.raises(ValueError):
print(calendar.data)
with pytest.raises(ValueError):
print(calendar[:])
with pytest.raises(ValueError):
print(calendar[0])
def test_instrument_storage(self):
"""
The meaning of instrument, such as CSI500:
CSI500 composition changes:
date add remove
2005-01-01 SH600000
2005-01-01 SH600001
2005-01-01 SH600002
2005-02-01 SH600003 SH600000
2005-02-15 SH600000 SH600002
Calendar:
pd.date_range(start="2020-01-01", stop="2020-03-01", freq="1D")
Instrument:
symbol start_time end_time
SH600000 2005-01-01 2005-01-31 (2005-02-01 Last trading day)
SH600000 2005-02-15 2005-03-01
SH600001 2005-01-01 2005-03-01
SH600002 2005-01-01 2005-02-14 (2005-02-15 Last trading day)
SH600003 2005-02-01 2005-03-01
InstrumentStorage:
{
"SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],
"SH600001": [(2005-01-01, 2005-03-01)],
"SH600002": [(2005-01-01, 2005-02-14)],
"SH600003": [(2005-02-01, 2005-03-01)],
}
"""
instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri)
for inst, spans in instrument.data.items():
assert isinstance(inst, str) and isinstance(
spans, Iterable
), f"{instrument.__class__.__name__} value is not Iterable"
for s_e in spans:
assert (
isinstance(s_e, tuple) and len(s_e) == 2
), f"{instrument.__class__.__name__}.__getitem__(k) TypeError"
print(f"instrument['SH600000']: {instrument['SH600000']}")
instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
with pytest.raises(ValueError):
print(instrument.data)
with pytest.raises(ValueError):
print(instrument["sSH600000"])
def test_feature_storage(self):
"""
Calendar:
pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D")
Instrument:
{
"SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],
"SH600001": [(2005-01-01, 2005-03-01)],
"SH600002": [(2005-01-01, 2005-02-14)],
"SH600003": [(2005-02-01, 2005-03-01)],
}
Feature:
Stock data(close):
2005-01-01 ... 2005-02-01 ... 2005-02-14 2005-02-15 ... 2005-03-01
SH600000 1 ... 3 ... 4 5 6
SH600001 1 ... 4 ... 5 6 7
SH600002 1 ... 5 ... 6 nan nan
SH600003 nan ... 1 ... 2 3 4
FeatureStorage(SH600000, close):
[
(calendar.index("2005-01-01"), 1),
...,
(calendar.index("2005-03-01"), 6)
]
====> [(0, 1), ..., (59, 6)]
FeatureStorage(SH600002, close):
[
(calendar.index("2005-01-01"), 1),
...,
(calendar.index("2005-02-14"), 6)
]
===> [(0, 1), ..., (44, 6)]
FeatureStorage(SH600003, close):
[
(calendar.index("2005-02-01"), 1),
...,
(calendar.index("2005-03-01"), 4)
]
===> [(31, 1), ..., (59, 4)]
"""
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri)
with pytest.raises(IndexError):
print(feature[0])
assert isinstance(
feature[815][1], (float, np.float32)
), f"{feature.__class__.__name__}.__getitem__(i: int) error"
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
print(f"feature[815: 818]: \n{feature[815: 818]}")
print(f"feature[:].tail(): \n{feature[:].tail()}")
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount")
assert feature[0] == (None, None), "FeatureStorage does not exist, feature[i] should return `(None, None)`"
assert feature[:].empty, "FeatureStorage does not exist, feature[:] should return `pd.Series(dtype=np.float32)`"
assert (
feature.data.empty
), "FeatureStorage does not exist, feature.data should return `pd.Series(dtype=np.float32)`"

View File

@@ -12,55 +12,7 @@ from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
from qlib.tests import TestAutoData
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH
port_analysis_config = {
"strategy": {
@@ -75,7 +27,7 @@ port_analysis_config = {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": benchmark,
"benchmark": CSI300_BENCH,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
@@ -96,15 +48,15 @@ def train():
"""
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
# To test __repr__
print(dataset)
print(R)
# start exp
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
# prediction
@@ -137,12 +89,12 @@ def train_with_sigana():
performance: dict
model performance
"""
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
# start exp
with R.start(experiment_name="workflow_with_sigana"):
R.log_params(**flatten_dict(task))
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
# predict and calculate ic and ric
@@ -171,7 +123,7 @@ def fake_experiment():
default_uri = R.get_uri()
current_uri = "file:./temp-test-exp-mag"
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
R.log_params(**flatten_dict(task))
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
current_uri_to_check = R.get_uri()
default_uri_to_check = R.get_uri()

View File

@@ -1,73 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import shutil
import unittest
from pathlib import Path
import qlib
from qlib.config import C
from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.tests import TestAutoData
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
from qlib.tests.config import CSI300_GBDT_TASK
def train_multiseg():
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
recorder = R.get_recorder()
sr = MultiSegRecord(model, dataset, recorder)
@@ -77,10 +26,10 @@ def train_multiseg():
def train_mse():
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
model.fit(dataset)
recorder = R.get_recorder()
sr = SignalMseRecord(recorder, model=model, dataset=dataset)

View File

@@ -1,16 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import shutil
import unittest
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
import qlib
from qlib.data import D
from qlib.tests.data import GetData
DATA_DIR = Path(__file__).parent.joinpath("test_get_data")
SOURCE_DIR = DATA_DIR.joinpath("source")
@@ -37,7 +34,9 @@ class TestGetData(unittest.TestCase):
def test_0_qlib_data(self):
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False)
GetData().qlib_data(
name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False, exists_skip=True
)
df = D.features(D.instruments("csi300"), self.FIELDS)
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
self.assertFalse(df.dropna().empty, "get qlib data failed")

View File

@@ -1,17 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import unittest
import numpy as np
import qlib
from qlib.data import D
from qlib.data.ops import ElemOperator, PairOperator
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data
from qlib.tests import TestAutoData
from qlib.tests.data import GetData
class Diff(ElemOperator):