mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
bug fixed & examples fire
This commit is contained in:
@@ -97,8 +97,8 @@ def task_generating():
|
||||
|
||||
|
||||
def task_training(tasks, task_pool, exp_name):
|
||||
trainer = TrainerRM()
|
||||
trainer.train(tasks, exp_name, task_pool)
|
||||
trainer = TrainerRM(exp_name, task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
@@ -119,7 +119,7 @@ def task_collecting(task_pool, exp_name):
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup(),
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
@@ -128,7 +128,7 @@ def main(
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
exp_name="rolling_exp",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
):
|
||||
mongo_conf = {
|
||||
@@ -137,11 +137,13 @@ def main(
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
# reset(task_pool, exp_name)
|
||||
# tasks = task_generating()
|
||||
# task_training(tasks, task_pool, exp_name)
|
||||
task_collecting(task_pool, exp_name)
|
||||
reset(task_pool, experiment_name)
|
||||
tasks = task_generating()
|
||||
task_training(tasks, task_pool, experiment_name)
|
||||
task_collecting(task_pool, experiment_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire()
|
||||
|
||||
@@ -70,89 +70,106 @@ task_xgboost_config = {
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
class RollingOnlineExample:
|
||||
|
||||
def print_online_model():
|
||||
print("========== print_online_model ==========")
|
||||
print("Current 'online' model:")
|
||||
for rid, rec in list_recorders(exp_name).items():
|
||||
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.ONLINE_TAG:
|
||||
print(rid)
|
||||
print("Current 'next online' model:")
|
||||
for rid, rec in list_recorders(exp_name).items():
|
||||
if rolling_online_manager.get_online_tag(rec) == rolling_online_manager.NEXT_ONLINE_TAG:
|
||||
print(rid)
|
||||
def __init__(self, exp_name="rolling_exp", task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550):
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool)
|
||||
self.rolling_online_manager = RollingOnlineManager(experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer)
|
||||
|
||||
|
||||
|
||||
def print_online_model(self):
|
||||
print("========== print_online_model ==========")
|
||||
print("Current 'online' model:")
|
||||
for rid, rec in list_recorders(self.exp_name).items():
|
||||
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.ONLINE_TAG:
|
||||
print(rid)
|
||||
print("Current 'next online' model:")
|
||||
for rid, rec in list_recorders(self.exp_name).items():
|
||||
if self.rolling_online_manager.get_online_tag(rec) == self.rolling_online_manager.NEXT_ONLINE_TAG:
|
||||
print(rid)
|
||||
|
||||
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating():
|
||||
# This part corresponds to "Task Generating" in the document
|
||||
def task_generating(self):
|
||||
|
||||
print("========== task_generating ==========")
|
||||
print("========== task_generating ==========")
|
||||
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=rolling_gen, # generate different date segment
|
||||
)
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=self.rolling_gen, # generate different date segment
|
||||
)
|
||||
|
||||
pprint(tasks)
|
||||
pprint(tasks)
|
||||
|
||||
return tasks
|
||||
return tasks
|
||||
|
||||
|
||||
def task_training(tasks):
|
||||
trainer.train(tasks, exp_name, task_pool)
|
||||
def task_training(self, tasks):
|
||||
self.trainer.train(tasks)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting():
|
||||
print("========== task_collecting ==========")
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
|
||||
)
|
||||
print(artifact)
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=self.exp_name, rec_key_func=rec_key, rec_filter_func=my_filter), RollingGroup()
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset():
|
||||
print("========== reset ==========")
|
||||
task_manager.remove()
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
self.task_manager.remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run():
|
||||
print("========== first_run ==========")
|
||||
reset()
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run(self):
|
||||
print("========== first_run ==========")
|
||||
self.reset()
|
||||
|
||||
tasks = task_generating()
|
||||
task_training(tasks)
|
||||
task_collecting()
|
||||
tasks = self.task_generating()
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
latest_rec, _ = rolling_online_manager.list_latest_recorders()
|
||||
rolling_online_manager.reset_online_tag(latest_rec.values())
|
||||
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
|
||||
self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
|
||||
|
||||
|
||||
def routine():
|
||||
print("========== routine ==========")
|
||||
print_online_model()
|
||||
rolling_online_manager.routine()
|
||||
print_online_model()
|
||||
task_collecting()
|
||||
def routine(self):
|
||||
print("========== routine ==========")
|
||||
self.print_online_model()
|
||||
self.rolling_online_manager.routine()
|
||||
self.print_online_model()
|
||||
self.task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -161,26 +178,7 @@ if __name__ == "__main__":
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python task_manager_rolling_with_updating.py after_day
|
||||
|
||||
#################### you need to finish the configurations below #########################
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # data_dir
|
||||
mongo_conf = {
|
||||
"task_url": "mongodb://10.0.0.4:27017/", # your MongoDB url
|
||||
"task_db_name": "rolling_db", # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
exp_name = "rolling_exp" # experiment name, will be used as the experiment in MLflow
|
||||
task_pool = "rolling_task" # task pool name, will be used as the document in MongoDB
|
||||
rolling_step = 550
|
||||
|
||||
##########################################################################################
|
||||
rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
task_manager = TaskManager(task_pool=task_pool)
|
||||
trainer = TrainerRM()
|
||||
rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=rolling_gen, task_manager=task_manager, trainer=trainer
|
||||
)
|
||||
|
||||
fire.Fire()
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
|
||||
@@ -54,10 +54,10 @@ task = {
|
||||
|
||||
def first_train(experiment_name="online_srv"):
|
||||
|
||||
rid = task_train(task_config=task, experiment_name=experiment_name)
|
||||
rec = task_train(task_config=task, experiment_name=experiment_name)
|
||||
|
||||
online_manager = OnlineManagerR(experiment_name)
|
||||
online_manager.reset_online_tag(rid)
|
||||
online_manager.reset_online_tag(rec)
|
||||
|
||||
|
||||
def update_online_pred(experiment_name="online_srv"):
|
||||
@@ -71,13 +71,17 @@ def update_online_pred(experiment_name="online_srv"):
|
||||
|
||||
online_manager.update_online_pred()
|
||||
|
||||
def main(provider_uri = "~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv"):
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
first_train(experiment_name)
|
||||
update_online_pred(experiment_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
## to update online predictions once a day, use the command below
|
||||
# python update_online_pred.py update_online_pred
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire()
|
||||
|
||||
@@ -147,7 +147,7 @@ _default_config = {
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
"task_db_name": "default_task_db",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
|
||||
@@ -3,9 +3,10 @@ from typing import Callable, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib.workflow.task.collect import Collector
|
||||
from qlib.utils.serial import Serializable
|
||||
|
||||
|
||||
def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_filter_func=None, *args, **kwargs):
|
||||
def ens_workflow(collector: Collector, process_list, *args, **kwargs):
|
||||
"""the ensemble workflow based on collector and different dict processors.
|
||||
|
||||
Args:
|
||||
@@ -21,7 +22,7 @@ def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_fil
|
||||
Returns:
|
||||
dict: the ensemble dict
|
||||
"""
|
||||
collect_dict = collector.collect(artifacts_key=artifacts_key, rec_filter_func=rec_filter_func)
|
||||
collect_dict = collector.collect()
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
|
||||
@@ -37,23 +38,12 @@ def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_fil
|
||||
return ensemble
|
||||
|
||||
|
||||
class Ensemble:
|
||||
class Ensemble(Serializable):
|
||||
"""Merge the objects in an Ensemble."""
|
||||
|
||||
def __init__(self, merge_func=None):
|
||||
"""init Ensemble
|
||||
|
||||
Args:
|
||||
merge_func (Callable, optional): Given a dict and return the ensemble.
|
||||
|
||||
For example: {Rollinga_b: object, Rollingb_c: object} -> object
|
||||
|
||||
Defaults to None.
|
||||
"""
|
||||
self._merge = merge_func
|
||||
|
||||
def __call__(self, ensemble_dict: dict, *args, **kwargs):
|
||||
"""Merge the ensemble_dict into an ensemble object.
|
||||
For example: {Rollinga_b: object, Rollingb_c: object} -> object
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): the ensemble dict waiting for merging like {name: things}
|
||||
@@ -61,38 +51,29 @@ class Ensemble:
|
||||
Returns:
|
||||
object: the ensemble object
|
||||
"""
|
||||
if isinstance(getattr(self, "_merge", None), Callable):
|
||||
return self._merge(ensemble_dict, *args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid merge_func.")
|
||||
raise NotImplementedError(f"Please implement the `__call__` method.")
|
||||
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
|
||||
"""Merge the rolling objects in an Ensemble"""
|
||||
|
||||
@staticmethod
|
||||
def rolling_merge(rolling_dict: dict):
|
||||
def __call__(self, ensemble_dict: dict, *args, **kwargs):
|
||||
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.Dataframe, and have the index "datetime"
|
||||
|
||||
Args:
|
||||
rolling_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}.
|
||||
ensemble_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.Dataframe: the complete result of rolling.
|
||||
"""
|
||||
artifact_list = list(rolling_dict.values())
|
||||
artifact_list = list(ensemble_dict.values())
|
||||
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
|
||||
artifact = pd.concat(artifact_list)
|
||||
# If there are duplicated predition, use the latest perdiction
|
||||
artifact = artifact[~artifact.index.duplicated(keep="last")]
|
||||
artifact = artifact.sort_index()
|
||||
return artifact
|
||||
|
||||
def __init__(self, merge_func=None):
|
||||
super().__init__(merge_func=merge_func)
|
||||
if merge_func is None:
|
||||
self._merge = RollingEnsemble.rolling_merge
|
||||
@@ -1,8 +1,9 @@
|
||||
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
|
||||
from typing import Callable, Union
|
||||
from qlib.utils.serial import Serializable
|
||||
|
||||
|
||||
class Group:
|
||||
class Group(Serializable):
|
||||
"""Group the objects based on dict"""
|
||||
|
||||
def __init__(self, group_func=None, ens: Ensemble = None):
|
||||
@@ -17,8 +18,8 @@ class Group:
|
||||
|
||||
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
|
||||
"""
|
||||
self._group = group_func
|
||||
self._ens = ens
|
||||
self.group = group_func
|
||||
self.ens = ens
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, *args, **kwargs):
|
||||
"""Group the ungrouped_dict into different groups.
|
||||
@@ -29,16 +30,16 @@ class Group:
|
||||
Returns:
|
||||
dict: grouped_dict like {G1: object, G2: object}
|
||||
"""
|
||||
if isinstance(getattr(self, "_group", None), Callable):
|
||||
grouped_dict = self._group(ungrouped_dict, *args, **kwargs)
|
||||
if self._ens is not None:
|
||||
if isinstance(getattr(self, "group", None), Callable):
|
||||
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
|
||||
if self.ens is not None:
|
||||
ens_dict = {}
|
||||
for key, value in grouped_dict.items():
|
||||
ens_dict[key] = self._ens(value)
|
||||
ens_dict[key] = self.ens(value)
|
||||
grouped_dict = ens_dict
|
||||
return grouped_dict
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid merge_func.")
|
||||
raise NotImplementedError(f"Please specify valid group_func.")
|
||||
|
||||
|
||||
class RollingGroup(Group):
|
||||
@@ -65,4 +66,4 @@ class RollingGroup(Group):
|
||||
def __init__(self, group_func=None, ens: Ensemble = RollingEnsemble()):
|
||||
super().__init__(group_func=group_func, ens=ens)
|
||||
if group_func is None:
|
||||
self._group = RollingGroup.rolling_group
|
||||
self.group = RollingGroup.rolling_group
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
task based training
|
||||
|
||||
@@ -20,8 +21,7 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
|
||||
Returns
|
||||
----------
|
||||
rid : str
|
||||
The id of the recorder of this task
|
||||
Recorder : The instance of the recorder
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
@@ -80,30 +80,40 @@ class TrainerR(Trainer):
|
||||
Assumption: models were defined by `task` and the results will saved to `Recorder`
|
||||
"""
|
||||
|
||||
def train(self, tasks: list, experiment_name: str, train_func=task_train, *args, **kwargs):
|
||||
def __init__(self, experiment_name, train_func=task_train):
|
||||
self.experiment_name = experiment_name
|
||||
self.train_func = train_func
|
||||
|
||||
def train(self, tasks: list, train_func=None, *args, **kwargs):
|
||||
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
experiment_name (str): the experiment name
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
recs = []
|
||||
for task in tasks:
|
||||
recs.append(train_func(task, experiment_name, *args, **kwargs))
|
||||
recs.append(train_func(task, self.experiment_name, *args, **kwargs))
|
||||
return recs
|
||||
|
||||
|
||||
class TrainerRM(TrainerR):
|
||||
class TrainerRM(Trainer):
|
||||
"""Trainer based on (R)ecorder and Task(M)anager
|
||||
|
||||
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
|
||||
"""
|
||||
|
||||
def train(self, tasks: list, experiment_name: str, task_pool: str, train_func=task_train, *args, **kwargs):
|
||||
def __init__(self, experiment_name: str, task_pool: str, train_func=task_train):
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
|
||||
def train(self, tasks: list, train_func=None, *args, **kwargs):
|
||||
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
This method defaults to a single process, but TaskManager offered a great way to parallel training.
|
||||
@@ -111,17 +121,18 @@ class TrainerRM(TrainerR):
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
experiment_name (str): the experiment name
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
tm = TaskManager(task_pool=self.task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(train_func, task_pool, experiment_name=experiment_name, *args, **kwargs)
|
||||
run_task(train_func, self.task_pool, experiment_name=self.experiment_name, *args, **kwargs)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
recs.append(tm.re_query(_id)["res"])
|
||||
return recs
|
||||
return recs
|
||||
|
||||
@@ -20,7 +20,7 @@ class OnlineManager(Serializable):
|
||||
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self, trainer: Trainer = None) -> None:
|
||||
def __init__(self, trainer: Trainer = None):
|
||||
self._trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
@@ -81,7 +81,8 @@ class OnlineManagerR(OnlineManager):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, trainer: Trainer = TrainerR()) -> None:
|
||||
def __init__(self, experiment_name: str, trainer: Trainer = None):
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(trainer)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.exp_name = experiment_name
|
||||
@@ -105,20 +106,22 @@ class OnlineManagerR(OnlineManager):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
"""
|
||||
if recorder is None:
|
||||
recorder = list_recorders(
|
||||
self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG
|
||||
).values()
|
||||
recorder = list(
|
||||
list_recorders(
|
||||
self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG
|
||||
).values()
|
||||
)
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
if len(recorder) == 0:
|
||||
self.logger.info("No 'next online' model, just use current 'online' models.")
|
||||
return
|
||||
recs = list_recorders(self.exp_name)
|
||||
self.set_online_tag(OnlineManager.OFFLINE_TAG, recs.values())
|
||||
self.set_online_tag(OnlineManager.OFFLINE_TAG, list(recs.values()))
|
||||
self.set_online_tag(OnlineManager.ONLINE_TAG, recorder)
|
||||
self.logger.info(f"Reset {len(recorder)} models to 'online'.")
|
||||
|
||||
def update_online_pred(self):
|
||||
def update_online_pred(self, *args, **kwargs):
|
||||
"""update all online model predictions to the latest day in Calendar"""
|
||||
mu = ModelUpdater(self.exp_name)
|
||||
cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
|
||||
@@ -126,25 +129,24 @@ class OnlineManagerR(OnlineManager):
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManagerR):
|
||||
"""An implementation of OnlineManager based on Rolling.
|
||||
|
||||
"""
|
||||
"""An implementation of OnlineManager based on Rolling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
rolling_gen: RollingGen,
|
||||
trainer: Trainer = TrainerR(),
|
||||
) -> None:
|
||||
trainer: Trainer = None,
|
||||
):
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(experiment_name, trainer)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def prepare_signals(self):
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def prepare_tasks(self):
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""prepare new tasks based on new date.
|
||||
|
||||
Returns:
|
||||
@@ -155,7 +157,7 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
)
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return None
|
||||
return []
|
||||
calendar_latest = self.ta.last_date()
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
|
||||
old_tasks = []
|
||||
@@ -168,7 +170,7 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
new_tasks_tmp = task_generator(old_tasks, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return None
|
||||
return []
|
||||
|
||||
def list_latest_recorders(self, rec_filter_func=None):
|
||||
"""find latest recorders based on test segments.
|
||||
@@ -187,4 +189,4 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
for rid, rec in recs_flt.items():
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec[rid] = rec
|
||||
return latest_rec, max_test
|
||||
return latest_rec, max_test
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.utils.serial import Serializable
|
||||
|
||||
|
||||
class Collector:
|
||||
class Collector(Serializable):
|
||||
"""The collector to collect different results"""
|
||||
|
||||
def collect(self, *args, **kwargs):
|
||||
@@ -25,33 +26,46 @@ class Collector:
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
def __init__(
|
||||
self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, rec_key_func=None
|
||||
) -> None:
|
||||
self,
|
||||
exp_name,
|
||||
artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"},
|
||||
rec_key_func=None,
|
||||
artifacts_key=None,
|
||||
rec_filter_func=None,
|
||||
):
|
||||
"""init RecorderCollector
|
||||
|
||||
Args:
|
||||
exp_name (str): the name of Experiment
|
||||
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
self._get_key = rec_key_func
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_path.keys()
|
||||
self.rec_key = rec_key_func
|
||||
self.artifacts_key = artifacts_key
|
||||
self.rec_filter = rec_filter_func
|
||||
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None): # ensemble, get_group_key_func,
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None):
|
||||
"""Collect different artifacts based on recorder after filtering.
|
||||
|
||||
Args:
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use default.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use default.
|
||||
|
||||
Returns:
|
||||
dict: the dict after collected like {artifact: {rec_key: object}}
|
||||
"""
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_path.keys()
|
||||
artifacts_key = self.artifacts_key
|
||||
if rec_filter_func is None:
|
||||
rec_filter_func = self.rec_filter
|
||||
|
||||
if isinstance(artifacts_key, str):
|
||||
artifacts_key = [artifacts_key]
|
||||
@@ -60,9 +74,9 @@ class RecorderCollector(Collector):
|
||||
# filter records
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self._get_key(rec)
|
||||
rec_key = self.rec_key(rec)
|
||||
for key in artifacts_key:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
return collect_dict
|
||||
|
||||
@@ -49,7 +49,7 @@ class TaskManager:
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool=None):
|
||||
def __init__(self, task_pool: str):
|
||||
"""
|
||||
init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
|
||||
@@ -59,9 +59,13 @@ class TaskManager:
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
self.task_pool = task_pool
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
# @property
|
||||
# def task_pool(self):
|
||||
# return self._task_pool
|
||||
|
||||
def list(self):
|
||||
return self.mdb.list_collection_names()
|
||||
|
||||
@@ -79,39 +83,39 @@ class TaskManager:
|
||||
task[k] = pickle.loads(task[k])
|
||||
return task
|
||||
|
||||
def _get_task_pool(self, task_pool=None):
|
||||
if task_pool is None:
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
raise ValueError("You must specify a task pool.")
|
||||
if isinstance(task_pool, str):
|
||||
return getattr(self.mdb, task_pool)
|
||||
return task_pool
|
||||
# def _get_task_pool(self, task_pool=None):
|
||||
# if task_pool is None:
|
||||
# task_pool = self.task_pool
|
||||
# if task_pool is None:
|
||||
# raise ValueError("You must specify a task pool.")
|
||||
# if isinstance(task_pool, str):
|
||||
# return getattr(self.mdb, task_pool)
|
||||
# return task_pool
|
||||
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def replace_task(self, task, new_task, task_pool=None):
|
||||
def replace_task(self, task, new_task):
|
||||
# assume that the data out of interface was decoded and the data in interface was encoded
|
||||
new_task = self._encode_task(new_task)
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
query = {"_id": ObjectId(task["_id"])}
|
||||
try:
|
||||
task_pool.replace_one(query, new_task)
|
||||
self.task_pool.replace_one(query, new_task)
|
||||
except InvalidDocument:
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
task_pool.replace_one(query, new_task)
|
||||
self.task_pool.replace_one(query, new_task)
|
||||
|
||||
def insert_task(self, task, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
def insert_task(self, task):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
try:
|
||||
insert_result = task_pool.insert_one(task)
|
||||
insert_result = self.task_pool.insert_one(task)
|
||||
except InvalidDocument:
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
insert_result = task_pool.insert_one(task)
|
||||
insert_result = self.task_pool.insert_one(task)
|
||||
return insert_result
|
||||
|
||||
def insert_task_def(self, task_def, task_pool=None):
|
||||
def insert_task_def(self, task_def):
|
||||
"""
|
||||
insert a task to task_pool
|
||||
|
||||
@@ -126,7 +130,7 @@ class TaskManager:
|
||||
-------
|
||||
|
||||
"""
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
task = self._encode_task(
|
||||
{
|
||||
"def": task_def,
|
||||
@@ -134,10 +138,10 @@ class TaskManager:
|
||||
"status": self.STATUS_WAITING,
|
||||
}
|
||||
)
|
||||
insert_result = self.insert_task(task, task_pool)
|
||||
insert_result = self.insert_task(task)
|
||||
return insert_result
|
||||
|
||||
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
|
||||
def create_task(self, task_def_l, dry_run=False, print_nt=False):
|
||||
"""
|
||||
if the tasks in task_def_l is new, then insert new tasks into the task_pool
|
||||
|
||||
@@ -156,13 +160,13 @@ class TaskManager:
|
||||
list
|
||||
a list of the _id of new tasks
|
||||
"""
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
new_tasks = []
|
||||
for t in task_def_l:
|
||||
try:
|
||||
r = task_pool.find_one({"filter": t})
|
||||
r = self.task_pool.find_one({"filter": t})
|
||||
except InvalidDocument:
|
||||
r = task_pool.find_one({"filter": self._dict_to_str(t)})
|
||||
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
|
||||
if r is None:
|
||||
new_tasks.append(t)
|
||||
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
|
||||
@@ -176,18 +180,18 @@ class TaskManager:
|
||||
|
||||
_id_list = []
|
||||
for t in new_tasks:
|
||||
insert_result = self.insert_task_def(t, task_pool)
|
||||
insert_result = self.insert_task_def(t)
|
||||
_id_list.append(insert_result.inserted_id)
|
||||
|
||||
return _id_list
|
||||
|
||||
def fetch_task(self, query={}, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
def fetch_task(self, query={}):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query.update({"status": self.STATUS_WAITING})
|
||||
task = task_pool.find_one_and_update(
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
)
|
||||
# null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
|
||||
@@ -197,7 +201,7 @@ class TaskManager:
|
||||
return self._decode_task(task)
|
||||
|
||||
@contextmanager
|
||||
def safe_fetch_task(self, query={}, task_pool=None):
|
||||
def safe_fetch_task(self, query={}):
|
||||
"""
|
||||
fetch task from task_pool using query with contextmanager
|
||||
|
||||
@@ -212,7 +216,7 @@ class TaskManager:
|
||||
-------
|
||||
|
||||
"""
|
||||
task = self.fetch_task(query=query, task_pool=task_pool)
|
||||
task = self.fetch_task(query=query)
|
||||
try:
|
||||
yield task
|
||||
except Exception:
|
||||
@@ -229,7 +233,7 @@ class TaskManager:
|
||||
break
|
||||
yield task
|
||||
|
||||
def query(self, query={}, decode=True, task_pool=None):
|
||||
def query(self, query={}, decode=True):
|
||||
"""
|
||||
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
|
||||
|
||||
@@ -248,29 +252,30 @@ class TaskManager:
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
for t in task_pool.find(query):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
return task_pool.find_one({"_id": ObjectId(_id)})
|
||||
def re_query(self, _id):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
t = self.task_pool.find_one({"_id": ObjectId(_id)})
|
||||
return self._decode_task(t)
|
||||
|
||||
def commit_task_res(self, task, res, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
def commit_task_res(self, task, res, status=None):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
|
||||
def return_task(self, task, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
def return_task(self, task, status=None):
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_WAITING
|
||||
update_dict = {"$set": {"status": status}}
|
||||
task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def remove(self, query={}, task_pool=None):
|
||||
def remove(self, query={}):
|
||||
"""
|
||||
remove the task using query
|
||||
|
||||
@@ -286,16 +291,16 @@ class TaskManager:
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
task_pool.delete_many(query)
|
||||
self.task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}, task_pool=None):
|
||||
def task_stat(self, query={}):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
tasks = self.query(task_pool=task_pool, query=query, decode=False)
|
||||
tasks = self.query(query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
|
||||
@@ -306,14 +311,14 @@ class TaskManager:
|
||||
# default query
|
||||
if "status" not in query:
|
||||
query["status"] = self.STATUS_RUNNING
|
||||
return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool)
|
||||
return self.reset_status(query=query, status=self.STATUS_WAITING)
|
||||
|
||||
def reset_status(self, query, status, task_pool=None):
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
# task_pool = self._get_task_pool(task_pool)
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
print(task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
|
||||
@@ -321,14 +326,14 @@ class TaskManager:
|
||||
def _get_total(self, task_stat):
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}, task_pool=None):
|
||||
task_stat = self.task_stat(query, task_pool)
|
||||
def wait(self, query={}):
|
||||
task_stat = self.task_stat(query)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
undone_n = self._get_undone_n(self.task_stat(query, task_pool))
|
||||
undone_n = self._get_undone_n(self.task_stat(query))
|
||||
pbar.update(last_undone_n - undone_n)
|
||||
last_undone_n = undone_n
|
||||
if undone_n == 0:
|
||||
@@ -365,7 +370,7 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
break
|
||||
get_module_logger("run_task").info(task["def"])
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: # what this means?
|
||||
res = executor.submit(task_func, task["def"], *args, **kwargs).result()
|
||||
else:
|
||||
res = task_func(task["def"], *args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user