1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00

Consider more situations about task_config.

Save the "param" file which is collect.py need.
This commit is contained in:
lzh222333
2021-03-03 11:25:37 +08:00
parent c4733f601f
commit b84156fde8

View File

@@ -27,16 +27,22 @@ def task_train(task_config: dict, experiment_name):
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
R.save_objects(param=task_config) # keep the original format and datatype
# generate records: prediction, backtest, and analysis
for record in task_config["record"]:
records = task_config.get('record', [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
if record["class"] == SignalRecord.__name__:
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record.setdefault("kwargs", {})
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)
sr.generate()
else:
rconf = {"recorder": recorder}
record.setdefault("kwargs", {})
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()