diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 5e62a141c..e901bc252 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -34,7 +34,7 @@ def task_train(task_config: dict, experiment_name: str) -> str: model.fit(dataset) recorder = R.get_recorder() R.save_objects(**{"params.pkl": model}) - R.save_objects(**{"task.pkl": task_config}) # keep the original format and datatype + R.save_objects(**{"task": task_config}) # keep the original format and datatype # generate records: prediction, backtest, and analysis records = task_config.get("record", []) diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 2e4746f59..c022e6e76 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -2,6 +2,7 @@ from qlib.workflow import R import pandas as pd from typing import Union from typing import Callable + from qlib import get_module_logger @@ -17,13 +18,13 @@ class TaskCollector: def list_recorders(self, rec_filter_func=None, task_filter_func=None, only_finished=True, only_have_task=False): """ - Return a dict of {rid:recorder} by recorder filter and task filter. It is not necessary to use those filter. - If you don't train with "task_train", then there is no "task.pkl" which includes the task config. - If there is a "task.pkl", then it will become rec.task which can be get simply. + Return a dict of {rid:Recorder} by recorder filter and task filter. It is not necessary to use those filter. + If you don't train with "task_train", then there is no "task" which includes the task config. + If there is a "task", then it will become rec.task which can be get simply. Parameters ---------- - rec_filter_func : Callable[[MLflowRecorder], bool], optional + rec_filter_func : Callable[[Recorder], bool], optional judge whether you need this recorder, by default None task_filter_func : Callable[[dict], bool], optional judge whether you need this task, by default None @@ -35,30 +36,27 @@ class TaskCollector: Returns ------- dict - a dict of {rid:recorder} + a dict of {rid:Recorder} Raises ------ OSError - if you use a task filter, but there is no "task.pkl" which includes the task config + if you use a task filter, but there is no "task" which includes the task config """ recs = self.exp.list_recorders() - # return all recorders if the filter is None and you don't need task - if rec_filter_func==None and task_filter_func==None and only_have_task==False: - return recs recs_flt = {} + if task_filter_func is not None: + only_have_task = True for rid, rec in recs.items(): if (only_finished and rec.status == rec.STATUS_FI) or only_finished==False: if rec_filter_func is None or rec_filter_func(rec): task = None try: - task = rec.load_object("task.pkl") + task = rec.load_object("task") except OSError: - if task_filter_func is not None: - raise OSError('Can not find "task.pkl" in your records, have you train with "task_train" method in qlib.model.trainer?') + pass if task is None and only_have_task: continue - if task_filter_func is None or task_filter_func(task): rec.task = task recs_flt[rid] = rec @@ -68,7 +66,7 @@ class TaskCollector: def collect_predictions( self, get_key_func, - filter_func=None, + task_filter_func=None, ): """ @@ -85,7 +83,7 @@ class TaskCollector: dict the dict of predictions """ - recs_flt = self.list_recorders(task_filter_func=filter_func) + recs_flt = self.list_recorders(task_filter_func=task_filter_func,only_have_task=True) # group recs_group = {} @@ -108,11 +106,14 @@ class TaskCollector: def collect_latest_records( self, - filter_func=None, + task_filter_func=None, ): - recs_flt = self.list_recorders(task_filter_func=filter_func,only_have_task=True) - - max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values()) + recs_flt = self.list_recorders(task_filter_func=task_filter_func,only_have_task=True) + + if len(recs_flt) == 0: + self.logger.warning("Can not collect any recorders...") + return None, None + max_test = max(rec.task['dataset']['kwargs']['segments']['test'] for rec in recs_flt.values()) latest_record = {} for rid, rec in recs_flt.items(): @@ -120,52 +121,5 @@ class TaskCollector: latest_record[rid] = rec self.logger.info(f"Collect {len(latest_record)} latest records in {self.exp_name}") - return latest_record - - - -class RollingCollector: - """ - Rolling Models Ensemble based on (R)ecord - - This shares nothing with Ensemble - """ - - # TODO: speed up this class - def __init__(self, get_key_func, flt_func=None): - self.get_key_func = get_key_func # get the key of a task based on task config - self.flt_func = flt_func # determine whether a task can be retained based on task config - - def __call__(self, exp_name) -> Union[pd.Series, dict]: - # TODO; - # Should we split the scripts into several sub functions? - exp = R.get_exp(experiment_name=exp_name) - - # filter records - recs = exp.list_recorders() - - recs_flt = {} - for rid, rec in tqdm(recs.items(), desc="Loading data"): - params = rec.load_object("task.pkl") - if rec.status == rec.STATUS_FI: - if self.flt_func is None or self.flt_func(params): - rec.params = params - recs_flt[rid] = rec - - # group - recs_group = {} - for _, rec in recs_flt.items(): - params = rec.params - group_key = self.get_key_func(params) - recs_group.setdefault(group_key, []).append(rec) - - # reduce group - reduce_group = {} - for k, rec_l in recs_group.items(): - pred_l = [] - for rec in rec_l: - pred_l.append(rec.load_object("pred.pkl").iloc[:, 0]) - pred = pd.concat(pred_l).sort_index() - reduce_group[k] = pred - - return reduce_group + return latest_record, max_test + \ No newline at end of file diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index f27d02594..22b5430cc 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -10,10 +10,8 @@ A task consists of 3 parts from bson.binary import Binary import pickle from pymongo.errors import InvalidDocument -from fire import Fire from bson.objectid import ObjectId from contextlib import contextmanager -from loguru import logger from tqdm.cli import tqdm import time import concurrent @@ -21,7 +19,7 @@ import pymongo from qlib.config import C from .utils import get_mongodb from qlib import auto_init - +from qlib import get_module_logger class TaskManager: """TaskManager @@ -62,6 +60,7 @@ class TaskManager: """ self.mdb = get_mongodb() self.task_pool = task_pool + self.logger = get_module_logger("TaskManager") def list(self): return self.mdb.list_collection_names() @@ -210,9 +209,9 @@ class TaskManager: yield task except Exception: if task is not None: - logger.info("Returning task before raising error") + self.logger.info("Returning task before raising error") self.return_task(task) - logger.info("Task returned") + self.logger.info("Task returned") raise def task_fetcher_iter(self, query={}, task_pool=None): @@ -352,7 +351,7 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs): with tm.safe_fetch_task() as task: if task is None: break - logger.info(task["def"]) + get_module_logger("run_task").info(task["def"]) if force_release: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: res = executor.submit(task_func, task["def"], *args, **kwargs).result() diff --git a/qlib/workflow/task/update.py b/qlib/workflow/task/update.py index 5127a87da..9f1cc0a29 100644 --- a/qlib/workflow/task/update.py +++ b/qlib/workflow/task/update.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union,List from qlib.workflow import R from tqdm.auto import tqdm from qlib.data import D @@ -7,8 +7,10 @@ from qlib.utils import init_instance_by_config from qlib import get_module_logger from qlib.workflow import R from qlib.model.trainer import task_train +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.collect import TaskCollector -class ModelUpdater: +class ModelUpdater(TaskCollector): """ The model updater to re-train model or update predictions """ @@ -29,58 +31,59 @@ class ModelUpdater: self.exp = R.get_exp(experiment_name=experiment_name) self.logger = get_module_logger("ModelUpdater") - def set_online_model(self, rid: str): + def set_online_model(self, recorder: Union[str,Recorder]): """online model will be identified at the tags of the record Parameters ---------- - rid : str - the id of a record + recorder: Union[str,Recorder] + the id of a Recorder or the Recorder instance """ - rec = self.exp.get_recorder(recorder_id=rid) - rec.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_TRUE}) + if isinstance(recorder,str): + recorder = self.exp.get_recorder(recorder_id=recorder) + recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_TRUE}) - def cancel_online_model(self, rid: str): - rec = self.exp.get_recorder(recorder_id=rid) - rec.set_tags(**{self.ONLINE_TAG: self.ONLINE_TAG_FALSE}) + def cancel_online_model(self, recorder: Union[str,Recorder]): + if isinstance(recorder,str): + recorder = self.exp.get_recorder(recorder_id=recorder) + recorder.set_tags(**{ModelUpdater.ONLINE_TAG: ModelUpdater.ONLINE_TAG_FALSE}) def cancel_all_online_model(self): recs = self.exp.list_recorders() for rid, rec in recs.items(): - self.cancel_online_model(rid) + self.cancel_online_model(rec) - def reset_online_model(self, rids: Union[str, list]): + def reset_online_model(self, recorders: List[Union[str,Recorder]]): """cancel all online model and reset the given model to online model Parameters ---------- - rids : Union[str, list] - the name of a record or the list of the name of records + recorders: List[Union[str,Recorder]] + the list of the id of a Recorder or the Recorder instance """ self.cancel_all_online_model() - if isinstance(rids, str): - rids = [rids] - for rid in rids: - self.set_online_model(rid) + for rec_or_rid in recorders: + self.set_online_model(rec_or_rid) - def update_pred(self, rid: str): + def update_pred(self, recorder: Union[str,Recorder]): """update predictions to the latest day in Calendar based on rid Parameters ---------- - rid : str - the id of the record + recorder: Union[str,Recorder] + the id of a Recorder or the Recorder instance """ - rec = self.exp.get_recorder(recorder_id=rid) - old_pred = rec.load_object("pred.pkl") + if isinstance(recorder,str): + recorder = self.exp.get_recorder(recorder_id=recorder) + old_pred = recorder.load_object("pred.pkl") last_end = old_pred.index.get_level_values("datetime").max() - task_config = rec.load_object("task.pkl") + task_config = recorder.load_object("task") # recorder.task # updated to the latest trading day cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None) if len(cal) == 0: - self.logger.info(f"All prediction in {rid} of {self.exp_name} are latest. No need to update.") + self.logger.info(f"The prediction in {recorder.info['id']} of {self.exp_name} are latest. No need to update.") return start_time, end_time = cal[0], cal[-1] @@ -89,32 +92,32 @@ class ModelUpdater: dataset = init_instance_by_config(task_config["dataset"]) - model = rec.load_object("params.pkl") + model = recorder.load_object("params.pkl") new_pred = model.predict(dataset) cb_pred = pd.concat([old_pred, new_pred.to_frame("score")], axis=0) cb_pred = cb_pred.sort_index() - rec.save_objects(**{"pred.pkl": cb_pred}) + recorder.save_objects(**{"pred.pkl": cb_pred}) - self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {rid} of {self.exp_name}.") + self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {recorder.info['id']} of {self.exp_name}.") - def update_all_pred(self, filter_func=None): + def update_all_pred(self, rec_filter_func=None): """update all predictions in this experiment after filter. An example of filter function: .. code-block:: python - def record_filter(record): - task_config = record.load_object("task.pkl") + def rec_filter_func(recorder): + task_config = recorder.load_object("task") if task_config["model"]["class"]=="LGBModel": return True return False Parameters ---------- - filter_func : function, optional + rec_filter_func : Callable[[Recorder], bool], optional the filter function to decide whether this record will be updated, by default None Returns @@ -123,20 +126,14 @@ class ModelUpdater: the count of updated record """ - cnt = 0 - recs = self.exp.list_recorders() + recs = self.list_recorders(rec_filter_func=rec_filter_func,only_have_task=True) for rid, rec in recs.items(): - if rec.status == rec.STATUS_FI: - if filter_func != None and filter_func(rec) == False: - # records that should be filtered out - continue - self.update_pred(rid) - cnt += 1 - return cnt + self.update_pred(rec) + return len(recs) - def online_filter(self, record): - tags = record.list_tags() - if tags.get(self.ONLINE_TAG, self.ONLINE_TAG_FALSE) == self.ONLINE_TAG_TRUE: + def online_filter(self, recorder): + tags = recorder.list_tags() + if tags.get(ModelUpdater.ONLINE_TAG, ModelUpdater.ONLINE_TAG_FALSE) == ModelUpdater.ONLINE_TAG_TRUE: return True return False @@ -151,11 +148,7 @@ class ModelUpdater: Returns ------- dict - {rid : record of the online model} + {rid : recorder of the online model} """ - recs = self.exp.list_recorders() - online_rec = {} - for rid, rec in recs.items(): - if self.online_filter(rec): - online_rec[rid] = rec - return online_rec + + return self.list_recorders(rec_filter_func=self.online_filter) diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 63563e2f6..9445d3c68 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -50,7 +50,6 @@ class TimeAdjuster: if idx >= len(self.cals): return None return self.cals[idx] - def max(self): """ (Deprecated) @@ -86,6 +85,9 @@ class TimeAdjuster: raise NotImplementedError(f"This type of input is not supported") return idx + def cal_interval(self, time_point_A, time_point_B): + return self.align_idx(time_point_A) - self.align_idx(time_point_B) + def align_time(self, time_point, tp_type="start"): """ Align time_point to trade date of calendar