From d96f7a67c60fc50e055dfa138b859c3641a38e14 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Wed, 23 Jun 2021 08:46:21 +0000 Subject: [PATCH] bug & docs fixed --- qlib/data/dataset/loader.py | 7 ++++++- qlib/log.py | 2 +- qlib/utils/serial.py | 6 +++--- qlib/workflow/online/utils.py | 5 +++-- qlib/workflow/recorder.py | 12 +++++++++--- 5 files changed, 22 insertions(+), 10 deletions(-) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 2ad110b89..b4c75e104 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -207,7 +207,12 @@ class StaticDataLoader(DataLoader): df = self._data.loc(axis=0)[:, instruments] if start_time is None and end_time is None: return df # NOTE: avoid copy by loc - return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)] + # pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None. + if start_time is not None: + start_time = pd.Timestamp(start_time) + if end_time is not None: + end_time = pd.Timestamp(end_time) + return df.loc[start_time:end_time] def _maybe_load_raw_data(self): if self._data is not None: diff --git a/qlib/log.py b/qlib/log.py index ad55e2200..0b8b04b48 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -70,7 +70,7 @@ def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logge class TimeInspector: - timer_logger = get_module_logger("timer", level=logging.WARNING) + timer_logger = get_module_logger("timer", level=logging.INFO) time_marks = [] diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 263e632de..4189f8e61 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -92,16 +92,16 @@ class Serializable: @classmethod def load(cls, filepath): """ - Load the collector from a filepath. + Load the serializable class from a filepath. Args: filepath (str): the path of file Raises: - TypeError: the pickled file must be `Collector` + TypeError: the pickled file must be `type(cls)` Returns: - Collector: the instance of Collector + `type(cls)`: the instance of `type(cls)` """ with open(filepath, "rb") as f: object = cls.get_backend().load(f) diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index 86763d5d6..83cacbb88 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -12,6 +12,7 @@ from qlib.data.dataset import TSDatasetH from qlib.log import get_module_logger from qlib.utils import get_cls_kwargs +from qlib.utils.exceptions import QlibException from qlib.workflow.online.update import PredUpdater from qlib.workflow.recorder import Recorder from qlib.workflow.task.utils import list_recorders @@ -191,9 +192,9 @@ class OnlineToolR(OnlineTool): hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN) try: updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref) - except OSError: + except QlibException as e: # skip the recorder without pred - self.logger.warn(f"Can't find `pred.pkl`, skip it.") + self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.") continue updater.update() diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 0c9abf731..e11287ca1 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -5,6 +5,9 @@ import mlflow, logging import shutil, os, pickle, tempfile, codecs, pickle from pathlib import Path from datetime import datetime + +from mlflow.exceptions import MlflowException +from qlib.utils.exceptions import QlibException from ..utils.objm import FileManager from ..log import get_module_logger @@ -308,9 +311,12 @@ class MLflowRecorder(Recorder): def load_object(self, name): assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." - path = self.client.download_artifacts(self.id, name) - with Path(path).open("rb") as f: - return pickle.load(f) + try: + path = self.client.download_artifacts(self.id, name) + with Path(path).open("rb") as f: + return pickle.load(f) + except OSError as e: + raise QlibException(message=str(e)) def log_params(self, **kwargs): for name, data in kwargs.items():