mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Simplify TSDataset and async recorder
This commit is contained in:
@@ -524,20 +524,18 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
# make sure the calendar is updated to latest when loading data from new config
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
self.cal = cal
|
||||
self.cal = sorted(cal)
|
||||
|
||||
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
|
||||
@staticmethod
|
||||
def _extend_slice(slc: slice, cal: list, step_len: int) -> slice:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
start, end = slc.start, slc.stop
|
||||
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
|
||||
pad_start_idx = max(0, start_idx - self.step_len)
|
||||
pad_start = self.cal[pad_start_idx]
|
||||
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
|
||||
return data
|
||||
start_idx = bisect.bisect_left(cal, pd.Timestamp(start))
|
||||
pad_start_idx = max(0, start_idx - step_len)
|
||||
pad_start = cal[pad_start_idx]
|
||||
return slice(pad_start, end)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
@@ -547,12 +545,14 @@ class TSDatasetH(DatasetH):
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
data = self._prepare_raw_seg(slc, **kwargs)
|
||||
|
||||
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
|
||||
data = super()._prepare_seg(ext_slice, **kwargs)
|
||||
|
||||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs["col_set"] = flt_col
|
||||
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
|
||||
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
import pandas as pd
|
||||
|
||||
from queue import Queue
|
||||
|
||||
|
||||
class ParallelExt(Parallel):
|
||||
@@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
|
||||
return pd.concat(dfs, axis=axis).sort_index()
|
||||
else:
|
||||
return _naive_group_apply(df)
|
||||
|
||||
|
||||
class AsyncCaller:
|
||||
"""
|
||||
This AsyncCaller tries to make it easier to async call
|
||||
|
||||
Currently, it is used in MLflowRecorder to make functions like `log_params` async
|
||||
|
||||
NOTE:
|
||||
- This caller didn't consider the return value
|
||||
"""
|
||||
|
||||
STOP_MARK = "__STOP"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._q = Queue()
|
||||
self._stop = False
|
||||
self._t = Thread(target=self.run)
|
||||
self._t.start()
|
||||
|
||||
def close(self):
|
||||
self._q.put(self.STOP_MARK)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
data = self._q.get()
|
||||
if data == self.STOP_MARK:
|
||||
break
|
||||
else:
|
||||
data()
|
||||
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
self._q.put(partial(func, *args, **kwargs))
|
||||
|
||||
def wait(self, close=True):
|
||||
if close:
|
||||
self.close()
|
||||
self._t.join()
|
||||
|
||||
@staticmethod
|
||||
def async_dec(ac_attr):
|
||||
def decorator_func(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if isinstance(getattr(self, ac_attr, None), Callable):
|
||||
return getattr(self, ac_attr)(func, self, *args, **kwargs)
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_func
|
||||
|
||||
@@ -9,8 +9,9 @@ from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from qlib.utils.paral import AsyncCaller
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
from ..log import TimeInspector, get_module_logger
|
||||
from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
@@ -229,6 +230,7 @@ class MLflowRecorder(Recorder):
|
||||
if mlflow_run.info.end_time is not None
|
||||
else None
|
||||
)
|
||||
self.async_log = None
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
@@ -287,6 +289,10 @@ class MLflowRecorder(Recorder):
|
||||
self.status = Recorder.STATUS_R
|
||||
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
|
||||
|
||||
# NOTE: making logging async.
|
||||
# - This may cause delay when uploading results
|
||||
# - The logging time may not be accurate
|
||||
self.async_log = AsyncCaller()
|
||||
return run
|
||||
|
||||
def end_run(self, status: str = Recorder.STATUS_S):
|
||||
@@ -300,6 +306,8 @@ class MLflowRecorder(Recorder):
|
||||
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.status != Recorder.STATUS_S:
|
||||
self.status = status
|
||||
with TimeInspector.logt("waiting `async_log`"):
|
||||
self.async_log.wait()
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
@@ -345,14 +353,17 @@ class MLflowRecorder(Recorder):
|
||||
except Exception as e:
|
||||
raise LoadObjectError(message=str(e))
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def log_params(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_param(self.id, name, data)
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_metric(self.id, name, data, step=step)
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def set_tags(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.set_tag(self.id, name, data)
|
||||
|
||||
Reference in New Issue
Block a user