mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
bug fixed & examples fire
This commit is contained in:
@@ -97,8 +97,8 @@ def task_generating():
|
||||
|
||||
|
||||
def task_training(tasks, task_pool, exp_name):
|
||||
trainer = TrainerRM()
|
||||
trainer.train(tasks, exp_name, task_pool)
|
||||
trainer = TrainerRM(exp_name, task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
@@ -119,7 +119,7 @@ def task_collecting(task_pool, exp_name):
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(),
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
@@ -128,7 +128,7 @@ def main(
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
exp_name="rolling_exp",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
):
|
||||
mongo_conf = {
|
||||
@@ -137,11 +137,13 @@ def main(
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
# reset(task_pool, exp_name)
|
||||
# tasks = task_generating()
|
||||
# task_training(tasks, task_pool, exp_name)
|
||||
task_collecting(task_pool, exp_name)
|
||||
reset(task_pool, experiment_name)
|
||||
tasks = task_generating()
|
||||
task_training(tasks, task_pool, experiment_name)
|
||||
task_collecting(task_pool, experiment_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire()
|
||||
|
||||
Reference in New Issue
Block a user