mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
bug fixed & examples fire
This commit is contained in:
@@ -97,8 +97,8 @@ def task_generating():
|
||||
|
||||
|
||||
def task_training(tasks, task_pool, exp_name):
|
||||
trainer = TrainerRM()
|
||||
trainer.train(tasks, exp_name, task_pool)
|
||||
trainer = TrainerRM(exp_name, task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
@@ -119,7 +119,7 @@ def task_collecting(task_pool, exp_name):
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(),
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
@@ -128,7 +128,7 @@ def main(
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
exp_name="rolling_exp",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
):
|
||||
mongo_conf = {
|
||||
@@ -137,11 +137,13 @@ def main(
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
# reset(task_pool, exp_name)
|
||||
# tasks = task_generating()
|
||||
# task_training(tasks, task_pool, exp_name)
|
||||
task_collecting(task_pool, exp_name)
|
||||
reset(task_pool, experiment_name)
|
||||
tasks = task_generating()
|
||||
task_training(tasks, task_pool, experiment_name)
|
||||
task_collecting(task_pool, experiment_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire()
|
||||
|
||||
@@ -70,89 +70,106 @@ task_xgboost_config = {
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
class RollingOnlineExample:
|
||||
|
||||
def print_online_model():
|
||||
print("========== print_online_model ==========")
|
||||
print("Current 'online' model:")
|
||||
for rid, rec in list_recorders(exp_name).items():
|
||||
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.ONLINE_TAG:
|
||||
print(rid)
|
||||
print("Current 'next online' model:")
|
||||
for rid, rec in list_recorders(exp_name).items():
|
||||
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.NEXT_ONLINE_TAG:
|
||||
print(rid)
|
||||
def __init__(self, exp_name="rolling_exp", task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550):
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool)
|
||||
self.rolling_online_manager = RollingOnlineManager(experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer)
|
||||
|
||||
|
||||
|
||||
def print_online_model(self):
|
||||
print("========== print_online_model ==========")
|
||||
print("Current 'online' model:")
|
||||
for rid, rec in list_recorders(self.exp_name).items():
|
||||
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.ONLINE_TAG:
|
||||
print(rid)
|
||||
print("Current 'next online' model:")
|
||||
for rid, rec in list_recorders(self.exp_name).items():
|
||||
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG:
|
||||
print(rid)
|
||||
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating():
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating(self):
|
||||
|
||||
print("========== task_generating ==========")
|
||||
print("========== task_generating ==========")
|
||||
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=rolling_gen, # generate different date segment
|
||||
)
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=self.rolling_gen, # generate different date segment
|
||||
)
|
||||
|
||||
pprint(tasks)
|
||||
pprint(tasks)
|
||||
|
||||
return tasks
|
||||
return tasks
|
||||
|
||||
|
||||
def task_training(tasks):
|
||||
trainer.train(tasks, exp_name, task_pool)
|
||||
def task_training(self, tasks):
|
||||
self.trainer.train(tasks)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting():
|
||||
print("========== task_collecting ==========")
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
|
||||
)
|
||||
print(artifact)
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup()
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset():
|
||||
print("========== reset ==========")
|
||||
task_manager.remove()
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
self.task_manager.remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run():
|
||||
print("========== first_run ==========")
|
||||
reset()
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run(self):
|
||||
print("========== first_run ==========")
|
||||
self.reset()
|
||||
|
||||
tasks = task_generating()
|
||||
task_training(tasks)
|
||||
task_collecting()
|
||||
tasks = self.task_generating()
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
latest_rec, _ = rolling_online_manager.list_latest_recorders()
|
||||
rolling_online_manager.reset_online_tag(latest_rec.values())
|
||||
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
|
||||
self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
|
||||
|
||||
|
||||
def routine():
|
||||
print("========== routine ==========")
|
||||
print_online_model()
|
||||
rolling_online_manager.routine()
|
||||
print_online_model()
|
||||
task_collecting()
|
||||
def routine(self):
|
||||
print("========== routine ==========")
|
||||
self.print_online_model()
|
||||
self.rolling_online_manager.routine()
|
||||
self.print_online_model()
|
||||
self.task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -161,26 +178,7 @@ if __name__ == "__main__":
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python task_manager_rolling_with_updating.py after_day
|
||||
|
||||
#################### you need to finish the configurations below #########################
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir
|
||||
mongo_conf = {
|
||||
"task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url
|
||||
"task_db_name": "rolling_db", # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
|
||||
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
|
||||
rolling_step = 550
|
||||
|
||||
##########################################################################################
|
||||
rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
task_manager = TaskManager(task_pool=task_pool)
|
||||
trainer = TrainerRM()
|
||||
rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=rolling_gen, task_manager=task_manager, trainer=trainer
|
||||
)
|
||||
|
||||
fire.Fire()
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
|
||||
@@ -54,10 +54,10 @@ task = {
|
||||
|
||||
def first_train(experiment_name="online_srv"):
|
||||
|
||||
rid = task_train(task_config=task, experiment_name=experiment_name)
|
||||
rec = task_train(task_config=task, experiment_name=experiment_name)
|
||||
|
||||
online_manager = OnlineManagerR(experiment_name)
|
||||
online_manager.reset_online_tag(rid)
|
||||
online_manager.reset_online_tag(rec)
|
||||
|
||||
|
||||
def update_online_pred(experiment_name="online_srv"):
|
||||
@@ -71,13 +71,17 @@ def update_online_pred(experiment_name="online_srv"):
|
||||
|
||||
online_manager.update_online_pred()
|
||||
|
||||
def main(provider_uri = "~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
first_train(experiment_name)
|
||||
update_online_pred(experiment_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
## to update online predictions once a day, use the command below
|
||||
# python update_online_pred.py update_online_pred
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire()
|
||||
|
||||
Reference in New Issue
Block a user