1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 09:31:18 +08:00
Files
qlib/qlib/workflow/recorder.py
2020-12-23 21:14:53 +08:00

307 lines
10 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import mlflow
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
from datetime import datetime
from ..utils.objm import FileManager
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
class Recorder:
"""
This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
The status of the recorder can be SCHEDULED, RUNNING, FINISHED, FAILED.
"""
# status type
STATUS_S = "SCHEDULED"
STATUS_R = "RUNNING"
STATUS_FI = "FINISHED"
STATUS_FA = "FAILED"
def __init__(self, experiment_id, name):
self.id = None
self.name = name
self.experiment_id = experiment_id
self.start_time = None
self.end_time = None
self.status = Recorder.STATUS_S
def __repr__(self):
return str(self.info)
def __str__(self):
return str(self.info)
@property
def info(self):
output = dict()
output["class"] = "Recorder"
output["id"] = self.id
output["name"] = self.name
output["experiment_id"] = self.experiment_id
output["start_time"] = self.start_time
output["end_time"] = self.end_time
output["status"] = self.status
return output
def set_recorder_name(self, rname):
self.recorder_name = rname
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
"""
Save objects such as prediction file or model checkpoints to the artifact URI. User
can save object through keywords arguments (name:value).
Parameters
----------
local_path : str
if provided, them save the file or directory to the artifact URI.
artifact_path=None : str
the relative path for the artifact to be stored in the URI.
"""
raise NotImplementedError(f"Please implement the `save_objects` method.")
def load_object(self, name):
"""
Load objects such as prediction file or model checkpoints.
Parameters
----------
name : str
name of the file to be loaded.
Returns
-------
The saved object.
"""
raise NotImplementedError(f"Please implement the `load_object` method.")
def start_run(self):
"""
Start running or resuming the Recorder. The return value can be used as a context manager within a `with` block;
otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow)
Returns
-------
An active running object (e.g. mlflow.ActiveRun object).
"""
raise NotImplementedError(f"Please implement the `start_run` method.")
def end_run(self):
"""
End an active Recorder.
"""
raise NotImplementedError(f"Please implement the `end_run` method.")
def log_params(self, **kwargs):
"""
Log a batch of params for the current run.
Parameters
----------
keyword arguments
key, value pair to be logged as parameters.
"""
raise NotImplementedError(f"Please implement the `log_params` method.")
def log_metrics(self, step=None, **kwargs):
"""
Log multiple metrics for the current run.
Parameters
----------
keyword arguments
key, value pair to be logged as metrics.
"""
raise NotImplementedError(f"Please implement the `log_metrics` method.")
def set_tags(self, **kwargs):
"""
Log a batch of tags for the current run.
Parameters
----------
keyword arguments
key, value pair to be logged as tags.
"""
raise NotImplementedError(f"Please implement the `set_tags` method.")
def delete_tags(self, *keys):
"""
Delete some tags from a run.
Parameters
----------
keys : series of strs of the keys
all the name of the tag to be deleted.
"""
raise NotImplementedError(f"Please implement the `delete_tags` method.")
def list_artifacts(self, artifact_path: str = None):
"""
List all the artifacts of a recorder.
Parameters
----------
artifact_path : str
the relative path for the artifact to be stored in the URI.
Returns
-------
A list of artifacts information (name, path, etc.) that being stored.
"""
raise NotImplementedError(f"Please implement the `list_artifacts` method.")
def list_metrics(self):
"""
List all the metrics of a recorder.
Returns
-------
A dictionary of metrics that being stored.
"""
raise NotImplementedError(f"Please implement the `list_metrics` method.")
def list_params(self):
"""
List all the params of a recorder.
Returns
-------
A dictionary of params that being stored.
"""
raise NotImplementedError(f"Please implement the `list_params` method.")
def list_tags(self):
"""
List all the tags of a recorder.
Returns
-------
A dictionary of tags that being stored.
"""
raise NotImplementedError(f"Please implement the `list_tags` method.")
class MLflowRecorder(Recorder):
"""
Use mlflow to implement a Recorder.
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
use file manager to help maintain the objects in the project.
"""
def __init__(self, experiment_id, uri, name=None, mlflow_run=None):
super(MLflowRecorder, self).__init__(experiment_id, name)
self._uri = uri
self.artifact_uri = None
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
# construct from mlflow run
if mlflow_run is not None:
assert isinstance(mlflow_run, mlflow.entities.run.Run), "Please input with a MLflow Run object."
self.name = mlflow_run.data.tags["mlflow.runName"]
self.id = mlflow_run.info.run_id
self.status = mlflow_run.info.status
self.start_time = (
datetime.fromtimestamp(float(mlflow_run.info.start_time) / 1000.0).strftime("%Y-%m-%d %H:%M:%S")
if mlflow_run.info.start_time is not None
else None
)
self.end_time = (
datetime.fromtimestamp(float(mlflow_run.info.end_time) / 1000.0).strftime("%Y-%m-%d %H:%M:%S")
if mlflow_run.info.end_time is not None
else None
)
def start_run(self):
# set the tracking uri
mlflow.set_tracking_uri(self._uri)
# start the run
run = mlflow.start_run(self.id, self.experiment_id, self.name)
# save the run id and artifact_uri
self.id = run.info.run_id
self.artifact_uri = run.info.artifact_uri
self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.status = Recorder.STATUS_R
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
return run
def end_run(self, status: str = Recorder.STATUS_S):
assert status in [
Recorder.STATUS_S,
Recorder.STATUS_R,
Recorder.STATUS_FI,
Recorder.STATUS_FA,
], f"The status type {status} is not supported."
mlflow.end_run(status)
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.status != Recorder.STATUS_S:
self.status = status
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
if local_path is not None:
self.client.log_artifacts(self.id, local_path, artifact_path)
else:
temp_dir = Path(tempfile.mkdtemp()).resolve()
for name, data in kwargs.items():
with (temp_dir / name).open("wb") as f:
pickle.dump(data, f)
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
shutil.rmtree(temp_dir)
def load_object(self, name):
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
def log_params(self, **kwargs):
for name, data in kwargs.items():
self.client.log_param(self.id, name, data)
def log_metrics(self, step=None, **kwargs):
for name, data in kwargs.items():
self.client.log_metric(self.id, name, data, step=step)
def set_tags(self, **kwargs):
for name, data in kwargs.items():
self.client.set_tag(self.id, name, data)
def delete_tags(self, *keys):
for key in keys:
self.client.delete_tag(self.id, key)
def get_artifact_uri(self):
if self.artifact_uri is not None:
return self.artifact_uri
else:
raise Exception(
"Please make sure the recorder has been created and started properly before getting artifact uri."
)
def list_artifacts(self, artifact_path=None):
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
artifacts = self.client.list_artifacts(self.id, artifact_path)
return [art.path for art in artifacts]
def list_metrics(self):
run = self.client.get_run(self.id)
return run.data.metrics
def list_params(self):
run = self.client.get_run(self.id)
return run.data.params
def list_tags(self):
run = self.client.get_run(self.id)
return run.data.tags