1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

bug fixed & examples fire

This commit is contained in:
lzh222333
2021-04-07 03:33:27 +00:00
parent 431a9c92c1
commit cb42e99bee
10 changed files with 250 additions and 232 deletions

View File

@@ -97,8 +97,8 @@ def task_generating():
def task_training(tasks, task_pool, exp_name):
trainer = TrainerRM()
trainer.train(tasks, exp_name, task_pool)
trainer = TrainerRM(exp_name, task_pool)
trainer.train(tasks)
# This part corresponds to "Task Collecting" in the document
@@ -119,7 +119,7 @@ def task_collecting(task_pool, exp_name):
return False
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(),
)
print(artifact)
@@ -128,7 +128,7 @@ def main(
provider_uri="~/.qlib/qlib_data/cn_data",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
exp_name="rolling_exp",
experiment_name="rolling_exp",
task_pool="rolling_task",
):
mongo_conf = {
@@ -137,11 +137,13 @@ def main(
}
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
# reset(task_pool, exp_name)
# tasks = task_generating()
# task_training(tasks, task_pool, exp_name)
task_collecting(task_pool, exp_name)
reset(task_pool, experiment_name)
tasks = task_generating()
task_training(tasks, task_pool, experiment_name)
task_collecting(task_pool, experiment_name)
if __name__ == "__main__":
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire()