1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

format code

This commit is contained in:
lzh222333
2021-04-08 03:30:24 +00:00
parent 71605794a2
commit c20eb5c8a6
7 changed files with 29 additions and 53 deletions

View File

@@ -119,7 +119,8 @@ def task_collecting(task_pool, exp_name):
return False
artifact = ens_workflow(
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(),
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter),
RollingGroup(),
)
print(artifact)

View File

@@ -70,9 +70,18 @@ task_xgboost_config = {
"record": record_config,
}
class RollingOnlineExample:
def __init__(self, exp_name="rolling_exp", task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550):
class RollingOnlineExample:
def __init__(
self,
exp_name="rolling_exp",
task_pool="rolling_task",
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=550,
):
self.exp_name = exp_name
self.task_pool = task_pool
mongo_conf = {
@@ -84,9 +93,9 @@ class RollingOnlineExample:
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
self.trainer = TrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool)
self.rolling_online_manager = RollingOnlineManager(experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer)
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer
)
def print_online_model(self):
print("========== print_online_model ==========")
@@ -99,7 +108,6 @@ class RollingOnlineExample:
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG:
print(rid)
# This part corresponds to "Task Generating" in the document
def task_generating(self):
@@ -114,11 +122,9 @@ class RollingOnlineExample:
return tasks
def task_training(self, tasks):
self.trainer.train(tasks)
# This part corresponds to "Task Collecting" in the document
def task_collecting(self):
print("========== task_collecting ==========")
@@ -141,7 +147,6 @@ class RollingOnlineExample:
)
print(artifact)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
@@ -150,7 +155,6 @@ class RollingOnlineExample:
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Run this firstly to see the workflow in Task Management
def first_run(self):
print("========== first_run ==========")
@@ -163,7 +167,6 @@ class RollingOnlineExample:
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
def routine(self):
print("========== routine ==========")
self.print_online_model()
@@ -178,7 +181,7 @@ if __name__ == "__main__":
####### to update the models and predictions after the trading time, use the command below
# python task_manager_rolling_with_updating.py after_day
####### to define your own parameters, use `--`
# python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40
fire.Fire(RollingOnlineExample)

View File

@@ -71,12 +71,14 @@ def update_online_pred(experiment_name="online_srv"):
online_manager.update_online_pred()
def main(provider_uri = "~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
def main(provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=region)
first_train(experiment_name)
update_online_pred(experiment_name)
if __name__ == "__main__":
## to train a model and set it to online model, use the command below
# python update_online_pred.py first_train