diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index e5de1ef60..75d360fa1 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -73,7 +73,7 @@ def reset(task_pool, exp_name): print("========== reset ==========") TaskManager(task_pool=task_pool).remove() - exp, _ = R.get_exp(experiment_name=exp_name) + exp = R.get_exp(experiment_name=exp_name) for rid in exp.list_recorders(): exp.delete_recorder(rid) @@ -115,7 +115,7 @@ def task_collecting(task_pool, exp_name): task_config = recorder.load_object("task") model_key = task_config["model"]["class"] rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return model_key, model_key, rolling_key + return model_key, rolling_key def my_filter(recorder): # only choose the results of "LGBModel" diff --git a/examples/online_svr/task_manager_rolling_with_updating.py b/examples/online_svr/task_manager_rolling_with_updating.py index 4e9fdd336..fff470c86 100644 --- a/examples/online_svr/task_manager_rolling_with_updating.py +++ b/examples/online_svr/task_manager_rolling_with_updating.py @@ -117,7 +117,7 @@ def task_collecting(): task_config = recorder.load_object("task") model_key = task_config["model"]["class"] rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return model_key, model_key, rolling_key + return model_key, rolling_key def my_filter(recorder): # only choose the results of "LGBModel" @@ -136,7 +136,7 @@ def task_collecting(): def reset(): print("========== reset ==========") task_manager.remove() - exp, _ = R.get_exp(experiment_name=exp_name) + exp = R.get_exp(experiment_name=exp_name) for rid in exp.list_recorders(): exp.delete_recorder(rid) diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index a7a6ce4bb..91b713ef8 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -7,8 +7,7 @@ from qlib.workflow.task.utils import list_recorders class Collector: - """The collector to collect different results based on experiment backend and ensemble method - """ + """The collector to collect different results based on experiment backend and ensemble method""" def collect(self, ensemble, get_group_key_func, *args, **kwargs): """To collect the results, we need to get the experiment record firstly and divided them into @@ -23,7 +22,7 @@ class Collector: class RecorderCollector(Collector): - def __init__(self, exp_name, artifacts_path = {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None: + def __init__(self, exp_name, artifacts_path={"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}) -> None: """init RecorderCollector Args: @@ -48,14 +47,14 @@ class RecorderCollector(Collector): """ if artifacts_key is None: artifacts_key = self.artifacts_path.keys() - + if isinstance(artifacts_key, str): artifacts_key = [artifacts_key] # prepare_ensemble ensemble_dict = {} for key in artifacts_key: - ensemble_dict.setdefault(key,{}) + ensemble_dict.setdefault(key, {}) # filter records recs_flt = list_recorders(self.exp_name, rec_filter_func) for _, rec in recs_flt.items(): @@ -64,7 +63,6 @@ class RecorderCollector(Collector): artifact = rec.load_object(self.artifacts_path[key]) ensemble_dict[key][group_key] = artifact - if isinstance(artifacts_key, str): return ensemble(ensemble_dict[artifacts_key]) @@ -72,4 +70,3 @@ class RecorderCollector(Collector): for key in artifacts_key: collect_dict[key] = ensemble(ensemble_dict[key]) return collect_dict - \ No newline at end of file diff --git a/qlib/workflow/task/ensemble.py b/qlib/workflow/task/ensemble.py index 649ce9415..dca0dee3e 100644 --- a/qlib/workflow/task/ensemble.py +++ b/qlib/workflow/task/ensemble.py @@ -7,12 +7,10 @@ from qlib.workflow.task.utils import list_recorders from typing import Dict - class Ensemble: - """Merge the objects in an Ensemble. - """ + """Merge the objects in an Ensemble.""" - def __init__(self, merge_func = None, get_grouped_key_func = None) -> None: + def __init__(self, merge_func=None, get_grouped_key_func=None) -> None: """init Ensemble Args: @@ -26,7 +24,7 @@ class Ensemble: self.get_grouped_key_func = get_grouped_key_func def merge_func(self, group_inner_dict): - """Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object}, + """Given a group_inner_dict such as {Rollinga_b: object, Rollingb_c: object}, merge it to object Args: @@ -34,10 +32,10 @@ class Ensemble: """ raise NotImplementedError(f"Please implement the `merge_func` method.") - + def get_grouped_key_func(self, group_key): """Given a group_key and return the group_outer_key, group_inner_key. - + For example: (A,B,Rolling) -> (A,B):Rolling (A,B) -> C:(A,B) @@ -135,10 +133,10 @@ class Ensemble: grouped_dict = self.group(group_dict) return self.reduce(grouped_dict) -class RollingEnsemble(Ensemble): - """A specific implementation of Ensemble for Rolling. - """ +class RollingEnsemble(Ensemble): + """A specific implementation of Ensemble for Rolling.""" + def merge_func(self, group_inner_dict): """merge group_inner_dict by datetime. @@ -155,7 +153,7 @@ class RollingEnsemble(Ensemble): artifact = artifact[~artifact.index.duplicated(keep="last")] artifact = artifact.sort_index() return artifact - + def get_grouped_key_func(self, group_key): """The final axis of group_key must be the Rolling key. When `collect`, get_group_key_func can add the statement below. @@ -174,7 +172,5 @@ class RollingEnsemble(Ensemble): Returns: tuple or str, tuple or str: group_outer_key, group_inner_key """ - assert len(group_key)>=2 + assert len(group_key) >= 2 return group_key[:-1], group_key[-1] - - diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 6e9fa6571..a62164207 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -174,7 +174,7 @@ class TaskManager: for t in new_tasks: self.insert_task_def(t, task_pool) - + return len(new_tasks) def fetch_task(self, query={}, task_pool=None): @@ -250,7 +250,7 @@ class TaskManager: def re_query(self, task, task_pool=None): task_pool = self._get_task_pool(task_pool) - return task_pool.find_one({"_id":ObjectId(task["_id"])}) + return task_pool.find_one({"_id": ObjectId(task["_id"])}) def commit_task_res(self, task, res, status=None, task_pool=None): task_pool = self._get_task_pool(task_pool) diff --git a/qlib/workflow/task/online.py b/qlib/workflow/task/online.py index d23fc88c8..f7ffbd18a 100644 --- a/qlib/workflow/task/online.py +++ b/qlib/workflow/task/online.py @@ -106,7 +106,9 @@ class RollingOnlineManager(OnlineManagerR): pass def prepare_tasks(self): - latest_records, max_test = self.list_latest_recorders(lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG) + latest_records, max_test = self.list_latest_recorders( + lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG + ) if max_test is None: self.logger.warn(f"No latest_recorders.") return diff --git a/qlib/workflow/task/update.py b/qlib/workflow/task/update.py index 43c304239..002f1128f 100644 --- a/qlib/workflow/task/update.py +++ b/qlib/workflow/task/update.py @@ -9,6 +9,7 @@ from qlib.workflow.recorder import Recorder from qlib.workflow.task.utils import list_recorders from qlib.data.dataset.handler import DataHandlerLP + class ModelUpdater: """ The model updater to update model results in new data. @@ -48,7 +49,7 @@ class ModelUpdater: dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}, segments=segments) return dataset - def update_pred(self, recorder: Recorder, frequency='day'): + def update_pred(self, recorder: Recorder, frequency="day"): """update predictions to the latest day in Calendar based on rid Parameters @@ -60,10 +61,10 @@ class ModelUpdater: last_end = old_pred.index.get_level_values("datetime").max() # updated to the latest trading day - if frequency=='day': + if frequency == "day": cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None) else: - raise NotImplementedError("Now Qlib only support update daily frequency prediction") + raise NotImplementedError("Now `ModelUpdater` only support update daily frequency prediction") if len(cal) == 0: self.logger.info( diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index b34b75306..b6287abc2 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -42,7 +42,7 @@ def list_recorders(experiment, rec_filter_func=None): dict: a dict {rid: recorder} after filtering. """ if isinstance(experiment, str): - experiment, _ = R.get_exp(experiment_name=experiment) + experiment = R.get_exp(experiment_name=experiment) recs = experiment.list_recorders() recs_flt = {} for rid, rec in recs.items():