mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
Merge pull request #471 from Derek-Wds/main
Update Recorder Wrapper to prevent reinitialization
This commit is contained in:
12
qlib/utils/exceptions.py
Normal file
12
qlib/utils/exceptions.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Base exception class
|
||||
class QlibException(Exception):
|
||||
def __init__(self, message):
|
||||
super(QlibException, self).__init__(message)
|
||||
|
||||
|
||||
# Error type for reinitialization when starting an experiment
|
||||
class RecorderInitializationError(QlibException):
|
||||
pass
|
||||
@@ -7,6 +7,7 @@ from .expm import MLflowExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
from ..utils.exceptions import RecorderInitializationError
|
||||
|
||||
|
||||
class QlibRecorder:
|
||||
@@ -525,14 +526,29 @@ class QlibRecorder:
|
||||
self.get_exp().get_recorder().set_tags(**kwargs)
|
||||
|
||||
|
||||
class RecorderWrapper(Wrapper):
|
||||
"""
|
||||
Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment.
|
||||
"""
|
||||
|
||||
def register(self, provider):
|
||||
if self._provider is not None:
|
||||
expm = getattr(self._provider, "exp_manager")
|
||||
if expm.active_experiment is not None:
|
||||
raise RecorderInitializationError(
|
||||
"Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified."
|
||||
)
|
||||
self._provider = provider
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper]
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper]
|
||||
else:
|
||||
QlibRecorderWrapper = QlibRecorder
|
||||
|
||||
# global record
|
||||
R: QlibRecorderWrapper = Wrapper()
|
||||
R: QlibRecorderWrapper = RecorderWrapper()
|
||||
|
||||
Reference in New Issue
Block a user