mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
add test/config.py
This commit is contained in:
@@ -1,24 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import fire
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
from qlib.utils import init_instance_by_config, exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
@@ -96,9 +85,7 @@ class HighfreqWorkflow:
|
||||
# use yahoo_cn_1min data
|
||||
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
|
||||
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
qlib.init(**QLIB_INIT_CONFIG)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
|
||||
@@ -1,46 +1,9 @@
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
import optuna
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(scripts_dir))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
dataset_task = {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataset = init_instance_by_config(dataset_task["dataset"])
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.config import CSI300_DATASET_CONFIG
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
def objective(trial):
|
||||
@@ -65,12 +28,19 @@ def objective(trial):
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
|
||||
@@ -1,46 +1,11 @@
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
import optuna
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(scripts_dir))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
dataset_task = {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha360",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataset = init_instance_by_config(dataset_task["dataset"])
|
||||
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
|
||||
|
||||
|
||||
def objective(trial):
|
||||
@@ -72,5 +37,13 @@ def objective(trial):
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
dataset = init_instance_by_config(DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# config
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
# get model feature importance
|
||||
feature_importance = model.get_feature_importance()
|
||||
print("feature importance:")
|
||||
print(feature_importance)
|
||||
32
examples/model_interpreter/feature.py
Normal file
32
examples/model_interpreter/feature.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
# get model feature importance
|
||||
feature_importance = model.get_feature_importance()
|
||||
print("feature importance:")
|
||||
print(feature_importance)
|
||||
@@ -17,63 +17,7 @@ from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingTaskExample:
|
||||
@@ -85,11 +29,13 @@ class RollingTaskExample:
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_config=[task_xgboost_config, task_lgb_config],
|
||||
task_config=None,
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
):
|
||||
# TaskManager config
|
||||
if task_config is None:
|
||||
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
|
||||
@@ -13,63 +13,7 @@ from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -84,7 +28,7 @@ class OnlineSimulationExample:
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
tasks=None,
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
@@ -101,6 +45,8 @@ class OnlineSimulationExample:
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
|
||||
@@ -17,62 +17,7 @@ from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
@@ -83,9 +28,13 @@ class RollingOnlineExample:
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=[task_xgboost_config],
|
||||
add_tasks=[task_lgb_config],
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
|
||||
@@ -7,56 +7,19 @@ There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
import copy
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"record": {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,13 +4,11 @@
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
import pandas as pd
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
@@ -25,9 +23,7 @@ class RollingDataWorkflow:
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def _dump_pre_handler(self, path):
|
||||
|
||||
@@ -5,13 +5,11 @@ import os
|
||||
import sys
|
||||
import fire
|
||||
import time
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
import tempfile
|
||||
import traceback
|
||||
import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
@@ -23,8 +21,7 @@ from pprint import pprint
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.cli import workflow
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
# init qlib
|
||||
@@ -39,12 +36,8 @@ exp_manager = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
# decorator to check the arguments
|
||||
|
||||
@@ -1,82 +1,22 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
@@ -90,7 +30,7 @@ if __name__ == "__main__":
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
@@ -100,8 +40,8 @@ if __name__ == "__main__":
|
||||
}
|
||||
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
|
||||
# NOTE: This line is optional
|
||||
# It demonstrates that the dataset can be used standalone.
|
||||
@@ -110,7 +50,7 @@ if __name__ == "__main__":
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
|
||||
|
||||
@@ -14,7 +14,14 @@ class FeatureInt:
|
||||
|
||||
@abstractmethod
|
||||
def get_feature_importance(self) -> pd.Series:
|
||||
...
|
||||
"""get feature importance
|
||||
|
||||
Returns
|
||||
-------
|
||||
The index is the feature name.
|
||||
|
||||
The greater the value, the higher importance.
|
||||
"""
|
||||
|
||||
|
||||
class LightGBMFInt(FeatureInt):
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import sys
|
||||
import unittest
|
||||
from ..utils import exists_qlib_data
|
||||
from .data import GetData
|
||||
from .. import init
|
||||
from ..config import REG_CN
|
||||
@@ -14,14 +12,12 @@ class TestAutoData(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
if not exists_qlib_data(cls.provider_uri):
|
||||
print(f"Qlib data is not found in {cls.provider_uri}")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple",
|
||||
region="cn",
|
||||
interval="1d",
|
||||
target_dir=cls.provider_uri,
|
||||
delete_old=False,
|
||||
)
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple",
|
||||
region=REG_CN,
|
||||
interval="1d",
|
||||
target_dir=cls.provider_uri,
|
||||
delete_old=False,
|
||||
)
|
||||
init(provider_uri=cls.provider_uri, region=REG_CN, **cls._setup_kwargs)
|
||||
|
||||
108
qlib/tests/config.py
Normal file
108
qlib/tests/config.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
CSI300_MARKET = "csi300"
|
||||
CSI100_MARKET = "csi100"
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
|
||||
DATASET_ALPHA158_CLASS = "Alpha158"
|
||||
DATASET_ALPHA360_CLASS = "Alpha360"
|
||||
|
||||
###################################
|
||||
# config
|
||||
###################################
|
||||
|
||||
|
||||
GBDT_MODEL = {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
RECORD_CONFIG = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_data_handler_config(market=CSI300_MARKET):
|
||||
return {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
|
||||
return {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": dataset_class,
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": get_data_handler_config(market),
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_gbdt_task(market=CSI300_MARKET):
|
||||
return {
|
||||
"model": GBDT_MODEL,
|
||||
"dataset": get_dataset_config(market),
|
||||
}
|
||||
|
||||
|
||||
def get_record_lgb_config(market=CSI300_MARKET):
|
||||
return {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_record_xgboost_config(market=CSI300_MARKET):
|
||||
return {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
|
||||
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
|
||||
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
|
||||
@@ -10,6 +10,7 @@ import datetime
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
class GetData:
|
||||
@@ -112,6 +113,7 @@ class GetData:
|
||||
interval="1d",
|
||||
region="cn",
|
||||
delete_old=True,
|
||||
exists_skip=True,
|
||||
):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
@@ -129,6 +131,8 @@ class GetData:
|
||||
data region, value from [cn, us], by default cn
|
||||
delete_old: bool
|
||||
delete an existing directory, by default True
|
||||
exists_skip: bool
|
||||
exists skip, by default True
|
||||
|
||||
Examples
|
||||
---------
|
||||
@@ -140,6 +144,9 @@ class GetData:
|
||||
-------
|
||||
|
||||
"""
|
||||
if exists_skip and exists_qlib_data(target_dir):
|
||||
return
|
||||
|
||||
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))
|
||||
|
||||
def _get_file_name(v):
|
||||
|
||||
@@ -1,26 +1,10 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.config import REG_CN
|
||||
import unittest
|
||||
import numpy as np
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.data import D
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class TestDataset(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
class TestDataset(TestAutoData):
|
||||
def testCSI300(self):
|
||||
close_p = D.features(D.instruments("csi300"), ["$close"])
|
||||
size = close_p.groupby("datetime").size()
|
||||
|
||||
@@ -12,55 +12,7 @@ from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
@@ -75,7 +27,7 @@ port_analysis_config = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
@@ -96,15 +48,15 @@ def train():
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
# To test __repr__
|
||||
print(dataset)
|
||||
print(R)
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
|
||||
# prediction
|
||||
@@ -137,12 +89,12 @@ def train_with_sigana():
|
||||
performance: dict
|
||||
model performance
|
||||
"""
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow_with_sigana"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
|
||||
# predict and calculate ic and ric
|
||||
@@ -171,7 +123,7 @@ def fake_experiment():
|
||||
default_uri = R.get_uri()
|
||||
current_uri = "file:./temp-test-exp-mag"
|
||||
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
|
||||
current_uri_to_check = R.get_uri()
|
||||
default_uri_to_check = R.get_uri()
|
||||
|
||||
@@ -1,73 +1,22 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
from qlib.config import C
|
||||
from qlib.contrib.workflow import MultiSegRecord, SignalMseRecord
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
def train_multiseg():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = MultiSegRecord(model, dataset, recorder)
|
||||
@@ -77,10 +26,10 @@ def train_multiseg():
|
||||
|
||||
|
||||
def train_mse():
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalMseRecord(recorder, model=model, dataset=dataset)
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
DATA_DIR = Path(__file__).parent.joinpath("test_get_data")
|
||||
SOURCE_DIR = DATA_DIR.joinpath("source")
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.data.ops import ElemOperator, PairOperator
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class Diff(ElemOperator):
|
||||
|
||||
Reference in New Issue
Block a user