mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix recorder related bugs
This commit is contained in:
@@ -83,11 +83,9 @@ def init(default_conf="client", **kwargs):
|
||||
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
|
||||
# set up QlibRecorder
|
||||
uri = C["exp_uri"]
|
||||
# exp manager module
|
||||
module = get_module_by_module_path("qlib.workflow.expm")
|
||||
exp_manager = init_instance_by_config(C["exp_manager"], module)
|
||||
qr = QlibRecorder(exp_manager, uri)
|
||||
qr = QlibRecorder(exp_manager)
|
||||
R.register(qr)
|
||||
|
||||
|
||||
|
||||
@@ -43,26 +43,26 @@ class QlibRecorder:
|
||||
def get_uri(self):
|
||||
return self.exp_manager.get_uri()
|
||||
|
||||
def get_recorder(self):
|
||||
return self.exp_manager.active_recorder
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None):
|
||||
return self.exp_manager.active_experiment.get_recorder(recorder_id, recorder_name)
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
self.exp_manager.active_recorder.save_objects(local_path, artifact_path, **kwargs)
|
||||
self.exp_manager.active_experiment.active_recorder.save_objects(local_path, artifact_path, **kwargs)
|
||||
|
||||
def load_object(self, name):
|
||||
return self.exp_manager.active_recorder.load_object(name)
|
||||
return self.exp_manager.active_experiment.active_recorder.load_object(name)
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
self.exp_manager.active_recorder.log_params(**kwargs)
|
||||
self.exp_manager.active_experiment.active_recorder.log_params(**kwargs)
|
||||
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
self.exp_manager.active_recorder.log_metrics(step, **kwargs)
|
||||
self.exp_manager.active_experiment.active_recorder.log_metrics(step, **kwargs)
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
self.exp_manager.active_recorder.set_tags(**kwargs)
|
||||
self.exp_manager.active_experiment.active_recorder.set_tags(**kwargs)
|
||||
|
||||
def delete_tag(self, *key):
|
||||
self.exp_manager.active_recorder.delete_tag(*key)
|
||||
self.exp_manager.active_experiment.active_recorder.delete_tag(*key)
|
||||
|
||||
|
||||
# global record
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
import mlflow
|
||||
from pathlib import Path
|
||||
from .recorder import MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
class Experiment:
|
||||
"""
|
||||
|
||||
@@ -184,10 +184,12 @@ class MLflowExpManager(ExpManager):
|
||||
else:
|
||||
if experiment_name not in self.experiments:
|
||||
if mlflow.get_experiment_by_name(experiment_name) is not None:
|
||||
raise Exception(
|
||||
"The experiment has already been created before. Please pick another name or delete the files under uri."
|
||||
logger.info(
|
||||
"The experiment has already been created before. Try to resume the experiment..."
|
||||
)
|
||||
experiment_id = mlflow.create_experiment(experiment_name)
|
||||
experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
|
||||
else:
|
||||
experiment_id = mlflow.create_experiment(experiment_name)
|
||||
else:
|
||||
experiment_id = self.experiments[experiment_name].id
|
||||
experiment = self.experiments[experiment_name]
|
||||
|
||||
@@ -76,7 +76,7 @@ class SignalRecord(RecordTemp):
|
||||
def generate(self, **kwargs):
|
||||
# generate prediciton
|
||||
pred = self.model.predict(self.dataset)
|
||||
self.recorder.save_object(pred, "pred.pkl")
|
||||
self.recorder.save_objects(data=pred, name="pred.pkl")
|
||||
|
||||
def load(self):
|
||||
# try to load the saved object
|
||||
@@ -132,9 +132,9 @@ class PortAnaRecord(SignalRecord):
|
||||
assert super().check(), "Make sure the parent process is completed and store the data properly."
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.BACKTEST_CONFIG)
|
||||
self.recorder.save_object(report_normal, "report_normal.pkl", self.artifact_path)
|
||||
self.recorder.save_object(positions_normal, "positions_normal.pkl", self.artifact_path)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
self.recorder.save_objects(data=report_normal, name="report_normal.pkl", artifact_path=self.artifact_path)
|
||||
self.recorder.save_objects(data=positions_normal, name="positions_normal.pkl", artifact_path=self.artifact_path)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
@@ -143,7 +143,7 @@ class PortAnaRecord(SignalRecord):
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
self.recorder.save_object(pred, "port_analysis.pkl", self.artifact_path)
|
||||
self.recorder.save_objects(data=analysis_df, name="port_analysis.pkl", artifact_path=self.artifact_path)
|
||||
|
||||
def load(self):
|
||||
# try to load the saved object
|
||||
|
||||
@@ -201,7 +201,7 @@ class MLflowRecorder(Recorder):
|
||||
|
||||
def load_object(self, name):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
path = client.download_artifacts(self.recorder_id, name)
|
||||
path = client.download_artifacts(self.id, name)
|
||||
try:
|
||||
with Path(path).open("rb") as f:
|
||||
f.seek(0)
|
||||
@@ -242,5 +242,5 @@ class MLflowRecorder(Recorder):
|
||||
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
artifacts = client.list_artifacts(self.id, path)
|
||||
artifacts = client.list_artifacts(self.id, artifact_path)
|
||||
return artifacts
|
||||
|
||||
Reference in New Issue
Block a user