diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index f0bc0b780..71cf9061f 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -27,16 +27,22 @@ def task_train(task_config: dict, experiment_name): model.fit(dataset) recorder = R.get_recorder() R.save_objects(**{"params.pkl": model}) + R.save_objects(param=task_config) # keep the original format and datatype # generate records: prediction, backtest, and analysis - for record in task_config["record"]: + records = task_config.get('record', []) + if isinstance(records, dict): # prevent only one dict + records = [records] + for record in records: if record["class"] == SignalRecord.__name__: srconf = {"model": model, "dataset": dataset, "recorder": recorder} + record.setdefault("kwargs", {}) record["kwargs"].update(srconf) sr = init_instance_by_config(record) sr.generate() else: rconf = {"recorder": recorder} + record.setdefault("kwargs", {}) record["kwargs"].update(rconf) ar = init_instance_by_config(record) ar.generate()