From 3e04ded750e2735af05b54d100d58525083cf9b2 Mon Sep 17 00:00:00 2001 From: Jactus Date: Mon, 16 Nov 2020 17:29:26 +0800 Subject: [PATCH] Add initial workflow_by_config --- examples/workflow_by_code.py | 27 ++++++------------ examples/workflow_by_config.py | 49 ++++++++++++++++++++++++++++++++ examples/workflow_config.yaml | 52 ++++++++++++++++++++++++++++++++++ setup.py | 1 + 4 files changed, 110 insertions(+), 19 deletions(-) create mode 100644 examples/workflow_by_config.py create mode 100644 examples/workflow_config.yaml diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 98cd1f928..cae890672 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -46,15 +46,6 @@ if __name__ == "__main__": "instruments": MARKET, } - TRAINER_CONFIG = { - "train_start_time": "2008-01-01", - "train_end_time": "2014-12-31", - "validate_start_time": "2015-01-01", - "validate_end_time": "2016-12-31", - "test_start_time": "2017-01-01", - "test_end_time": "2020-08-01", - } - task = { "model": { "class": "LGBModel", @@ -82,14 +73,8 @@ if __name__ == "__main__": }, "segments": { "train": ("2008-01-01", "2014-12-31"), - "valid": ( - "2015-01-01", - "2016-12-31", - ), - "test": ( - "2017-01-01", - "2020-08-01", - ), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), }, }, }, @@ -99,8 +84,12 @@ if __name__ == "__main__": port_analysis_config = { "strategy": { - "topk": 50, - "n_drop": 5, + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.strategy", + "kwargs": { + "topk": 50, + "n_drop": 5, + } }, "backtest": { "verbose": False, diff --git a/examples/workflow_by_config.py b/examples/workflow_by_config.py new file mode 100644 index 000000000..7955d29d0 --- /dev/null +++ b/examples/workflow_by_config.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from pathlib import Path + +import qlib +import fire +import yaml +import pandas as pd +from qlib.config import REG_CN +from qlib.utils import exists_qlib_data, init_instance_by_config +from qlib.workflow import R +from qlib.workflow.record_temp import SignalRecord, PortAnaRecord + +# worflow handler function +def workflow(config_path): + with open(config_path) as fp: + config = yaml.load(fp, Loader=yaml.FullLoader) + + provider_uri = config.get("PROVIDER_URI") + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) + from get_data import GetData + + GetData().qlib_data_cn(target_dir=provider_uri) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + # model initiaiton + model = init_instance_by_config(config.get("TASK")["model"]) + dataset = init_instance_by_config(config.get("TASK")["dataset"]) + + # start exp + with R.start("workflow"): + model.fit(dataset) + + # prediction + recorder = R.get_recorder() + sr = SignalRecord(model, dataset, recorder) + sr.generate() + + # backtest + par = PortAnaRecord(recorder, config.get("PORT_ANALYSIS_CONFIG")) + par.generate() + +if __name__ == "__main__": + fire.Fire(workflow) \ No newline at end of file diff --git a/examples/workflow_config.yaml b/examples/workflow_config.yaml new file mode 100644 index 000000000..2698423df --- /dev/null +++ b/examples/workflow_config.yaml @@ -0,0 +1,52 @@ +PROVIDER_URI: "~/.qlib/qlib_data/cn_data" +MARKET: &market csi300 +BENCHMARK: &benchmark SH000300 +DATA_HANDLER_CONFIG: &data_handerler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market +TASK: + model: + class: LGBModel + module_path: qlib.contrib.model.gbdt + kwargs: + loss: mse + colsample_bytree: 0.8879 + learning_rate: 0.0421 + subsample: 0.8789 + lambda_l1: 205.6999 + lambda_l2: 580.9768 + max_depth: 8 + num_leaves: 210 + num_threads: 20 + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handerler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: [SignalRecord, PortAnaRecord] +PORT_ANALYSIS_CONFIG: + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 \ No newline at end of file diff --git a/setup.py b/setup.py index 22e806d8d..38a84ef7c 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ REQUIRED = [ "lightgbm", "tornado", "joblib>=0.17.0", + "fire>=0.3.1", ] # Numpy include