mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
Consider more situations about task_config.
Save the "param" file which is collect.py need.
This commit is contained in:
@@ -27,16 +27,22 @@ def task_train(task_config: dict, experiment_name):
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
R.save_objects(param=task_config) # keep the original format and datatype
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
for record in task_config["record"]:
|
||||
records = task_config.get('record', [])
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
if record["class"] == SignalRecord.__name__:
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record.setdefault("kwargs", {})
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record.setdefault("kwargs", {})
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||
|
||||
Reference in New Issue
Block a user