mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
trainer & group & collect & ensemble
This commit is contained in:
@@ -8,9 +8,11 @@ from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.workflow.task.ensemble import RollingEnsemble
|
||||
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
|
||||
import pandas as pd
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
@@ -94,24 +96,16 @@ def task_generating():
|
||||
return tasks
|
||||
|
||||
|
||||
# This part corresponds to "Task Storing" in the document
|
||||
def task_storing(tasks, task_pool, exp_name):
|
||||
print("========== task_storing ==========")
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
|
||||
|
||||
# This part corresponds to "Task Running" in the document
|
||||
def task_running(task_pool, exp_name):
|
||||
print("========== task_running ==========")
|
||||
run_task(task_train, task_pool, experiment_name=exp_name) # all tasks will be trained using "task_train" method
|
||||
def task_training(tasks, task_pool, exp_name):
|
||||
trainer = TrainerRM()
|
||||
trainer.train(tasks, exp_name, task_pool)
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting(task_pool, exp_name):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def get_group_key_func(recorder):
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
@@ -119,14 +113,14 @@ def task_collecting(task_pool, exp_name):
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = get_group_key_func(recorder)
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
collector = RecorderCollector(exp_name)
|
||||
# group tasks by "get_task_key" and filter tasks by "my_filter"
|
||||
artifact = collector.collect(RollingEnsemble(), get_group_key_func, rec_filter_func=my_filter)
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(exp_name=exp_name, rec_key_func=rec_key), RollingGroup(), rec_filter_func=my_filter
|
||||
)
|
||||
print(artifact)
|
||||
|
||||
|
||||
@@ -143,10 +137,9 @@ def main(
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo=mongo_conf)
|
||||
|
||||
reset(task_pool, exp_name)
|
||||
tasks = task_generating()
|
||||
task_storing(tasks, task_pool, exp_name)
|
||||
task_running(task_pool, exp_name)
|
||||
# reset(task_pool, exp_name)
|
||||
# tasks = task_generating()
|
||||
# task_training(tasks, task_pool, exp_name)
|
||||
task_collecting(task_pool, exp_name)
|
||||
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@ from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.workflow.task.ensemble import RollingEnsemble
|
||||
from qlib.model.ens.ensemble import RollingEnsemble
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.online import RollingOnlineManager
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
data_handler_config = {
|
||||
@@ -155,10 +155,10 @@ def first_run():
|
||||
rolling_online_manager.reset_online_tag(latest_rec.values())
|
||||
|
||||
|
||||
def after_day():
|
||||
def routine():
|
||||
print("========== after_day ==========")
|
||||
print_online_model()
|
||||
rolling_online_manager.after_day()
|
||||
rolling_online_manager.routine()
|
||||
print_online_model()
|
||||
task_collecting()
|
||||
|
||||
@@ -2,7 +2,7 @@ import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.task.online import OnlineManagerR
|
||||
from qlib.workflow.online.manager import OnlineManagerR
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
data_handler_config = {
|
||||
@@ -52,7 +52,7 @@ task = {
|
||||
}
|
||||
|
||||
|
||||
def first_train(experiment_name="online_svr"):
|
||||
def first_train(experiment_name="online_srv"):
|
||||
|
||||
rid = task_train(task_config=task, experiment_name=experiment_name)
|
||||
|
||||
@@ -60,7 +60,7 @@ def first_train(experiment_name="online_svr"):
|
||||
online_manager.reset_online_tag(rid)
|
||||
|
||||
|
||||
def update_online_pred(experiment_name="online_svr"):
|
||||
def update_online_pred(experiment_name="online_srv"):
|
||||
|
||||
online_manager = OnlineManagerR(experiment_name)
|
||||
|
||||
98
qlib/model/ens/ensemble.py
Normal file
98
qlib/model/ens/ensemble.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib.workflow.task.collect import Collector
|
||||
|
||||
|
||||
def ens_workflow(collector: Collector, process_list, artifacts_key=None, rec_filter_func=None, *args, **kwargs):
|
||||
"""the ensemble workflow based on collector and different dict processors.
|
||||
|
||||
Args:
|
||||
collector (Collector): the collector to collect the result into {result_key: things}
|
||||
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
The processor order is same as the list order.
|
||||
|
||||
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
|
||||
|
||||
artifacts_key (list, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: the ensemble dict
|
||||
"""
|
||||
collect_dict = collector.collect(artifacts_key=artifacts_key, rec_filter_func=rec_filter_func)
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
|
||||
ensemble = {}
|
||||
for artifact in collect_dict:
|
||||
value = collect_dict[artifact]
|
||||
for process in process_list:
|
||||
if not callable(process):
|
||||
raise NotImplementedError(f"{type(process)} is not supported in `ens_workflow`.")
|
||||
value = process(value, *args, **kwargs)
|
||||
ensemble[artifact] = value
|
||||
|
||||
return ensemble
|
||||
|
||||
|
||||
class Ensemble:
|
||||
"""Merge the objects in an Ensemble."""
|
||||
|
||||
def __init__(self, merge_func=None):
|
||||
"""init Ensemble
|
||||
|
||||
Args:
|
||||
merge_func (Callable, optional): Given a dict and return the ensemble.
|
||||
|
||||
For example: {Rollinga_b: object, Rollingb_c: object} -> object
|
||||
|
||||
Defaults to None.
|
||||
"""
|
||||
self._merge = merge_func
|
||||
|
||||
def __call__(self, ensemble_dict: dict, *args, **kwargs):
|
||||
"""Merge the ensemble_dict into an ensemble object.
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): the ensemble dict waiting for merging like {name: things}
|
||||
|
||||
Returns:
|
||||
object: the ensemble object
|
||||
"""
|
||||
if isinstance(getattr(self, "_merge", None), Callable):
|
||||
return self._merge(ensemble_dict, *args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid merge_func.")
|
||||
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
|
||||
"""Merge the rolling objects in an Ensemble"""
|
||||
|
||||
@staticmethod
|
||||
def rolling_merge(rolling_dict: dict):
|
||||
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.Dataframe, and have the index "datetime"
|
||||
|
||||
Args:
|
||||
rolling_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.Dataframe: the complete result of rolling.
|
||||
"""
|
||||
artifact_list = list(rolling_dict.values())
|
||||
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
|
||||
artifact = pd.concat(artifact_list)
|
||||
# If there are duplicated predition, use the latest perdiction
|
||||
artifact = artifact[~artifact.index.duplicated(keep="last")]
|
||||
artifact = artifact.sort_index()
|
||||
return artifact
|
||||
|
||||
def __init__(self, merge_func=None):
|
||||
super().__init__(merge_func=merge_func)
|
||||
if merge_func is None:
|
||||
self._merge = RollingEnsemble.rolling_merge
|
||||
68
qlib/model/ens/group.py
Normal file
68
qlib/model/ens/group.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
|
||||
from typing import Callable, Union
|
||||
|
||||
|
||||
class Group:
|
||||
"""Group the objects based on dict"""
|
||||
|
||||
def __init__(self, group_func=None, ens: Ensemble = None):
|
||||
"""init Group.
|
||||
|
||||
Args:
|
||||
group_func (Callable, optional): Given a dict and return the group key and one of group elements.
|
||||
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
|
||||
Defaults to None.
|
||||
|
||||
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
|
||||
"""
|
||||
self._group = group_func
|
||||
self._ens = ens
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, *args, **kwargs):
|
||||
"""Group the ungrouped_dict into different groups.
|
||||
|
||||
Args:
|
||||
ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}
|
||||
|
||||
Returns:
|
||||
dict: grouped_dict like {G1: object, G2: object}
|
||||
"""
|
||||
if isinstance(getattr(self, "_group", None), Callable):
|
||||
grouped_dict = self._group(ungrouped_dict, *args, **kwargs)
|
||||
if self._ens is not None:
|
||||
ens_dict = {}
|
||||
for key, value in grouped_dict.items():
|
||||
ens_dict[key] = self._ens(value)
|
||||
grouped_dict = ens_dict
|
||||
return grouped_dict
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid merge_func.")
|
||||
|
||||
|
||||
class RollingGroup(Group):
|
||||
"""group the rolling dict"""
|
||||
|
||||
@staticmethod
|
||||
def rolling_group(rolling_dict: dict):
|
||||
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
|
||||
|
||||
NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly.
|
||||
|
||||
Args:
|
||||
rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.
|
||||
|
||||
Returns:
|
||||
dict: grouped dict
|
||||
"""
|
||||
grouped_dict = {}
|
||||
for key, values in rolling_dict.items():
|
||||
if isinstance(key, tuple):
|
||||
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
|
||||
return grouped_dict
|
||||
|
||||
def __init__(self, group_func=None, ens: Ensemble = RollingEnsemble()):
|
||||
super().__init__(group_func=group_func, ens=ens)
|
||||
if group_func is None:
|
||||
self._group = RollingGroup.rolling_group
|
||||
@@ -4,6 +4,7 @@
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
@@ -57,3 +58,70 @@ def task_train(task_config: dict, experiment_name: str) -> str:
|
||||
ar.generate()
|
||||
|
||||
return recorder
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
The trainer which can train a list of model
|
||||
"""
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
"""Given a list of model definition, finished training and return the results of them.
|
||||
|
||||
Returns:
|
||||
list: a list of trained results
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `train` method.")
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""Trainer based on (R)ecorder.
|
||||
|
||||
Assumption: models were defined by `task` and the results will saved to `Recorder`
|
||||
"""
|
||||
|
||||
def train(self, tasks: list, experiment_name: str, train_func=task_train, *args, **kwargs):
|
||||
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
experiment_name (str): the experiment name
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
recs = []
|
||||
for task in tasks:
|
||||
recs.append(train_func(task, experiment_name, *args, **kwargs))
|
||||
return recs
|
||||
|
||||
|
||||
class TrainerRM(TrainerR):
|
||||
"""Trainer based on (R)ecorder and Task(M)anager
|
||||
|
||||
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
|
||||
"""
|
||||
|
||||
def train(self, tasks: list, experiment_name: str, task_pool: str, train_func=task_train, *args, **kwargs):
|
||||
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
This method defaults to a single process, but TaskManager offered a great way to parallel training.
|
||||
Users can customize their train_func to realize multiple processes or even multiple machines.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
experiment_name (str): the experiment name
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(train_func, task_pool, experiment_name=experiment_name, *args, **kwargs)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
recs.append(tm.re_query(_id)["res"])
|
||||
return recs
|
||||
0
qlib/workflow/online/__init__.py
Normal file
0
qlib/workflow/online/__init__.py
Normal file
@@ -3,7 +3,7 @@ from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import MLflowRecorder, Recorder
|
||||
from qlib.workflow.task.update import ModelUpdater
|
||||
from qlib.workflow.online.update import ModelUpdater
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
@@ -37,6 +37,16 @@ class OnlineManager(Serializable):
|
||||
def get_online_tag(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
def routine(self, *args, **kwargs):
|
||||
self.prepare_signals(*args, **kwargs)
|
||||
self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(*args, **kwargs)
|
||||
self.update_online_pred(*args, **kwargs)
|
||||
self.reset_online_tag(*args, **kwargs)
|
||||
|
||||
|
||||
class OnlineManagerR(OnlineManager):
|
||||
"""
|
||||
@@ -86,21 +96,18 @@ class OnlineManagerR(OnlineManager):
|
||||
cnt = mu.update_all_pred(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG)
|
||||
self.logger.info(f"Finish updating {cnt} online model predictions of {self.exp_name}.")
|
||||
|
||||
def after_day(self, *args, **kwargs):
|
||||
self.prepare_signals(*args, **kwargs)
|
||||
self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(*args, **kwargs)
|
||||
self.update_online_pred(*args, **kwargs)
|
||||
self.reset_online_tag()
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManagerR):
|
||||
def __init__(self, experiment_name: str, rolling_gen: RollingGen, task_pool) -> None:
|
||||
# FIXME: TaskManager不应该与onlinemanager强耦合
|
||||
def __init__(
|
||||
self, experiment_name: str, rolling_gen: RollingGen, task_manager: TaskManager, trainer=run_task
|
||||
) -> None:
|
||||
super().__init__(experiment_name)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.tm = TaskManager(task_pool=task_pool)
|
||||
self.tm = task_manager
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.trainer = trainer
|
||||
|
||||
def prepare_signals(self):
|
||||
pass
|
||||
@@ -122,13 +129,13 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
old_tasks.append(task)
|
||||
new_tasks = task_generator(old_tasks, self.rg)
|
||||
new_num = self.tm.create_task(new_tasks)
|
||||
self.logger.info(f"Finished prepare {new_num} tasks.")
|
||||
self.tm.create_task(new_tasks)
|
||||
|
||||
def prepare_new_models(self):
|
||||
"""prepare(train) new models based on online model"""
|
||||
run_task(task_train, self.tm.task_pool, experiment_name=self.exp_name)
|
||||
run_task(task_train, task_pool=self.tm.task_pool, experiment_name=self.exp_name)
|
||||
latest_records, _ = self.list_latest_recorders()
|
||||
# FIXME: 现有的流程,如果没有可更新的模型,仍会调用这个,导致会先将以前的模型设置成nextonline再去更新pred,但这个时候online已经没有了,pred无法更新
|
||||
self.set_online_tag(OnlineManager.NEXT_ONLINE_TAG, latest_records.values())
|
||||
self.logger.info(f"Finished prepare {len(latest_records)} new models and set them to next_online.")
|
||||
|
||||
@@ -45,8 +45,8 @@ class ModelUpdater:
|
||||
"""
|
||||
segments = {"test": (start_time, end_time)}
|
||||
dataset = recorder.load_object("dataset")
|
||||
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time})
|
||||
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}, segments=segments)
|
||||
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
|
||||
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
|
||||
return dataset
|
||||
|
||||
def update_pred(self, recorder: Recorder, frequency="day"):
|
||||
@@ -1,49 +1,54 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
|
||||
class Collector:
|
||||
"""The collector to collect different results based on experiment backend and ensemble method"""
|
||||
"""The collector to collect different results"""
|
||||
|
||||
def collect(self, ensemble, get_group_key_func, *args, **kwargs):
|
||||
"""To collect the results, we need to get the experiment record firstly and divided them into
|
||||
different groups. Then use ensemble methods to merge the group.
|
||||
def collect(self, *args, **kwargs):
|
||||
"""Collect the results and return a dict like {key: things}
|
||||
|
||||
Args:
|
||||
ensemble (Ensemble): an instance of Ensemble
|
||||
get_group_key_func (Callable): a function to get the group of a experiment record
|
||||
Returns:
|
||||
dict: the dict after collected.
|
||||
|
||||
For example:
|
||||
|
||||
{"prediction": pd.Series}
|
||||
|
||||
{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
|
||||
|
||||
......
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `collect` method.")
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
def __init__(self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None:
|
||||
def __init__(
|
||||
self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}, rec_key_func=None
|
||||
) -> None:
|
||||
"""init RecorderCollector
|
||||
|
||||
Args:
|
||||
exp_name (str): the name of Experiment
|
||||
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
self._get_key = rec_key_func
|
||||
|
||||
def collect(self, ensemble, get_group_key_func, artifacts_key=None, rec_filter_func=None):
|
||||
"""Collect different artifacts based on recorder after filtering and ensemble method.
|
||||
Group recorder by get_group_key_func.
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None): # ensemble, get_group_key_func,
|
||||
"""Collect different artifacts based on recorder after filtering.
|
||||
|
||||
Args:
|
||||
ensemble (Ensemble): an instance of Ensemble
|
||||
get_group_key_func (Callable): a function to get the group of a experiment record
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. Defaults to None.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: the dict after collected.
|
||||
dict: the dict after collected like {artifact: {rec_key: object}}
|
||||
"""
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_path.keys()
|
||||
@@ -51,22 +56,13 @@ class RecorderCollector(Collector):
|
||||
if isinstance(artifacts_key, str):
|
||||
artifacts_key = [artifacts_key]
|
||||
|
||||
# prepare_ensemble
|
||||
ensemble_dict = {}
|
||||
for key in artifacts_key:
|
||||
ensemble_dict.setdefault(key, {})
|
||||
collect_dict = {}
|
||||
# filter records
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
for _, rec in recs_flt.items():
|
||||
group_key = get_group_key_func(rec)
|
||||
rec_key = self._get_key(rec)
|
||||
for key in artifacts_key:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
ensemble_dict[key][group_key] = artifact
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
if isinstance(artifacts_key, str):
|
||||
return ensemble(ensemble_dict[artifacts_key])
|
||||
|
||||
collect_dict = {}
|
||||
for key in artifacts_key:
|
||||
collect_dict[key] = ensemble(ensemble_dict[key])
|
||||
return collect_dict
|
||||
return collect_dict
|
||||
@@ -1,176 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class Ensemble:
|
||||
"""Merge the objects in an Ensemble."""
|
||||
|
||||
def __init__(self, merge_func=None, get_grouped_key_func=None) -> None:
|
||||
"""init Ensemble
|
||||
|
||||
Args:
|
||||
merge_func (Callable, optional): The specific merge function. Defaults to None.
|
||||
get_grouped_key_func (Callable, optional): Get group_inner_key and group_outer_key by group_key. Defaults to None.
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if merge_func is not None:
|
||||
self.merge_func = merge_func
|
||||
if get_grouped_key_func is not None:
|
||||
self.get_grouped_key_func = get_grouped_key_func
|
||||
|
||||
def merge_func(self, group_inner_dict):
|
||||
"""Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object},
|
||||
merge it to object
|
||||
|
||||
Args:
|
||||
group_inner_dict (dict): the inner group dict
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `merge_func` method.")
|
||||
|
||||
def get_grouped_key_func(self, group_key):
|
||||
"""Given a group_key and return the group_outer_key, group_inner_key.
|
||||
|
||||
For example:
|
||||
(A,B,Rolling) -> (A,B):Rolling
|
||||
(A,B) -> C:(A,B)
|
||||
|
||||
Args:
|
||||
group_key (tuple or str): the group key
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_grouped_key_func` method.")
|
||||
|
||||
def group(self, group_dict: Dict[tuple or str, object]) -> Dict[tuple or str, Dict[tuple or str, object]]:
|
||||
"""In a group of dict, further divide them into outgroups and innergroup.
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
RollingEnsemble:
|
||||
input:
|
||||
{
|
||||
(ModelA,Horizon5,Rollinga_b): object
|
||||
(ModelA,Horizon5,Rollingb_c): object
|
||||
(ModelA,Horizon10,Rollinga_b): object
|
||||
(ModelA,Horizon10,Rollingb_c): object
|
||||
(ModelB,Horizon5,Rollinga_b): object
|
||||
(ModelB,Horizon5,Rollingb_c): object
|
||||
(ModelB,Horizon10,Rollinga_b): object
|
||||
(ModelB,Horizon10,Rollingb_c): object
|
||||
}
|
||||
|
||||
output:
|
||||
{
|
||||
(ModelA,Horizon5): {Rollinga_b: object, Rollingb_c: object}
|
||||
(ModelA,Horizon10): {Rollinga_b: object, Rollingb_c: object}
|
||||
(ModelB,Horizon5): {Rollinga_b: object, Rollingb_c: object}
|
||||
(ModelB,Horizon10): {Rollinga_b: object, Rollingb_c: object}
|
||||
}
|
||||
|
||||
Args:
|
||||
group_dict (Dict[tuple or str, object]): a group of dict
|
||||
|
||||
Returns:
|
||||
Dict[tuple or str, Dict[tuple or str, object]]: the dict after `group`
|
||||
"""
|
||||
grouped_dict = {}
|
||||
for group_key, artifact in group_dict.items():
|
||||
group_outer_key, group_inner_key = self.get_grouped_key_func(group_key) # (A,B,Rolling) -> (A,B):Rolling
|
||||
grouped_dict.setdefault(group_outer_key, {})[group_inner_key] = artifact
|
||||
return grouped_dict
|
||||
|
||||
def reduce(self, grouped_dict: dict):
|
||||
"""After grouping, reduce the innergroup.
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
RollingEnsemble:
|
||||
input:
|
||||
{
|
||||
(ModelA,Horizon5): {Rollinga_b: object, Rollingb_c: object}
|
||||
(ModelA,Horizon10): {Rollinga_b: object, Rollingb_c: object}
|
||||
(ModelB,Horizon5): {Rollinga_b: object, Rollingb_c: object}
|
||||
(ModelB,Horizon10): {Rollinga_b: object, Rollingb_c: object}
|
||||
}
|
||||
|
||||
output:
|
||||
{
|
||||
(ModelA,Horizon5): object
|
||||
(ModelA,Horizon10): object
|
||||
(ModelB,Horizon5): object
|
||||
(ModelB,Horizon10): object
|
||||
}
|
||||
|
||||
Args:
|
||||
grouped_dict (dict): the dict after `group`
|
||||
|
||||
Returns:
|
||||
dict: the dict after `reduce`
|
||||
"""
|
||||
reduce_group = {}
|
||||
for group_outer_key, group_inner_dict in grouped_dict.items():
|
||||
artifact = self.merge_func(group_inner_dict)
|
||||
reduce_group[group_outer_key] = artifact
|
||||
return reduce_group
|
||||
|
||||
def __call__(self, group_dict):
|
||||
"""The process of Ensemble is group it firstly and then reduce it, like MapReduce.
|
||||
|
||||
Args:
|
||||
group_dict (Dict[tuple or str, object]): a group of dict
|
||||
|
||||
Returns:
|
||||
dict: the dict after `reduce`
|
||||
"""
|
||||
grouped_dict = self.group(group_dict)
|
||||
return self.reduce(grouped_dict)
|
||||
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
"""A specific implementation of Ensemble for Rolling."""
|
||||
|
||||
def merge_func(self, group_inner_dict):
|
||||
"""merge group_inner_dict by datetime.
|
||||
|
||||
Args:
|
||||
group_inner_dict (dict): the inner group dict
|
||||
|
||||
Returns:
|
||||
object: the artifact after merging
|
||||
"""
|
||||
artifact_list = list(group_inner_dict.values())
|
||||
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
|
||||
artifact = pd.concat(artifact_list)
|
||||
# If there are duplicated predition, use the latest perdiction
|
||||
artifact = artifact[~artifact.index.duplicated(keep="last")]
|
||||
artifact = artifact.sort_index()
|
||||
return artifact
|
||||
|
||||
def get_grouped_key_func(self, group_key):
|
||||
"""The final axis of group_key must be the Rolling key.
|
||||
When `collect`, get_group_key_func can add the statement below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def get_group_key_func(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
......
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return ......, rolling_key
|
||||
|
||||
Args:
|
||||
group_key (tuple or str): the group key
|
||||
|
||||
Returns:
|
||||
tuple or str, tuple or str: group_outer_key, group_inner_key
|
||||
"""
|
||||
assert len(group_key) >= 2
|
||||
return group_key[:-1], group_key[-1]
|
||||
@@ -60,7 +60,7 @@ class TaskManager:
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
self.task_pool = task_pool
|
||||
self.logger = get_module_logger("TaskManager")
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self):
|
||||
return self.mdb.list_collection_names()
|
||||
@@ -105,10 +105,11 @@ class TaskManager:
|
||||
def insert_task(self, task, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
try:
|
||||
task_pool.insert_one(task)
|
||||
insert_result = task_pool.insert_one(task)
|
||||
except InvalidDocument:
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
task_pool.insert_one(task)
|
||||
insert_result = task_pool.insert_one(task)
|
||||
return insert_result
|
||||
|
||||
def insert_task_def(self, task_def, task_pool=None):
|
||||
"""
|
||||
@@ -133,7 +134,8 @@ class TaskManager:
|
||||
"status": self.STATUS_WAITING,
|
||||
}
|
||||
)
|
||||
self.insert_task(task, task_pool)
|
||||
insert_result = self.insert_task(task, task_pool)
|
||||
return insert_result
|
||||
|
||||
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
|
||||
"""
|
||||
@@ -151,8 +153,8 @@ class TaskManager:
|
||||
if print new task
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
the length of new tasks
|
||||
list
|
||||
a list of the _id of new tasks
|
||||
"""
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
new_tasks = []
|
||||
@@ -163,7 +165,7 @@ class TaskManager:
|
||||
r = task_pool.find_one({"filter": self._dict_to_str(t)})
|
||||
if r is None:
|
||||
new_tasks.append(t)
|
||||
print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks))
|
||||
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
|
||||
|
||||
if print_nt: # print new task
|
||||
for t in new_tasks:
|
||||
@@ -172,10 +174,12 @@ class TaskManager:
|
||||
if dry_run:
|
||||
return
|
||||
|
||||
_id_list = []
|
||||
for t in new_tasks:
|
||||
self.insert_task_def(t, task_pool)
|
||||
insert_result = self.insert_task_def(t, task_pool)
|
||||
_id_list.append(insert_result.inserted_id)
|
||||
|
||||
return len(new_tasks)
|
||||
return _id_list
|
||||
|
||||
def fetch_task(self, query={}, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
@@ -248,9 +252,9 @@ class TaskManager:
|
||||
for t in task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, task, task_pool=None):
|
||||
def re_query(self, _id, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
return task_pool.find_one({"_id": ObjectId(task["_id"])})
|
||||
return task_pool.find_one({"_id": ObjectId(_id)})
|
||||
|
||||
def commit_task_res(self, task, res, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
|
||||
Reference in New Issue
Block a user