1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

bug fixed & examples fire

This commit is contained in:
lzh222333
2021-04-07 03:33:27 +00:00
parent 431a9c92c1
commit cb42e99bee
10 changed files with 250 additions and 232 deletions

View File

@@ -97,8 +97,8 @@ def task_generating():
def task_training(tasks, task_pool, exp_name):
trainer = TrainerRM()
trainer.train(tasks, exp_name, task_pool)
trainer = TrainerRM(exp_name, task_pool)
trainer.train(tasks)
# This part corresponds to "Task Collecting" in the document
@@ -119,7 +119,7 @@ def task_collecting(task_pool, exp_name):
return False
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(),
)
print(artifact)
@@ -128,7 +128,7 @@ def main(
provider_uri="~/.qlib/qlib_data/cn_data",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
exp_name="rolling_exp",
experiment_name="rolling_exp",
task_pool="rolling_task",
):
mongo_conf = {
@@ -137,11 +137,13 @@ def main(
}
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
# reset(task_pool, exp_name)
# tasks = task_generating()
# task_training(tasks, task_pool, exp_name)
task_collecting(task_pool, exp_name)
reset(task_pool, experiment_name)
tasks = task_generating()
task_training(tasks, task_pool, experiment_name)
task_collecting(task_pool, experiment_name)
if __name__ == "__main__":
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire()

View File

@@ -70,89 +70,106 @@ task_xgboost_config = {
"record": record_config,
}
class RollingOnlineExample:
def print_online_model():
print("========== print_online_model ==========")
print("Current 'online' model:")
for rid, rec in list_recorders(exp_name).items():
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.ONLINE_TAG:
print(rid)
print("Current 'next online' model:")
for rid, rec in list_recorders(exp_name).items():
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.NEXT_ONLINE_TAG:
print(rid)
def __init__(self, exp_name="rolling_exp", task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550):
self.exp_name = exp_name
self.task_pool = task_pool
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
self.trainer = TrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool)
self.rolling_online_manager = RollingOnlineManager(experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer)
def print_online_model(self):
print("========== print_online_model ==========")
print("Current 'online' model:")
for rid, rec in list_recorders(self.exp_name).items():
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.ONLINE_TAG:
print(rid)
print("Current 'next online' model:")
for rid, rec in list_recorders(self.exp_name).items():
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG:
print(rid)
# This part corresponds to "Task Generating" in the document
def task_generating():
# This part corresponds to "Task Generating" in the document
def task_generating(self):
print("========== task_generating ==========")
print("========== task_generating ==========")
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=rolling_gen, # generate different date segment
)
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=self.rolling_gen, # generate different date segment
)
pprint(tasks)
pprint(tasks)
return tasks
return tasks
def task_training(tasks):
trainer.train(tasks, exp_name, task_pool)
def task_training(self, tasks):
self.trainer.train(tasks)
# This part corresponds to "Task Collecting" in the document
def task_collecting():
print("========== task_collecting ==========")
# This part corresponds to "Task Collecting" in the document
def task_collecting(self):
print("========== task_collecting ==========")
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
)
print(artifact)
artifact = ens_workflow(
RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup()
)
print(artifact)
# Reset all things to the first status, be careful to save important data
def reset():
print("========== reset ==========")
task_manager.remove()
exp = R.get_exp(experiment_name=exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
self.task_manager.remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Run this firstly to see the workflow in Task Management
def first_run():
print("========== first_run ==========")
reset()
# Run this firstly to see the workflow in Task Management
def first_run(self):
print("========== first_run ==========")
self.reset()
tasks = task_generating()
task_training(tasks)
task_collecting()
tasks = self.task_generating()
self.task_training(tasks)
self.task_collecting()
latest_rec, _ = rolling_online_manager.list_latest_recorders()
rolling_online_manager.reset_online_tag(latest_rec.values())
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
def routine():
print("========== routine ==========")
print_online_model()
rolling_online_manager.routine()
print_online_model()
task_collecting()
def routine(self):
print("========== routine ==========")
self.print_online_model()
self.rolling_online_manager.routine()
self.print_online_model()
self.task_collecting()
if __name__ == "__main__":
@@ -161,26 +178,7 @@ if __name__ == "__main__":
####### to update the models and predictions after the trading time, use the command below
# python task_manager_rolling_with_updating.py after_day
#################### you need to finish the configurations below #########################
provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir
mongo_conf = {
"task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url
"task_db_name": "rolling_db", # database name
}
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
rolling_step = 550
##########################################################################################
rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
task_manager = TaskManager(task_pool=task_pool)
trainer = TrainerRM()
rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=rolling_gen, task_manager=task_manager, trainer=trainer
)
fire.Fire()
####### to define your own parameters, use `--`
# python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40
fire.Fire(RollingOnlineExample)

View File

@@ -54,10 +54,10 @@ task = {
def first_train(experiment_name="online_srv"):
rid = task_train(task_config=task, experiment_name=experiment_name)
rec = task_train(task_config=task, experiment_name=experiment_name)
online_manager = OnlineManagerR(experiment_name)
online_manager.reset_online_tag(rid)
online_manager.reset_online_tag(rec)
def update_online_pred(experiment_name="online_srv"):
@@ -71,13 +71,17 @@ def update_online_pred(experiment_name="online_srv"):
online_manager.update_online_pred()
def main(provider_uri = "~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=region)
first_train(experiment_name)
update_online_pred(experiment_name)
if __name__ == "__main__":
## to train a model and set it to online model, use the command below
# python update_online_pred.py first_train
## to update online predictions once a day, use the command below
# python update_online_pred.py update_online_pred
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=REG_CN)
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire()