From d6984a3f2de2d1f007dbd54c129638fd17f48352 Mon Sep 17 00:00:00 2001 From: xixi <920435730@qq.com> Date: Sat, 19 Jun 2021 17:32:28 +0800 Subject: [PATCH] fill_placehorder --- qlib/model/trainer.py | 37 ++++++++++++++++++---- test.yaml | 72 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 test.yaml diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 28d854477..44a7e56d2 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -45,6 +45,35 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str return recorder +def fill_placeholder(kwargs, model, dataset): + """ + Detect placeholder( and ) in dict and fill them. + + Args: + kwargs (Dict): the parameter dict will be filled + model (Model): fill + dataset (Dataset): fill + + Returns: + Dict: the parameter dict + """ + top = 0 + tail = 1 + dict_quene = [kwargs] + while(top < tail): + now_dict = dict_quene[top] + top += 1 + for key in now_dict.keys(): + if(isinstance(now_dict[key], dict)): + dict_quene.append(now_dict[key]) + tail += 1 + elif(now_dict[key] == ""): + now_dict[key] = model + elif(now_dict[key] == ""): + now_dict[key] = dataset + return kwargs + + def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: """ Finish task training with real model fitting and saving. @@ -73,13 +102,9 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: records = [records] for record in records: cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp") - if cls is SignalRecord: - rconf = {"model": model, "dataset": dataset, "recorder": rec} - else: - rconf = {"recorder": rec} - r = cls(**kwargs, **rconf) + kwargs = fill_placeholder(kwargs, model, dataset) + r = cls(**kwargs, **{"record", record}) r.generate() - return rec diff --git a/test.yaml b/test.yaml new file mode 100644 index 000000000..c8287cf36 --- /dev/null +++ b/test.yaml @@ -0,0 +1,72 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + model: + dataset: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: LGBModel + module_path: qlib.contrib.model.gbdt + kwargs: + loss: mse + colsample_bytree: 0.8879 + learning_rate: 0.2 + subsample: 0.8789 + lambda_l1: 205.6999 + lambda_l2: 580.9768 + max_depth: 8 + num_leaves: 210 + num_threads: 20 + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config \ No newline at end of file