1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

Simplify TSDataset and async recorder

This commit is contained in:
Young
2021-11-02 11:03:23 +08:00
parent 7a884fa9f2
commit 2593185721
3 changed files with 82 additions and 14 deletions

View File

@@ -524,20 +524,18 @@ class TSDatasetH(DatasetH):
def setup_data(self, **kwargs):
super().setup_data(**kwargs)
# make sure the calendar is updated to latest when loading data from new config
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
self.cal = cal
self.cal = sorted(cal)
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
@staticmethod
def _extend_slice(slc: slice, cal: list, step_len: int) -> slice:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
pad_start_idx = max(0, start_idx - self.step_len)
pad_start = self.cal[pad_start_idx]
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
return data
start_idx = bisect.bisect_left(cal, pd.Timestamp(start))
pad_start_idx = max(0, start_idx - step_len)
pad_start = cal[pad_start_idx]
return slice(pad_start, end)
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
"""
@@ -547,12 +545,14 @@ class TSDatasetH(DatasetH):
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete time-series
data = self._prepare_raw_seg(slc, **kwargs)
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
data = super()._prepare_seg(ext_slice, **kwargs)
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None

View File

@@ -1,9 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from functools import partial
from threading import Thread
from typing import Callable
from joblib import Parallel, delayed
from joblib._parallel_backends import MultiprocessingBackend
import pandas as pd
from queue import Queue
class ParallelExt(Parallel):
@@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
return pd.concat(dfs, axis=axis).sort_index()
else:
return _naive_group_apply(df)
class AsyncCaller:
"""
This AsyncCaller tries to make it easier to async call
Currently, it is used in MLflowRecorder to make functions like `log_params` async
NOTE:
- This caller didn't consider the return value
"""
STOP_MARK = "__STOP"
def __init__(self) -> None:
self._q = Queue()
self._stop = False
self._t = Thread(target=self.run)
self._t.start()
def close(self):
self._q.put(self.STOP_MARK)
def run(self):
while True:
data = self._q.get()
if data == self.STOP_MARK:
break
else:
data()
def __call__(self, func, *args, **kwargs):
self._q.put(partial(func, *args, **kwargs))
def wait(self, close=True):
if close:
self.close()
self._t.join()
@staticmethod
def async_dec(ac_attr):
def decorator_func(func):
def wrapper(self, *args, **kwargs):
if isinstance(getattr(self, ac_attr, None), Callable):
return getattr(self, ac_attr)(func, self, *args, **kwargs)
else:
return func(self, *args, **kwargs)
return wrapper
return decorator_func

View File

@@ -9,8 +9,9 @@ from pathlib import Path
from datetime import datetime
from qlib.utils.exceptions import LoadObjectError
from qlib.utils.paral import AsyncCaller
from ..utils.objm import FileManager
from ..log import get_module_logger
from ..log import TimeInspector, get_module_logger
from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
logger = get_module_logger("workflow", logging.INFO)
@@ -229,6 +230,7 @@ class MLflowRecorder(Recorder):
if mlflow_run.info.end_time is not None
else None
)
self.async_log = None
def __repr__(self):
name = self.__class__.__name__
@@ -287,6 +289,10 @@ class MLflowRecorder(Recorder):
self.status = Recorder.STATUS_R
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
# NOTE: making logging async.
# - This may cause delay when uploading results
# - The logging time may not be accurate
self.async_log = AsyncCaller()
return run
def end_run(self, status: str = Recorder.STATUS_S):
@@ -300,6 +306,8 @@ class MLflowRecorder(Recorder):
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.status != Recorder.STATUS_S:
self.status = status
with TimeInspector.logt("waiting `async_log`"):
self.async_log.wait()
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
@@ -345,14 +353,17 @@ class MLflowRecorder(Recorder):
except Exception as e:
raise LoadObjectError(message=str(e))
@AsyncCaller.async_dec(ac_attr="async_log")
def log_params(self, **kwargs):
for name, data in kwargs.items():
self.client.log_param(self.id, name, data)
@AsyncCaller.async_dec(ac_attr="async_log")
def log_metrics(self, step=None, **kwargs):
for name, data in kwargs.items():
self.client.log_metric(self.id, name, data, step=step)
@AsyncCaller.async_dec(ac_attr="async_log")
def set_tags(self, **kwargs):
for name, data in kwargs.items():
self.client.set_tag(self.id, name, data)