diff --git a/qlib/utils/exceptions.py b/qlib/utils/exceptions.py new file mode 100644 index 000000000..69712172b --- /dev/null +++ b/qlib/utils/exceptions.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Base exception class +class QlibException(Exception): + def __init__(self, message): + super(QlibException, self).__init__(message) + + +# Error type for reinitialization when starting an experiment +class RecorderInitializationError(QlibException): + pass diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 2b2535edc..98b2c9925 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -7,6 +7,7 @@ from .expm import MLflowExpManager from .exp import Experiment from .recorder import Recorder from ..utils import Wrapper +from ..utils.exceptions import RecorderInitializationError class QlibRecorder: @@ -525,14 +526,29 @@ class QlibRecorder: self.get_exp().get_recorder().set_tags(**kwargs) +class RecorderWrapper(Wrapper): + """ + Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment. + """ + + def register(self, provider): + if self._provider is not None: + expm = getattr(self._provider, "exp_manager") + if expm.active_experiment is not None: + raise RecorderInitializationError( + "Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified." + ) + self._provider = provider + + import sys if sys.version_info >= (3, 9): from typing import Annotated - QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper] + QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper] else: QlibRecorderWrapper = QlibRecorder # global record -R: QlibRecorderWrapper = Wrapper() +R: QlibRecorderWrapper = RecorderWrapper()