diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 7ad5f4c6d..5cc7d3c2d 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -524,20 +524,18 @@ class TSDatasetH(DatasetH): def setup_data(self, **kwargs): super().setup_data(**kwargs) + # make sure the calendar is updated to latest when loading data from new config cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() - cal = sorted(cal) - self.cal = cal + self.cal = sorted(cal) - def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame: + @staticmethod + def _extend_slice(slc: slice, cal: list, step_len: int) -> slice: # Dataset decide how to slice data(Get more data for timeseries). start, end = slc.start, slc.stop - start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start)) - pad_start_idx = max(0, start_idx - self.step_len) - pad_start = self.cal[pad_start_idx] - - # TSDatasetH will retrieve more data for complete - data = super()._prepare_seg(slice(pad_start, end), **kwargs) - return data + start_idx = bisect.bisect_left(cal, pd.Timestamp(start)) + pad_start_idx = max(0, start_idx - step_len) + pad_start = cal[pad_start_idx] + return slice(pad_start, end) def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: """ @@ -547,12 +545,14 @@ class TSDatasetH(DatasetH): start, end = slc.start, slc.stop flt_col = kwargs.pop("flt_col", None) # TSDatasetH will retrieve more data for complete time-series - data = self._prepare_raw_seg(slc, **kwargs) + + ext_slice = self._extend_slice(slc, self.cal, self.step_len) + data = super()._prepare_seg(ext_slice, **kwargs) flt_kwargs = deepcopy(kwargs) if flt_col is not None: flt_kwargs["col_set"] = flt_col - flt_data = self._prepare_raw_seg(slc, **flt_kwargs) + flt_data = self._prepare_seg(ext_slice, **flt_kwargs) assert len(flt_data.columns) == 1 else: flt_data = None diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py index 075a1adb8..48b427a28 100644 --- a/qlib/utils/paral.py +++ b/qlib/utils/paral.py @@ -1,9 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pandas as pd +from functools import partial +from threading import Thread +from typing import Callable + from joblib import Parallel, delayed from joblib._parallel_backends import MultiprocessingBackend +import pandas as pd + +from queue import Queue class ParallelExt(Parallel): @@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru return pd.concat(dfs, axis=axis).sort_index() else: return _naive_group_apply(df) + + +class AsyncCaller: + """ + This AsyncCaller tries to make it easier to async call + + Currently, it is used in MLflowRecorder to make functions like `log_params` async + + NOTE: + - This caller didn't consider the return value + """ + + STOP_MARK = "__STOP" + + def __init__(self) -> None: + self._q = Queue() + self._stop = False + self._t = Thread(target=self.run) + self._t.start() + + def close(self): + self._q.put(self.STOP_MARK) + + def run(self): + while True: + data = self._q.get() + if data == self.STOP_MARK: + break + else: + data() + + def __call__(self, func, *args, **kwargs): + self._q.put(partial(func, *args, **kwargs)) + + def wait(self, close=True): + if close: + self.close() + self._t.join() + + @staticmethod + def async_dec(ac_attr): + def decorator_func(func): + def wrapper(self, *args, **kwargs): + if isinstance(getattr(self, ac_attr, None), Callable): + return getattr(self, ac_attr)(func, self, *args, **kwargs) + else: + return func(self, *args, **kwargs) + + return wrapper + + return decorator_func diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 13c4bc7a0..2fff37eaa 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -9,8 +9,9 @@ from pathlib import Path from datetime import datetime from qlib.utils.exceptions import LoadObjectError +from qlib.utils.paral import AsyncCaller from ..utils.objm import FileManager -from ..log import get_module_logger +from ..log import TimeInspector, get_module_logger from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository logger = get_module_logger("workflow", logging.INFO) @@ -229,6 +230,7 @@ class MLflowRecorder(Recorder): if mlflow_run.info.end_time is not None else None ) + self.async_log = None def __repr__(self): name = self.__class__.__name__ @@ -287,6 +289,10 @@ class MLflowRecorder(Recorder): self.status = Recorder.STATUS_R logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...") + # NOTE: making logging async. + # - This may cause delay when uploading results + # - The logging time may not be accurate + self.async_log = AsyncCaller() return run def end_run(self, status: str = Recorder.STATUS_S): @@ -300,6 +306,8 @@ class MLflowRecorder(Recorder): self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if self.status != Recorder.STATUS_S: self.status = status + with TimeInspector.logt("waiting `async_log`"): + self.async_log.wait() def save_objects(self, local_path=None, artifact_path=None, **kwargs): assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." @@ -345,14 +353,17 @@ class MLflowRecorder(Recorder): except Exception as e: raise LoadObjectError(message=str(e)) + @AsyncCaller.async_dec(ac_attr="async_log") def log_params(self, **kwargs): for name, data in kwargs.items(): self.client.log_param(self.id, name, data) + @AsyncCaller.async_dec(ac_attr="async_log") def log_metrics(self, step=None, **kwargs): for name, data in kwargs.items(): self.client.log_metric(self.id, name, data, step=step) + @AsyncCaller.async_dec(ac_attr="async_log") def set_tags(self, **kwargs): for name, data in kwargs.items(): self.client.set_tag(self.id, name, data)