mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Merge remote-tracking branch 'microsoft/main' into online_srv
This commit is contained in:
@@ -72,12 +72,19 @@ Converting CSV Format into Qlib Format
|
||||
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
|
||||
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
|
||||
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
|
||||
Here are some example:
|
||||
|
||||
.. code-block:: bash
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
@@ -145,6 +152,16 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
|
||||
Stock Pool (Market)
|
||||
--------------------------------
|
||||
|
||||
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments
|
||||
|
||||
|
||||
Multiple Stock Modes
|
||||
--------------------------------
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ Graphical Result
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue <data.html#feature>`_ for more details.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
@@ -111,8 +111,6 @@ Usage & Example
|
||||
pred_score, strategy=strategy, **BACKTEST_CONFIG
|
||||
)
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
|
||||
|
||||
@@ -53,6 +53,34 @@ Cache
|
||||
.. autoclass:: qlib.data.cache.DiskDatasetCache
|
||||
:members:
|
||||
|
||||
|
||||
Storage
|
||||
-------------
|
||||
.. autoclass:: qlib.data.storage.storage.BaseStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.CalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.FeatureStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
|
||||
:members:
|
||||
|
||||
|
||||
Dataset
|
||||
---------------
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
- Override the `finetune` method (Optional)
|
||||
- This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- The parameters must include the parameter `dataset`.
|
||||
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
|
||||
.. code-block:: Python
|
||||
|
||||
@@ -1,24 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import fire
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
from qlib.utils import init_instance_by_config, exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
@@ -96,9 +85,7 @@ class HighfreqWorkflow:
|
||||
# use yahoo_cn_1min data
|
||||
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
|
||||
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
|
||||
qlib.init(**QLIB_INIT_CONFIG)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
|
||||
@@ -1,46 +1,9 @@
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
import optuna
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(scripts_dir))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
dataset_task = {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataset = init_instance_by_config(dataset_task["dataset"])
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.config import CSI300_DATASET_CONFIG
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
def objective(trial):
|
||||
@@ -65,12 +28,19 @@ def objective(trial):
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
|
||||
@@ -1,46 +1,11 @@
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
import optuna
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(scripts_dir))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
dataset_task = {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha360",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataset = init_instance_by_config(dataset_task["dataset"])
|
||||
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
|
||||
|
||||
|
||||
def objective(trial):
|
||||
@@ -72,5 +37,13 @@ def objective(trial):
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
dataset = init_instance_by_config(DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
|
||||
32
examples/model_interpreter/feature.py
Normal file
32
examples/model_interpreter/feature.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
# get model feature importance
|
||||
feature_importance = model.get_feature_importance()
|
||||
print("feature importance:")
|
||||
print(feature_importance)
|
||||
@@ -17,64 +17,8 @@ from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM, task_train
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingTaskExample:
|
||||
@@ -86,11 +30,13 @@ class RollingTaskExample:
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_config=[task_xgboost_config, task_lgb_config],
|
||||
task_config=None,
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
):
|
||||
# TaskManager config
|
||||
if task_config is None:
|
||||
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
|
||||
@@ -13,63 +13,7 @@ from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -84,7 +28,7 @@ class OnlineSimulationExample:
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
tasks=None,
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
@@ -101,6 +45,8 @@ class OnlineSimulationExample:
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
|
||||
@@ -18,63 +18,7 @@ from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
@@ -86,9 +30,13 @@ class RollingOnlineExample:
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
rolling_step=550,
|
||||
tasks=[task_xgboost_config],
|
||||
add_tasks=[task_lgb_config],
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
|
||||
@@ -7,56 +7,19 @@ There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
import copy
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"record": {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,13 +4,11 @@
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
import pandas as pd
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
@@ -25,9 +23,7 @@ class RollingDataWorkflow:
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def _dump_pre_handler(self, path):
|
||||
|
||||
@@ -5,13 +5,11 @@ import os
|
||||
import sys
|
||||
import fire
|
||||
import time
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
import tempfile
|
||||
import traceback
|
||||
import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
@@ -23,8 +21,7 @@ from pprint import pprint
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.cli import workflow
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
# init qlib
|
||||
@@ -39,12 +36,8 @@ exp_manager = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
# decorator to check the arguments
|
||||
|
||||
@@ -1,82 +1,22 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
@@ -90,7 +30,7 @@ if __name__ == "__main__":
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
@@ -100,8 +40,8 @@ if __name__ == "__main__":
|
||||
}
|
||||
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
|
||||
# NOTE: This line is optional
|
||||
# It demonstrates that the dataset can be used standalone.
|
||||
@@ -110,7 +50,7 @@ if __name__ == "__main__":
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class Position:
|
||||
def save_position(self, path, last_trade_date):
|
||||
path = pathlib.Path(path)
|
||||
p = copy.deepcopy(self.position)
|
||||
cash = pd.Series(dtype=np.float)
|
||||
cash = pd.Series(dtype=float)
|
||||
cash["init_cash"] = self.init_cash
|
||||
cash["cash"] = p["cash"]
|
||||
cash["today_account_value"] = p["today_account_value"]
|
||||
|
||||
@@ -10,9 +10,10 @@ from catboost.utils import get_gpu_device_count
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import FeatureInt
|
||||
|
||||
|
||||
class CatBoostModel(Model):
|
||||
class CatBoostModel(Model, FeatureInt):
|
||||
"""CatBoost Model"""
|
||||
|
||||
def __init__(self, loss="RMSE", **kwargs):
|
||||
@@ -69,6 +70,18 @@ class CatBoostModel(Model):
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
Notes
|
||||
-----
|
||||
parameters references:
|
||||
https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance
|
||||
"""
|
||||
return pd.Series(
|
||||
data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_
|
||||
).sort_values(ascending=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cat = CatBoostModel()
|
||||
|
||||
@@ -1,251 +1,265 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...log import get_module_logger
|
||||
|
||||
|
||||
class DEnsembleModel(Model):
|
||||
"""Double Ensemble Model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model="gbm",
|
||||
loss="mse",
|
||||
num_models=6,
|
||||
enable_sr=True,
|
||||
enable_fs=True,
|
||||
alpha1=1.0,
|
||||
alpha2=1.0,
|
||||
bins_sr=10,
|
||||
bins_fs=5,
|
||||
decay=None,
|
||||
sample_ratios=None,
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
**kwargs
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
self.num_models = num_models # the number of sub-models
|
||||
self.enable_sr = enable_sr
|
||||
self.enable_fs = enable_fs
|
||||
self.alpha1 = alpha1
|
||||
self.alpha2 = alpha2
|
||||
self.bins_sr = bins_sr
|
||||
self.bins_fs = bins_fs
|
||||
self.decay = decay
|
||||
if sample_ratios is None: # the default values for sample_ratios
|
||||
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
|
||||
if sub_weights is None: # the default values for sub_weights
|
||||
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
if not len(sub_weights) == num_models:
|
||||
raise ValueError("The length of sub_weights should be equal to num_models.")
|
||||
self.sub_weights = sub_weights
|
||||
self.epochs = epochs
|
||||
self.logger = get_module_logger("DEnsembleModel")
|
||||
self.logger.info("Double Ensemble Model...")
|
||||
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
|
||||
self.sub_features = [] # the features for each sub model in the form of pandas.Index
|
||||
self.params = {"objective": loss}
|
||||
self.params.update(kwargs)
|
||||
self.loss = loss
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
# initialize the sample weights
|
||||
N, F = x_train.shape
|
||||
weights = pd.Series(np.ones(N, dtype=float))
|
||||
# initialize the features
|
||||
features = x_train.columns
|
||||
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
|
||||
# train sub-models
|
||||
for k in range(self.num_models):
|
||||
self.sub_features.append(features)
|
||||
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
|
||||
model_k = self.train_submodel(df_train, df_valid, weights, features)
|
||||
self.ensemble.append(model_k)
|
||||
# no further sample re-weight and feature selection needed for the last sub-model
|
||||
if k + 1 == self.num_models:
|
||||
break
|
||||
|
||||
self.logger.info("Retrieving loss curve and loss values...")
|
||||
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
|
||||
pred_k = self.predict_sub(model_k, df_train, features)
|
||||
pred_sub.iloc[:, k] = pred_k
|
||||
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
|
||||
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
|
||||
|
||||
if self.enable_sr:
|
||||
self.logger.info("Sample re-weighting...")
|
||||
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
|
||||
|
||||
if self.enable_fs:
|
||||
self.logger.info("Feature selection...")
|
||||
features = self.feature_selection(df_train, loss_values)
|
||||
|
||||
def train_submodel(self, df_train, df_valid, weights, features):
|
||||
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
|
||||
evals_result = dict()
|
||||
model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=self.epochs,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
verbose_eval=20,
|
||||
evals_result=evals_result,
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
return model
|
||||
|
||||
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train, weight=weights)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def sample_reweight(self, loss_curve, loss_values, k_th):
|
||||
"""
|
||||
the SR module of Double Ensemble
|
||||
:param loss_curve: the shape is NxT
|
||||
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
|
||||
after the t-th iteration in the training of the previous sub-model.
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:param k_th: the index of the current sub-model, starting from 1
|
||||
:return: weights
|
||||
the weights for all the samples.
|
||||
"""
|
||||
# normalize loss_curve and loss_values with ranking
|
||||
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
|
||||
loss_values_norm = (-loss_values).rank(pct=True)
|
||||
|
||||
# calculate l_start and l_end from loss_curve
|
||||
N, T = loss_curve.shape
|
||||
part = np.maximum(int(T * 0.1), 1)
|
||||
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
|
||||
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
|
||||
|
||||
# calculate h-value for each sample
|
||||
h1 = loss_values_norm
|
||||
h2 = (l_end / l_start).rank(pct=True)
|
||||
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
|
||||
|
||||
# calculate weights
|
||||
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
"""
|
||||
the FS module of Double Ensemble
|
||||
:param df_train: the shape is NxF
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:return: res_feat: in the form of pandas.Index
|
||||
|
||||
"""
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
features = x_train.columns
|
||||
N, F = x_train.shape
|
||||
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
|
||||
M = len(self.ensemble)
|
||||
|
||||
# shuffle specific columns and calculate g-value for each feature
|
||||
x_train_tmp = x_train.copy()
|
||||
for i_f, feat in enumerate(features):
|
||||
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
|
||||
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
|
||||
for i_s, submodel in enumerate(self.ensemble):
|
||||
pred += (
|
||||
pd.Series(
|
||||
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
|
||||
)
|
||||
/ M
|
||||
)
|
||||
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
|
||||
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
|
||||
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
|
||||
|
||||
# one column in train features is all-nan # if g['g_value'].isna().any()
|
||||
g["g_value"].replace(np.nan, 0, inplace=True)
|
||||
|
||||
# divide features into bins_fs bins
|
||||
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
|
||||
|
||||
# randomly sample features from bins to construct the new features
|
||||
res_feat = []
|
||||
sorted_bins = sorted(g["bins"].unique(), reverse=True)
|
||||
for i_b, b in enumerate(sorted_bins):
|
||||
b_feat = features[g["bins"] == b]
|
||||
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
|
||||
res_feat = res_feat + np.random.choice(b_feat, size=num_feat).tolist()
|
||||
return pd.Index(res_feat)
|
||||
|
||||
def get_loss(self, label, pred):
|
||||
if self.loss == "mse":
|
||||
return (label - pred) ** 2
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
|
||||
def retrieve_loss_curve(self, model, df_train, features):
|
||||
if self.base_model == "gbm":
|
||||
num_trees = model.num_trees()
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train = np.squeeze(y_train.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
N = x_train.shape[0]
|
||||
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
|
||||
pred_tree = np.zeros(N, dtype=float)
|
||||
for i_tree in range(num_trees):
|
||||
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
|
||||
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
return loss_curve
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.ensemble is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
|
||||
for i_sub, submodel in enumerate(self.ensemble):
|
||||
feat_sub = self.sub_features[i_sub]
|
||||
pred += (
|
||||
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
|
||||
* self.sub_weights[i_sub]
|
||||
)
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
|
||||
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
|
||||
return pred_sub
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import FeatureInt
|
||||
from ...log import get_module_logger
|
||||
|
||||
|
||||
class DEnsembleModel(Model, FeatureInt):
|
||||
"""Double Ensemble Model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model="gbm",
|
||||
loss="mse",
|
||||
num_models=6,
|
||||
enable_sr=True,
|
||||
enable_fs=True,
|
||||
alpha1=1.0,
|
||||
alpha2=1.0,
|
||||
bins_sr=10,
|
||||
bins_fs=5,
|
||||
decay=None,
|
||||
sample_ratios=None,
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
**kwargs
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
self.num_models = num_models # the number of sub-models
|
||||
self.enable_sr = enable_sr
|
||||
self.enable_fs = enable_fs
|
||||
self.alpha1 = alpha1
|
||||
self.alpha2 = alpha2
|
||||
self.bins_sr = bins_sr
|
||||
self.bins_fs = bins_fs
|
||||
self.decay = decay
|
||||
if sample_ratios is None: # the default values for sample_ratios
|
||||
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
|
||||
if sub_weights is None: # the default values for sub_weights
|
||||
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
if not len(sub_weights) == num_models:
|
||||
raise ValueError("The length of sub_weights should be equal to num_models.")
|
||||
self.sub_weights = sub_weights
|
||||
self.epochs = epochs
|
||||
self.logger = get_module_logger("DEnsembleModel")
|
||||
self.logger.info("Double Ensemble Model...")
|
||||
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
|
||||
self.sub_features = [] # the features for each sub model in the form of pandas.Index
|
||||
self.params = {"objective": loss}
|
||||
self.params.update(kwargs)
|
||||
self.loss = loss
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
# initialize the sample weights
|
||||
N, F = x_train.shape
|
||||
weights = pd.Series(np.ones(N, dtype=float))
|
||||
# initialize the features
|
||||
features = x_train.columns
|
||||
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
|
||||
# train sub-models
|
||||
for k in range(self.num_models):
|
||||
self.sub_features.append(features)
|
||||
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
|
||||
model_k = self.train_submodel(df_train, df_valid, weights, features)
|
||||
self.ensemble.append(model_k)
|
||||
# no further sample re-weight and feature selection needed for the last sub-model
|
||||
if k + 1 == self.num_models:
|
||||
break
|
||||
|
||||
self.logger.info("Retrieving loss curve and loss values...")
|
||||
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
|
||||
pred_k = self.predict_sub(model_k, df_train, features)
|
||||
pred_sub.iloc[:, k] = pred_k
|
||||
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
|
||||
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
|
||||
|
||||
if self.enable_sr:
|
||||
self.logger.info("Sample re-weighting...")
|
||||
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
|
||||
|
||||
if self.enable_fs:
|
||||
self.logger.info("Feature selection...")
|
||||
features = self.feature_selection(df_train, loss_values)
|
||||
|
||||
def train_submodel(self, df_train, df_valid, weights, features):
|
||||
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
|
||||
evals_result = dict()
|
||||
model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=self.epochs,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
verbose_eval=20,
|
||||
evals_result=evals_result,
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
return model
|
||||
|
||||
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train, label=y_train, weight=weights)
|
||||
dvalid = lgb.Dataset(x_valid, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def sample_reweight(self, loss_curve, loss_values, k_th):
|
||||
"""
|
||||
the SR module of Double Ensemble
|
||||
:param loss_curve: the shape is NxT
|
||||
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
|
||||
after the t-th iteration in the training of the previous sub-model.
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:param k_th: the index of the current sub-model, starting from 1
|
||||
:return: weights
|
||||
the weights for all the samples.
|
||||
"""
|
||||
# normalize loss_curve and loss_values with ranking
|
||||
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
|
||||
loss_values_norm = (-loss_values).rank(pct=True)
|
||||
|
||||
# calculate l_start and l_end from loss_curve
|
||||
N, T = loss_curve.shape
|
||||
part = np.maximum(int(T * 0.1), 1)
|
||||
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
|
||||
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
|
||||
|
||||
# calculate h-value for each sample
|
||||
h1 = loss_values_norm
|
||||
h2 = (l_end / l_start).rank(pct=True)
|
||||
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
|
||||
|
||||
# calculate weights
|
||||
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
"""
|
||||
the FS module of Double Ensemble
|
||||
:param df_train: the shape is NxF
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:return: res_feat: in the form of pandas.Index
|
||||
|
||||
"""
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
features = x_train.columns
|
||||
N, F = x_train.shape
|
||||
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
|
||||
M = len(self.ensemble)
|
||||
|
||||
# shuffle specific columns and calculate g-value for each feature
|
||||
x_train_tmp = x_train.copy()
|
||||
for i_f, feat in enumerate(features):
|
||||
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
|
||||
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
|
||||
for i_s, submodel in enumerate(self.ensemble):
|
||||
pred += (
|
||||
pd.Series(
|
||||
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
|
||||
)
|
||||
/ M
|
||||
)
|
||||
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
|
||||
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
|
||||
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
|
||||
|
||||
# one column in train features is all-nan # if g['g_value'].isna().any()
|
||||
g["g_value"].replace(np.nan, 0, inplace=True)
|
||||
|
||||
# divide features into bins_fs bins
|
||||
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
|
||||
|
||||
# randomly sample features from bins to construct the new features
|
||||
res_feat = []
|
||||
sorted_bins = sorted(g["bins"].unique(), reverse=True)
|
||||
for i_b, b in enumerate(sorted_bins):
|
||||
b_feat = features[g["bins"] == b]
|
||||
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
|
||||
res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist()
|
||||
return pd.Index(set(res_feat))
|
||||
|
||||
def get_loss(self, label, pred):
|
||||
if self.loss == "mse":
|
||||
return (label - pred) ** 2
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
|
||||
def retrieve_loss_curve(self, model, df_train, features):
|
||||
if self.base_model == "gbm":
|
||||
num_trees = model.num_trees()
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train = np.squeeze(y_train.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
N = x_train.shape[0]
|
||||
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
|
||||
pred_tree = np.zeros(N, dtype=float)
|
||||
for i_tree in range(num_trees):
|
||||
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
|
||||
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
return loss_curve
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.ensemble is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
|
||||
for i_sub, submodel in enumerate(self.ensemble):
|
||||
feat_sub = self.sub_features[i_sub]
|
||||
pred += (
|
||||
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
|
||||
* self.sub_weights[i_sub]
|
||||
)
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
|
||||
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
|
||||
return pred_sub
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
Notes
|
||||
-----
|
||||
parameters reference:
|
||||
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
|
||||
"""
|
||||
res = []
|
||||
for _model, _weight in zip(self.ensemble, self.sub_weights):
|
||||
res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight)
|
||||
return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False)
|
||||
|
||||
@@ -8,9 +8,10 @@ from typing import Text, Union
|
||||
from ...model.base import ModelFT
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import LightGBMFInt
|
||||
|
||||
|
||||
class LGBModel(ModelFT):
|
||||
class LGBModel(ModelFT, LightGBMFInt):
|
||||
"""LightGBM Model"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
@@ -33,8 +34,8 @@ class LGBModel(ModelFT):
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
dtrain = lgb.Dataset(x_train, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
import warnings
|
||||
from ...model.base import ModelFT
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import LightGBMFInt
|
||||
|
||||
|
||||
class HFLGBModel(ModelFT):
|
||||
class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""LightGBM Model for high frequency prediction"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
@@ -97,8 +98,8 @@ class HFLGBModel(ModelFT):
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
dtrain = lgb.Dataset(x_train, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(
|
||||
|
||||
@@ -8,9 +8,10 @@ from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import FeatureInt
|
||||
|
||||
|
||||
class XGBModel(Model):
|
||||
class XGBModel(Model, FeatureInt):
|
||||
"""XGBModel Model"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -42,8 +43,8 @@ class XGBModel(Model):
|
||||
else:
|
||||
raise ValueError("XGBoost doesn't support multi-label training")
|
||||
|
||||
dtrain = xgb.DMatrix(x_train.values, label=y_train_1d)
|
||||
dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d)
|
||||
dtrain = xgb.DMatrix(x_train, label=y_train_1d)
|
||||
dvalid = xgb.DMatrix(x_valid, label=y_valid_1d)
|
||||
self.model = xgb.train(
|
||||
self._params,
|
||||
dtrain=dtrain,
|
||||
@@ -62,3 +63,13 @@ class XGBModel(Model):
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
Notes
|
||||
-------
|
||||
parameters reference:
|
||||
https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.get_score
|
||||
"""
|
||||
return pd.Series(self.model.get_score(*args, **kwargs)).sort_values(ascending=False)
|
||||
|
||||
@@ -6,7 +6,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import copy
|
||||
import time
|
||||
import queue
|
||||
import bisect
|
||||
@@ -27,12 +29,41 @@ from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
class ProviderBackendMixin:
|
||||
def get_default_backend(self):
|
||||
backend = {}
|
||||
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
|
||||
# set default storage class
|
||||
backend.setdefault("class", f"File{provider_name}Storage")
|
||||
# set default storage module
|
||||
backend.setdefault("module_path", "qlib.data.storage.file_storage")
|
||||
return backend
|
||||
|
||||
def backend_obj(self, **kwargs):
|
||||
backend = self.backend if self.backend else self.get_default_backend()
|
||||
backend = copy.deepcopy(backend)
|
||||
|
||||
# set default storage kwargs
|
||||
backend_kwargs = backend.setdefault("kwargs", {})
|
||||
# default provider_uri map
|
||||
if "provider_uri" not in backend_kwargs:
|
||||
# if the user has no uri configured, use: uri = uri_map[freq]
|
||||
freq = kwargs.get("freq", "day")
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {freq: C.get_data_path()})
|
||||
backend_kwargs["provider_uri"] = provider_uri_map[freq]
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Calendar provider base class
|
||||
|
||||
Provide calendar data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
"""Get calendar of certain market in given time range.
|
||||
@@ -127,12 +158,15 @@ class CalendarProvider(abc.ABC):
|
||||
return hash_args(start_time, end_time, freq, future)
|
||||
|
||||
|
||||
class InstrumentProvider(abc.ABC):
|
||||
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Instrument provider base class
|
||||
|
||||
Provide instrument data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@staticmethod
|
||||
def instruments(market="all", filter_pipe=None):
|
||||
"""Get the general config dictionary for a base market adding several dynamic filters.
|
||||
@@ -215,12 +249,15 @@ class InstrumentProvider(abc.ABC):
|
||||
raise ValueError(f"Unknown instrument type {inst}")
|
||||
|
||||
|
||||
class FeatureProvider(abc.ABC):
|
||||
class FeatureProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Feature provider class
|
||||
|
||||
Provide feature data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def feature(self, instrument, field, start_time, end_time, freq):
|
||||
"""Get feature data.
|
||||
@@ -497,6 +534,7 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LocalCalendarProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
@@ -517,21 +555,22 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
list
|
||||
list of timestamps
|
||||
"""
|
||||
if future:
|
||||
fname = self._uri_cal.format(freq + "_future")
|
||||
# if future calendar not exists, return current calendar
|
||||
if not os.path.exists(fname):
|
||||
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
|
||||
|
||||
try:
|
||||
backend_obj = self.backend_obj(freq=freq, future=future).data
|
||||
except ValueError:
|
||||
if future:
|
||||
get_module_logger("data").warning(
|
||||
f"load calendar error: freq={freq}, future={future}; return current calendar!"
|
||||
)
|
||||
get_module_logger("data").warning(
|
||||
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
|
||||
)
|
||||
fname = self._uri_cal.format(freq)
|
||||
else:
|
||||
fname = self._uri_cal.format(freq)
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("calendar not exists for freq " + freq)
|
||||
with open(fname) as f:
|
||||
return [pd.Timestamp(x.strip()) for x in f]
|
||||
backend_obj = self.backend_obj(freq=freq, future=False).data
|
||||
else:
|
||||
raise
|
||||
|
||||
return [pd.Timestamp(x) for x in backend_obj]
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
_calendar, _calendar_index = self._get_calendar(freq, future)
|
||||
@@ -562,38 +601,20 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
Provide instrument data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def _uri_inst(self):
|
||||
"""Instrument file uri."""
|
||||
return os.path.join(C.get_data_path(), "instruments", "{}.txt")
|
||||
|
||||
def _load_instruments(self, market):
|
||||
fname = self._uri_inst.format(market)
|
||||
if not os.path.exists(fname):
|
||||
raise ValueError("instruments not exists for market " + market)
|
||||
|
||||
_instruments = dict()
|
||||
df = pd.read_csv(
|
||||
fname,
|
||||
sep="\t",
|
||||
usecols=[0, 1, 2],
|
||||
names=["inst", "start_datetime", "end_datetime"],
|
||||
dtype={"inst": str},
|
||||
parse_dates=["start_datetime", "end_datetime"],
|
||||
)
|
||||
for row in df.itertuples(index=False):
|
||||
_instruments.setdefault(row[0], []).append((row[1], row[2]))
|
||||
return _instruments
|
||||
def _load_instruments(self, market, freq):
|
||||
return self.backend_obj(market=market, freq=freq).data
|
||||
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
market = instruments["market"]
|
||||
if market in H["i"]:
|
||||
_instruments = H["i"][market]
|
||||
else:
|
||||
_instruments = self._load_instruments(market)
|
||||
_instruments = self._load_instruments(market, freq=freq)
|
||||
H["i"][market] = _instruments
|
||||
# strip
|
||||
# use calendar boundary
|
||||
@@ -604,7 +625,7 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
inst: list(
|
||||
filter(
|
||||
lambda x: x[0] <= x[1],
|
||||
[(max(start_time, x[0]), min(end_time, x[1])) for x in spans],
|
||||
[(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans],
|
||||
)
|
||||
)
|
||||
for inst, spans in _instruments.items()
|
||||
@@ -630,6 +651,7 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LocalFeatureProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
|
||||
@property
|
||||
@@ -641,14 +663,7 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
# validate
|
||||
field = str(field).lower()[1:]
|
||||
instrument = code_to_fname(instrument)
|
||||
uri_data = self._uri_data.format(instrument.lower(), field, freq)
|
||||
if not os.path.exists(uri_data):
|
||||
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
|
||||
return pd.Series(dtype=np.float32)
|
||||
# raise ValueError('uri_data not found: ' + uri_data)
|
||||
# load
|
||||
series = read_bin(uri_data, start_index, end_index)
|
||||
return series
|
||||
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
|
||||
|
||||
|
||||
class LocalExpressionProvider(ExpressionProvider):
|
||||
@@ -1065,7 +1080,8 @@ def register_all_wrappers(C):
|
||||
register_wrapper(Cal, _calendar_provider, "qlib.data")
|
||||
logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}")
|
||||
|
||||
register_wrapper(Inst, C.instrument_provider, "qlib.data")
|
||||
_instrument_provider = init_instance_by_config(C.instrument_provider, module)
|
||||
register_wrapper(Inst, _instrument_provider, "qlib.data")
|
||||
logger.debug(f"registering Inst {C.instrument_provider}")
|
||||
|
||||
if getattr(C, "feature_provider", None) is not None:
|
||||
|
||||
@@ -357,7 +357,7 @@ class TSDataSampler:
|
||||
# get the previous index of a line given index
|
||||
"""
|
||||
# object incase of pandas converting int to flaot
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
|
||||
idx_df = lazy_sort_index(idx_df.unstack())
|
||||
# NOTE: the correctness of `__getitem__` depends on columns sorted here
|
||||
idx_df = lazy_sort_index(idx_df, axis=1)
|
||||
|
||||
4
qlib/data/storage/__init__.py
Normal file
4
qlib/data/storage/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT
|
||||
292
qlib/data/storage/file_storage.py
Normal file
292
qlib/data/storage/file_storage.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Union, Dict, Mapping, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
|
||||
|
||||
logger = get_module_logger("file_storage")
|
||||
|
||||
|
||||
class FileStorageMixin:
|
||||
@property
|
||||
def uri(self) -> Path:
|
||||
_provider_uri = self.kwargs.get("provider_uri", None)
|
||||
if _provider_uri is None:
|
||||
raise ValueError(
|
||||
f"The `provider_uri` parameter is not found in {self.__class__.__name__}, "
|
||||
f'please specify `provider_uri` in the "provider\'s backend"'
|
||||
)
|
||||
return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name)
|
||||
|
||||
def check(self):
|
||||
"""check self.uri
|
||||
|
||||
Raises
|
||||
-------
|
||||
ValueError
|
||||
"""
|
||||
if not self.uri.exists():
|
||||
raise ValueError(f"{self.storage_name} not exists: {self.uri}")
|
||||
|
||||
|
||||
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
def __init__(self, freq: str, future: bool, **kwargs):
|
||||
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
|
||||
self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower()
|
||||
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
|
||||
if not self.uri.exists():
|
||||
self._write_calendar(values=[])
|
||||
with self.uri.open("rb") as fp:
|
||||
return [
|
||||
str(x)
|
||||
for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, delimiter="\n", encoding="utf-8")
|
||||
]
|
||||
|
||||
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
|
||||
with self.uri.open(mode=mode) as fp:
|
||||
np.savetxt(fp, values, fmt="%s", encoding="utf-8")
|
||||
|
||||
@property
|
||||
def data(self) -> List[CalVT]:
|
||||
self.check()
|
||||
return self._read_calendar()
|
||||
|
||||
def extend(self, values: Iterable[CalVT]) -> None:
|
||||
self._write_calendar(values, mode="ab")
|
||||
|
||||
def clear(self) -> None:
|
||||
self._write_calendar(values=[])
|
||||
|
||||
def index(self, value: CalVT) -> int:
|
||||
self.check()
|
||||
calendar = self._read_calendar()
|
||||
return int(np.argwhere(calendar == value)[0])
|
||||
|
||||
def insert(self, index: int, value: CalVT):
|
||||
calendar = self._read_calendar()
|
||||
calendar = np.insert(calendar, index, value)
|
||||
self._write_calendar(values=calendar)
|
||||
|
||||
def remove(self, value: CalVT) -> None:
|
||||
self.check()
|
||||
index = self.index(value)
|
||||
calendar = self._read_calendar()
|
||||
calendar = np.delete(calendar, index)
|
||||
self._write_calendar(values=calendar)
|
||||
|
||||
def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None:
|
||||
calendar = self._read_calendar()
|
||||
calendar[i] = values
|
||||
self._write_calendar(values=calendar)
|
||||
|
||||
def __delitem__(self, i: Union[int, slice]) -> None:
|
||||
self.check()
|
||||
calendar = self._read_calendar()
|
||||
calendar = np.delete(calendar, i)
|
||||
self._write_calendar(values=calendar)
|
||||
|
||||
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:
|
||||
self.check()
|
||||
return self._read_calendar()[i]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
|
||||
|
||||
INSTRUMENT_SEP = "\t"
|
||||
INSTRUMENT_START_FIELD = "start_datetime"
|
||||
INSTRUMENT_END_FIELD = "end_datetime"
|
||||
SYMBOL_FIELD_NAME = "instrument"
|
||||
|
||||
def __init__(self, market: str, **kwargs):
|
||||
super(FileInstrumentStorage, self).__init__(market, **kwargs)
|
||||
self.file_name = f"{market.lower()}.txt"
|
||||
|
||||
def _read_instrument(self) -> Dict[InstKT, InstVT]:
|
||||
if not self.uri.exists():
|
||||
self._write_instrument()
|
||||
|
||||
_instruments = dict()
|
||||
df = pd.read_csv(
|
||||
self.uri,
|
||||
sep="\t",
|
||||
usecols=[0, 1, 2],
|
||||
names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
|
||||
dtype={self.SYMBOL_FIELD_NAME: str},
|
||||
parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD],
|
||||
)
|
||||
for row in df.itertuples(index=False):
|
||||
_instruments.setdefault(row[0], []).append((row[1], row[2]))
|
||||
return _instruments
|
||||
|
||||
def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None:
|
||||
if not data:
|
||||
with self.uri.open("w") as _:
|
||||
pass
|
||||
return
|
||||
|
||||
res = []
|
||||
for inst, v_list in data.items():
|
||||
_df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD])
|
||||
_df[self.SYMBOL_FIELD_NAME] = inst
|
||||
res.append(_df)
|
||||
|
||||
df = pd.concat(res, sort=False)
|
||||
df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv(
|
||||
self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False
|
||||
)
|
||||
df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._write_instrument(data={})
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[InstKT, InstVT]:
|
||||
self.check()
|
||||
return self._read_instrument()
|
||||
|
||||
def __setitem__(self, k: InstKT, v: InstVT) -> None:
|
||||
inst = self._read_instrument()
|
||||
inst[k] = v
|
||||
self._write_instrument(inst)
|
||||
|
||||
def __delitem__(self, k: InstKT) -> None:
|
||||
self.check()
|
||||
inst = self._read_instrument()
|
||||
del inst[k]
|
||||
self._write_instrument(inst)
|
||||
|
||||
def __getitem__(self, k: InstKT) -> InstVT:
|
||||
self.check()
|
||||
return self._read_instrument()[k]
|
||||
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
|
||||
if len(args) > 1:
|
||||
raise TypeError(f"update expected at most 1 arguments, got {len(args)}")
|
||||
inst = self._read_instrument()
|
||||
if args:
|
||||
other = args[0] # type: dict
|
||||
if isinstance(other, Mapping):
|
||||
for key in other:
|
||||
inst[key] = other[key]
|
||||
elif hasattr(other, "keys"):
|
||||
for key in other.keys():
|
||||
inst[key] = other[key]
|
||||
else:
|
||||
for key, value in other:
|
||||
inst[key] = value
|
||||
for key, value in kwargs.items():
|
||||
inst[key] = value
|
||||
|
||||
self._write_instrument(inst)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class FileFeatureStorage(FileStorageMixin, FeatureStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
|
||||
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
|
||||
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
|
||||
|
||||
def clear(self):
|
||||
with self.uri.open("wb") as _:
|
||||
pass
|
||||
|
||||
@property
|
||||
def data(self) -> pd.Series:
|
||||
return self[:]
|
||||
|
||||
def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None:
|
||||
if len(data_array) == 0:
|
||||
logger.info(
|
||||
"len(data_array) == 0, write"
|
||||
"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
|
||||
)
|
||||
return
|
||||
if not self.uri.exists():
|
||||
# write
|
||||
index = 0 if index is None else index
|
||||
with self.uri.open("wb") as fp:
|
||||
np.hstack([index, data_array]).astype("<f").tofile(fp)
|
||||
else:
|
||||
if index is None or index > self.end_index:
|
||||
# append
|
||||
index = 0 if index is None else index
|
||||
with self.uri.open("ab+") as fp:
|
||||
np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype("<f").tofile(fp)
|
||||
else:
|
||||
# rewrite
|
||||
with self.uri.open("rb+") as fp:
|
||||
_old_data = np.fromfile(fp, dtype="<f")
|
||||
_old_index = _old_data[0]
|
||||
_old_df = pd.DataFrame(
|
||||
_old_data[1:], index=range(_old_index, _old_index + len(_old_data) - 1), columns=["old"]
|
||||
)
|
||||
fp.seek(0)
|
||||
_new_df = pd.DataFrame(data_array, index=range(index, index + len(data_array)), columns=["new"])
|
||||
_df = pd.concat([_old_df, _new_df], sort=False, axis=1)
|
||||
_df = _df.reindex(range(_df.index.min(), _df.index.max() + 1))
|
||||
_df["new"].fillna(_df["old"]).values.astype("<f").tofile(fp)
|
||||
|
||||
@property
|
||||
def start_index(self) -> Union[int, None]:
|
||||
if not self.uri.exists():
|
||||
return None
|
||||
with self.uri.open("rb") as fp:
|
||||
index = int(np.frombuffer(fp.read(4), dtype="<f")[0])
|
||||
return index
|
||||
|
||||
@property
|
||||
def end_index(self) -> Union[int, None]:
|
||||
if not self.uri.exists():
|
||||
return None
|
||||
# The next data appending index point will be `end_index + 1`
|
||||
return self.start_index + len(self) - 1
|
||||
|
||||
def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
|
||||
if not self.uri.exists():
|
||||
if isinstance(i, int):
|
||||
return None, None
|
||||
elif isinstance(i, slice):
|
||||
return pd.Series(dtype=np.float32)
|
||||
else:
|
||||
raise TypeError(f"type(i) = {type(i)}")
|
||||
|
||||
storage_start_index = self.start_index
|
||||
storage_end_index = self.end_index
|
||||
with self.uri.open("rb") as fp:
|
||||
if isinstance(i, int):
|
||||
|
||||
if storage_start_index > i:
|
||||
raise IndexError(f"{i}: start index is {storage_start_index}")
|
||||
fp.seek(4 * (i - storage_start_index) + 4)
|
||||
return i, struct.unpack("f", fp.read(4))[0]
|
||||
elif isinstance(i, slice):
|
||||
start_index = storage_start_index if i.start is None else i.start
|
||||
end_index = storage_end_index if i.stop is None else i.stop - 1
|
||||
si = max(start_index, storage_start_index)
|
||||
if si > end_index:
|
||||
return pd.Series(dtype=np.float32)
|
||||
fp.seek(4 * (si - storage_start_index) + 4)
|
||||
# read n bytes
|
||||
count = end_index - si + 1
|
||||
data = np.frombuffer(fp.read(4 * count), dtype="<f")
|
||||
return pd.Series(data, index=pd.RangeIndex(si, si + len(data)))
|
||||
else:
|
||||
raise TypeError(f"type(i) = {type(i)}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.check()
|
||||
return self.uri.stat().st_size // 4 - 1
|
||||
501
qlib/data/storage/storage.py
Normal file
501
qlib/data/storage/storage.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
from typing import Iterable, overload, Tuple, List, Text, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
# calendar value type
|
||||
CalVT = str
|
||||
|
||||
# instrument value
|
||||
InstVT = List[Tuple[CalVT, CalVT]]
|
||||
# instrument key
|
||||
InstKT = Text
|
||||
|
||||
logger = get_module_logger("storage")
|
||||
|
||||
"""
|
||||
If the user is only using it in `qlib`, you can customize Storage to implement only the following methods:
|
||||
|
||||
class UserCalendarStorage(CalendarStorage):
|
||||
|
||||
@property
|
||||
def data(self) -> Iterable[CalVT]:
|
||||
'''get all data
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
'''
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `data` method")
|
||||
|
||||
|
||||
class UserInstrumentStorage(InstrumentStorage):
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[InstKT, InstVT]:
|
||||
'''get all data
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
'''
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
|
||||
|
||||
|
||||
class UserFeatureStorage(FeatureStorage):
|
||||
|
||||
def __getitem__(self, s: slice) -> pd.Series:
|
||||
'''x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Series(values, index=pd.RangeIndex(start, len(values))
|
||||
|
||||
Notes
|
||||
-------
|
||||
if data(storage) does not exist:
|
||||
if isinstance(i, int):
|
||||
return (None, None)
|
||||
if isinstance(i, slice):
|
||||
# return empty pd.Series
|
||||
return pd.Series(dtype=np.float32)
|
||||
'''
|
||||
raise NotImplementedError(
|
||||
"Subclass of FeatureStorage must implement `__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class BaseStorage:
|
||||
@property
|
||||
def storage_name(self) -> str:
|
||||
return re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2].lower()
|
||||
|
||||
|
||||
class CalendarStorage(BaseStorage):
|
||||
"""
|
||||
The behavior of CalendarStorage's methods and List's methods of the same name remain consistent
|
||||
"""
|
||||
|
||||
def __init__(self, freq: str, future: bool, **kwargs):
|
||||
self.freq = freq
|
||||
self.future = future
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def data(self) -> Iterable[CalVT]:
|
||||
"""get all data
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `data` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method")
|
||||
|
||||
def extend(self, iterable: Iterable[CalVT]) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method")
|
||||
|
||||
def index(self, value: CalVT) -> int:
|
||||
"""
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `index` method")
|
||||
|
||||
def insert(self, index: int, value: CalVT) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method")
|
||||
|
||||
def remove(self, value: CalVT) -> None:
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `remove` method")
|
||||
|
||||
@overload
|
||||
def __setitem__(self, i: int, value: CalVT) -> None:
|
||||
"""x.__setitem__(i, o) <==> (x[i] = o)"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, s: slice, value: Iterable[CalVT]) -> None:
|
||||
"""x.__setitem__(s, o) <==> (x[s] = o)"""
|
||||
...
|
||||
|
||||
def __setitem__(self, i, value) -> None:
|
||||
raise NotImplementedError(
|
||||
"Subclass of CalendarStorage must implement `__setitem__(i: int, o: CalVT)`/`__setitem__(s: slice, o: Iterable[CalVT])` method"
|
||||
)
|
||||
|
||||
@overload
|
||||
def __delitem__(self, i: int) -> None:
|
||||
"""x.__delitem__(i) <==> del x[i]"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __delitem__(self, i: slice) -> None:
|
||||
"""x.__delitem__(slice(start: int, stop: int, step: int)) <==> del x[start:stop:step]"""
|
||||
...
|
||||
|
||||
def __delitem__(self, i) -> None:
|
||||
"""
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclass of CalendarStorage must implement `__delitem__(i: int)`/`__delitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice) -> Iterable[CalVT]:
|
||||
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, i: int) -> CalVT:
|
||||
"""x.__getitem__(i) <==> x[i]"""
|
||||
...
|
||||
|
||||
def __getitem__(self, i) -> CalVT:
|
||||
"""
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclass of CalendarStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method")
|
||||
|
||||
|
||||
class InstrumentStorage(BaseStorage):
|
||||
def __init__(self, market: str, **kwargs):
|
||||
self.market = market
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[InstKT, InstVT]:
|
||||
"""get all data
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method")
|
||||
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
"""D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.
|
||||
|
||||
Notes
|
||||
------
|
||||
If E present and has a .keys() method, does: for k in E: D[k] = E[k]
|
||||
|
||||
If E present and lacks .keys() method, does: for (k, v) in E: D[k] = v
|
||||
|
||||
In either case, this is followed by: for k, v in F.items(): D[k] = v
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method")
|
||||
|
||||
def __setitem__(self, k: InstKT, v: InstVT) -> None:
|
||||
"""Set self[key] to value."""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method")
|
||||
|
||||
def __delitem__(self, k: InstKT) -> None:
|
||||
"""Delete self[key].
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method")
|
||||
|
||||
def __getitem__(self, k: InstKT) -> InstVT:
|
||||
"""x.__getitem__(k) <==> x[k]"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method")
|
||||
|
||||
|
||||
class FeatureStorage(BaseStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
|
||||
self.instrument = instrument
|
||||
self.field = field
|
||||
self.freq = freq
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
def data(self) -> pd.Series:
|
||||
"""get all data
|
||||
|
||||
Notes
|
||||
------
|
||||
if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)`
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `data` method")
|
||||
|
||||
@property
|
||||
def start_index(self) -> Union[int, None]:
|
||||
"""get FeatureStorage start index
|
||||
|
||||
Notes
|
||||
-----
|
||||
If the data(storage) does not exist, return None
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `start_index` method")
|
||||
|
||||
@property
|
||||
def end_index(self) -> Union[int, None]:
|
||||
"""get FeatureStorage end index
|
||||
|
||||
Notes
|
||||
-----
|
||||
The right index of the data range (both sides are closed)
|
||||
|
||||
The next data appending point will be `end_index + 1`
|
||||
|
||||
If the data(storage) does not exist, return None
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `end_index` method")
|
||||
|
||||
def clear(self) -> None:
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method")
|
||||
|
||||
def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None):
|
||||
"""Write data_array to FeatureStorage starting from index.
|
||||
|
||||
Notes
|
||||
------
|
||||
If index is None, append data_array to feature.
|
||||
|
||||
If len(data_array) == 0; return
|
||||
|
||||
If (index - self.end_index) >= 1, self[end_index+1: index] will be filled with np.nan
|
||||
|
||||
Examples
|
||||
---------
|
||||
.. code-block::
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
|
||||
|
||||
>>> self.write([6, 7], index=6)
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
6 6
|
||||
7 7
|
||||
|
||||
>>> self.write([8], index=9)
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
6 6
|
||||
7 7
|
||||
8 np.nan
|
||||
9 8
|
||||
|
||||
>>> self.write([1, np.nan], index=3)
|
||||
|
||||
feature:
|
||||
3 1
|
||||
4 np.nan
|
||||
5 6
|
||||
6 6
|
||||
7 7
|
||||
8 np.nan
|
||||
9 8
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `write` method")
|
||||
|
||||
def rebase(self, start_index: int = None, end_index: int = None):
|
||||
"""Rebase the start_index and end_index of the FeatureStorage.
|
||||
|
||||
start_index and end_index are closed intervals: [start_index, end_index]
|
||||
|
||||
Examples
|
||||
---------
|
||||
|
||||
.. code-block::
|
||||
|
||||
feature:
|
||||
3 4
|
||||
4 5
|
||||
5 6
|
||||
|
||||
|
||||
>>> self.rebase(start_index=4)
|
||||
|
||||
feature:
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.rebase(start_index=3)
|
||||
|
||||
feature:
|
||||
3 np.nan
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.write([3], index=3)
|
||||
|
||||
feature:
|
||||
3 3
|
||||
4 5
|
||||
5 6
|
||||
|
||||
>>> self.rebase(end_index=4)
|
||||
|
||||
feature:
|
||||
3 3
|
||||
4 5
|
||||
|
||||
>>> self.write([6, 7, 8], index=4)
|
||||
|
||||
feature:
|
||||
3 3
|
||||
4 6
|
||||
5 7
|
||||
6 8
|
||||
|
||||
>>> self.rebase(start_index=4, end_index=5)
|
||||
|
||||
feature:
|
||||
4 6
|
||||
5 7
|
||||
|
||||
"""
|
||||
storage_si = self.start_index
|
||||
storage_ei = self.end_index
|
||||
if storage_si is None or storage_ei is None:
|
||||
raise ValueError("storage.start_index or storage.end_index is None, storage may not exist")
|
||||
|
||||
start_index = storage_si if start_index is None else start_index
|
||||
end_index = storage_ei if end_index is None else end_index
|
||||
|
||||
if start_index is None or end_index is None:
|
||||
logger.warning("both start_index and end_index are None, or storage does not exist; rebase is ignored")
|
||||
return
|
||||
|
||||
if start_index < 0 or end_index < 0:
|
||||
logger.warning("start_index or end_index cannot be less than 0")
|
||||
return
|
||||
if start_index > end_index:
|
||||
logger.warning(
|
||||
f"start_index({start_index}) > end_index({end_index}), rebase is ignored; "
|
||||
f"if you need to clear the FeatureStorage, please execute: FeatureStorage.clear"
|
||||
)
|
||||
return
|
||||
|
||||
if start_index <= storage_si:
|
||||
self.write([np.nan] * (storage_si - start_index), start_index)
|
||||
else:
|
||||
self.rewrite(self[start_index:].values, start_index)
|
||||
|
||||
if end_index >= self.end_index:
|
||||
self.write([np.nan] * (end_index - self.end_index))
|
||||
else:
|
||||
self.rewrite(self[: end_index + 1].values, start_index)
|
||||
|
||||
def rewrite(self, data: Union[List, np.ndarray, Tuple], index: int):
|
||||
"""overwrite all data in FeatureStorage with data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: Union[List, np.ndarray, Tuple]
|
||||
data
|
||||
index: int
|
||||
data start index
|
||||
"""
|
||||
self.clear()
|
||||
self.write(data, index)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice) -> pd.Series:
|
||||
"""x.__getitem__(slice(start: int, stop: int, step: int)) <==> x[start:stop:step]
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Series(values, index=pd.RangeIndex(start, len(values))
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, i: int) -> Tuple[int, float]:
|
||||
"""x.__getitem__(y) <==> x[y]"""
|
||||
...
|
||||
|
||||
def __getitem__(self, i) -> Union[Tuple[int, float], pd.Series]:
|
||||
"""x.__getitem__(y) <==> x[y]
|
||||
|
||||
Notes
|
||||
-------
|
||||
if data(storage) does not exist:
|
||||
if isinstance(i, int):
|
||||
return (None, None)
|
||||
if isinstance(i, slice):
|
||||
# return empty pd.Series
|
||||
return pd.Series(dtype=np.float32)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclass of FeatureStorage must implement `__getitem__(i: int)`/`__getitem__(s: slice)` method"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data(storage) does not exist, raise ValueError
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method")
|
||||
0
qlib/model/interpret/__init__.py
Normal file
0
qlib/model/interpret/__init__.py
Normal file
40
qlib/model/interpret/base.py
Normal file
40
qlib/model/interpret/base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Interfaces to interpret models
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class FeatureInt:
|
||||
"""Feature (Int)erpreter"""
|
||||
|
||||
@abstractmethod
|
||||
def get_feature_importance(self) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
Returns
|
||||
-------
|
||||
The index is the feature name.
|
||||
|
||||
The greater the value, the higher importance.
|
||||
"""
|
||||
|
||||
|
||||
class LightGBMFInt(FeatureInt):
|
||||
"""LightGBM (F)eature (Int)erpreter"""
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
Notes
|
||||
-----
|
||||
parameters reference:
|
||||
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
|
||||
"""
|
||||
return pd.Series(self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()).sort_values(
|
||||
ascending=False
|
||||
)
|
||||
@@ -1,6 +1,4 @@
|
||||
import sys
|
||||
import unittest
|
||||
from ..utils import exists_qlib_data
|
||||
from .data import GetData
|
||||
from .. import init
|
||||
from ..config import REG_CN
|
||||
@@ -9,19 +7,18 @@ from ..config import REG_CN
|
||||
class TestAutoData(unittest.TestCase):
|
||||
|
||||
_setup_kwargs = {}
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple",
|
||||
region="cn",
|
||||
interval="1d",
|
||||
target_dir=provider_uri,
|
||||
delete_old=False,
|
||||
)
|
||||
init(provider_uri=provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple",
|
||||
region=REG_CN,
|
||||
interval="1d",
|
||||
target_dir=cls.provider_uri,
|
||||
delete_old=False,
|
||||
exists_skip=True,
|
||||
)
|
||||
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
|
||||
108
qlib/tests/config.py
Normal file
108
qlib/tests/config.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
CSI300_MARKET = "csi300"
|
||||
CSI100_MARKET = "csi100"
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
|
||||
DATASET_ALPHA158_CLASS = "Alpha158"
|
||||
DATASET_ALPHA360_CLASS = "Alpha360"
|
||||
|
||||
###################################
|
||||
# config
|
||||
###################################
|
||||
|
||||
|
||||
GBDT_MODEL = {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
RECORD_CONFIG = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_data_handler_config(market=CSI300_MARKET):
|
||||
return {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
|
||||
return {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": dataset_class,
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": get_data_handler_config(market),
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_gbdt_task(market=CSI300_MARKET):
|
||||
return {
|
||||
"model": GBDT_MODEL,
|
||||
"dataset": get_dataset_config(market),
|
||||
}
|
||||
|
||||
|
||||
def get_record_lgb_config(market=CSI300_MARKET):
|
||||
return {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_record_xgboost_config(market=CSI300_MARKET):
|
||||
return {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
|
||||
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
|
||||
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
|
||||
@@ -10,6 +10,7 @@ import datetime
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
class GetData:
|
||||
@@ -112,6 +113,7 @@ class GetData:
|
||||
interval="1d",
|
||||
region="cn",
|
||||
delete_old=True,
|
||||
exists_skip=False,
|
||||
):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
@@ -129,6 +131,8 @@ class GetData:
|
||||
data region, value from [cn, us], by default cn
|
||||
delete_old: bool
|
||||
delete an existing directory, by default True
|
||||
exists_skip: bool
|
||||
exists skip, by default False
|
||||
|
||||
Examples
|
||||
---------
|
||||
@@ -140,6 +144,13 @@ class GetData:
|
||||
-------
|
||||
|
||||
"""
|
||||
if exists_skip and exists_qlib_data(target_dir):
|
||||
logger.warning(
|
||||
f"Data already exists: {target_dir}, the data download will be skipped\n"
|
||||
f"\tIf downloading is required: `exists_skip=False` or `change target_dir`"
|
||||
)
|
||||
return
|
||||
|
||||
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))
|
||||
|
||||
def _get_file_name(v):
|
||||
|
||||
@@ -668,7 +668,10 @@ def exists_qlib_data(qlib_dir):
|
||||
return False
|
||||
# check calendar bin
|
||||
for _calendar in calendars_dir.iterdir():
|
||||
if not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin")):
|
||||
|
||||
if ("_future" not in _calendar.name) and (
|
||||
not list(features_dir.rglob(f"*.{_calendar.name.split('.')[0]}.bin"))
|
||||
):
|
||||
return False
|
||||
|
||||
# check instruments
|
||||
|
||||
@@ -121,7 +121,7 @@ df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
### Help
|
||||
```bash
|
||||
pythono collector.py collector_data --help
|
||||
python collector.py collector_data --help
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
@@ -191,7 +191,7 @@ class YahooCollector(BaseCollector):
|
||||
|
||||
class YahooCollectorCN(YahooCollector, ABC):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get HS stock symbos......")
|
||||
logger.info("get HS stock symbols......")
|
||||
symbols = get_hs_stock_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
@@ -581,7 +581,6 @@ class Run(BaseRun):
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
@@ -593,8 +592,6 @@ class Run(BaseRun):
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
@@ -611,8 +608,9 @@ class Run(BaseRun):
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
||||
super(Run, self).download_data(
|
||||
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
|
||||
)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
@@ -5,5 +5,4 @@ numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
loguru
|
||||
yahooquery
|
||||
|
||||
@@ -120,7 +120,7 @@ class DumpDataBase:
|
||||
else:
|
||||
df = file_or_df
|
||||
if df.empty or self.date_field_name not in df.columns.tolist():
|
||||
_calendars = pd.Series()
|
||||
_calendars = pd.Series(dtype=np.float32)
|
||||
else:
|
||||
_calendars = df[self.date_field_name]
|
||||
|
||||
|
||||
@@ -1,26 +1,10 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.config import REG_CN
|
||||
import unittest
|
||||
import numpy as np
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.data import D
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class TestDataset(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
class TestDataset(TestAutoData):
|
||||
def testCSI300(self):
|
||||
close_p = D.features(D.instruments("csi300"), ["$close"])
|
||||
size = close_p.groupby("datetime").size()
|
||||
|
||||
171
tests/storage_tests/test_storage.py
Normal file
171
tests/storage_tests/test_storage.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from collections.abc import Iterable
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from qlib.data.storage.file_storage import (
|
||||
FileCalendarStorage as CalendarStorage,
|
||||
FileInstrumentStorage as InstrumentStorage,
|
||||
FileFeatureStorage as FeatureStorage,
|
||||
)
|
||||
|
||||
_file_name = Path(__file__).name.split(".")[0]
|
||||
DATA_DIR = Path(__file__).parent.joinpath(f"{_file_name}_data")
|
||||
QLIB_DIR = DATA_DIR.joinpath("qlib")
|
||||
QLIB_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
class TestStorage(TestAutoData):
|
||||
def test_calendar_storage(self):
|
||||
|
||||
calendar = CalendarStorage(freq="day", future=False, provider_uri=self.provider_uri)
|
||||
assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable"
|
||||
assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable"
|
||||
|
||||
print(f"calendar[1: 5]: {calendar[1:5]}")
|
||||
print(f"calendar[0]: {calendar[0]}")
|
||||
print(f"calendar[-1]: {calendar[-1]}")
|
||||
|
||||
calendar = CalendarStorage(freq="1min", future=False, provider_uri="not_found")
|
||||
with pytest.raises(ValueError):
|
||||
print(calendar.data)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
print(calendar[:])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
print(calendar[0])
|
||||
|
||||
def test_instrument_storage(self):
|
||||
"""
|
||||
The meaning of instrument, such as CSI500:
|
||||
|
||||
CSI500 composition changes:
|
||||
|
||||
date add remove
|
||||
2005-01-01 SH600000
|
||||
2005-01-01 SH600001
|
||||
2005-01-01 SH600002
|
||||
2005-02-01 SH600003 SH600000
|
||||
2005-02-15 SH600000 SH600002
|
||||
|
||||
Calendar:
|
||||
pd.date_range(start="2020-01-01", stop="2020-03-01", freq="1D")
|
||||
|
||||
Instrument:
|
||||
symbol start_time end_time
|
||||
SH600000 2005-01-01 2005-01-31 (2005-02-01 Last trading day)
|
||||
SH600000 2005-02-15 2005-03-01
|
||||
SH600001 2005-01-01 2005-03-01
|
||||
SH600002 2005-01-01 2005-02-14 (2005-02-15 Last trading day)
|
||||
SH600003 2005-02-01 2005-03-01
|
||||
|
||||
InstrumentStorage:
|
||||
{
|
||||
"SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],
|
||||
"SH600001": [(2005-01-01, 2005-03-01)],
|
||||
"SH600002": [(2005-01-01, 2005-02-14)],
|
||||
"SH600003": [(2005-02-01, 2005-03-01)],
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri)
|
||||
|
||||
for inst, spans in instrument.data.items():
|
||||
assert isinstance(inst, str) and isinstance(
|
||||
spans, Iterable
|
||||
), f"{instrument.__class__.__name__} value is not Iterable"
|
||||
for s_e in spans:
|
||||
assert (
|
||||
isinstance(s_e, tuple) and len(s_e) == 2
|
||||
), f"{instrument.__class__.__name__}.__getitem__(k) TypeError"
|
||||
|
||||
print(f"instrument['SH600000']: {instrument['SH600000']}")
|
||||
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
|
||||
with pytest.raises(ValueError):
|
||||
print(instrument.data)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
print(instrument["sSH600000"])
|
||||
|
||||
def test_feature_storage(self):
|
||||
"""
|
||||
Calendar:
|
||||
pd.date_range(start="2005-01-01", stop="2005-03-01", freq="1D")
|
||||
|
||||
Instrument:
|
||||
{
|
||||
"SH600000": [(2005-01-01, 2005-01-31), (2005-02-15, 2005-03-01)],
|
||||
"SH600001": [(2005-01-01, 2005-03-01)],
|
||||
"SH600002": [(2005-01-01, 2005-02-14)],
|
||||
"SH600003": [(2005-02-01, 2005-03-01)],
|
||||
}
|
||||
|
||||
Feature:
|
||||
Stock data(close):
|
||||
2005-01-01 ... 2005-02-01 ... 2005-02-14 2005-02-15 ... 2005-03-01
|
||||
SH600000 1 ... 3 ... 4 5 6
|
||||
SH600001 1 ... 4 ... 5 6 7
|
||||
SH600002 1 ... 5 ... 6 nan nan
|
||||
SH600003 nan ... 1 ... 2 3 4
|
||||
|
||||
FeatureStorage(SH600000, close):
|
||||
|
||||
[
|
||||
(calendar.index("2005-01-01"), 1),
|
||||
...,
|
||||
(calendar.index("2005-03-01"), 6)
|
||||
]
|
||||
|
||||
====> [(0, 1), ..., (59, 6)]
|
||||
|
||||
|
||||
FeatureStorage(SH600002, close):
|
||||
|
||||
[
|
||||
(calendar.index("2005-01-01"), 1),
|
||||
...,
|
||||
(calendar.index("2005-02-14"), 6)
|
||||
]
|
||||
|
||||
===> [(0, 1), ..., (44, 6)]
|
||||
|
||||
FeatureStorage(SH600003, close):
|
||||
|
||||
[
|
||||
(calendar.index("2005-02-01"), 1),
|
||||
...,
|
||||
(calendar.index("2005-03-01"), 4)
|
||||
]
|
||||
|
||||
===> [(31, 1), ..., (59, 4)]
|
||||
|
||||
"""
|
||||
|
||||
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri=self.provider_uri)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
print(feature[0])
|
||||
assert isinstance(
|
||||
feature[815][1], (float, np.float32)
|
||||
), f"{feature.__class__.__name__}.__getitem__(i: int) error"
|
||||
assert len(feature[815:818]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error"
|
||||
print(f"feature[815: 818]: \n{feature[815: 818]}")
|
||||
|
||||
print(f"feature[:].tail(): \n{feature[:].tail()}")
|
||||
|
||||
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount")
|
||||
|
||||
assert feature[0] == (None, None), "FeatureStorage does not exist, feature[i] should return `(None, None)`"
|
||||
assert feature[:].empty, "FeatureStorage does not exist, feature[:] should return `pd.Series(dtype=np.float32)`"
|
||||
assert (
|
||||
feature.data.empty
|
||||
), "FeatureStorage does not exist, feature.data should return `pd.Series(dtype=np.float32)`"
|
||||
@@ -12,55 +12,7 @@ from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
@@ -75,7 +27,7 @@ port_analysis_config = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
@@ -96,15 +48,15 @@ def train():
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
# To test __repr__
|
||||
print(dataset)
|
||||
print(R)
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
|
||||
# prediction
|
||||
@@ -137,12 +89,12 @@ def train_with_sigana():
|
||||
performance: dict
|
||||
model performance
|
||||
"""
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow_with_sigana"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
|
||||
# predict and calculate ic and ric
|
||||
@@ -171,7 +123,7 @@ def fake_experiment():
|
||||
default_uri = R.get_uri()
|
||||
current_uri = "file:./temp-test-exp-mag"
|
||||
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
|
||||
current_uri_to_check = R.get_uri()
|
||||
default_uri_to_check = R.get_uri()
|
||||
|
||||
@@ -1,73 +1,22 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
from qlib.config import C
|
||||
from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
def train_multiseg():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = MultiSegRecord(model, dataset, recorder)
|
||||
@@ -77,10 +26,10 @@ def train_multiseg():
|
||||
|
||||
|
||||
def train_mse():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalMseRecord(recorder, model=model, dataset=dataset)
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
DATA_DIR = Path(__file__).parent.joinpath("test_get_data")
|
||||
SOURCE_DIR = DATA_DIR.joinpath("source")
|
||||
@@ -37,7 +34,9 @@ class TestGetData(unittest.TestCase):
|
||||
|
||||
def test_0_qlib_data(self):
|
||||
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False)
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False, exists_skip=True
|
||||
)
|
||||
df = D.features(D.instruments("csi300"), self.FIELDS)
|
||||
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
|
||||
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class Diff(ElemOperator):
|
||||
|
||||
Reference in New Issue
Block a user