1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00
Files
qlib/qlib/workflow/recorder.py
you-n-g c0ce712be9 more detailed docs for workflow (#639)
* more detailed docs for workflow

* add more detailed docs for workflow
2021-10-11 15:38:18 +08:00

380 lines
13 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils.serial import Serializable
import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
from datetime import datetime
from qlib.utils.exceptions import LoadObjectError
from ..utils.objm import FileManager
from ..log import get_module_logger
logger = get_module_logger("workflow", logging.INFO)
class Recorder:
"""
This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
The status of the recorder can be SCHEDULED, RUNNING, FINISHED, FAILED.
"""
# status type
STATUS_S = "SCHEDULED"
STATUS_R = "RUNNING"
STATUS_FI = "FINISHED"
STATUS_FA = "FAILED"
def __init__(self, experiment_id, name):
self.id = None
self.name = name
self.experiment_id = experiment_id
self.start_time = None
self.end_time = None
self.status = Recorder.STATUS_S
def __repr__(self):
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)
def __str__(self):
return str(self.info)
def __hash__(self) -> int:
return hash(self.info["id"])
@property
def info(self):
output = dict()
output["class"] = "Recorder"
output["id"] = self.id
output["name"] = self.name
output["experiment_id"] = self.experiment_id
output["start_time"] = self.start_time
output["end_time"] = self.end_time
output["status"] = self.status
return output
def set_recorder_name(self, rname):
self.recorder_name = rname
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
"""
Save objects such as prediction file or model checkpoints to the artifact URI. User
can save object through keywords arguments (name:value).
Please refer to the docs of qlib.workflow:R.save_objects
Parameters
----------
local_path : str
if provided, them save the file or directory to the artifact URI.
artifact_path=None : str
the relative path for the artifact to be stored in the URI.
"""
raise NotImplementedError(f"Please implement the `save_objects` method.")
def load_object(self, name):
"""
Load objects such as prediction file or model checkpoints.
Parameters
----------
name : str
name of the file to be loaded.
Returns
-------
The saved object.
"""
raise NotImplementedError(f"Please implement the `load_object` method.")
def start_run(self):
"""
Start running or resuming the Recorder. The return value can be used as a context manager within a `with` block;
otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow)
Returns
-------
An active running object (e.g. mlflow.ActiveRun object).
"""
raise NotImplementedError(f"Please implement the `start_run` method.")
def end_run(self):
"""
End an active Recorder.
"""
raise NotImplementedError(f"Please implement the `end_run` method.")
def log_params(self, **kwargs):
"""
Log a batch of params for the current run.
Parameters
----------
keyword arguments
key, value pair to be logged as parameters.
"""
raise NotImplementedError(f"Please implement the `log_params` method.")
def log_metrics(self, step=None, **kwargs):
"""
Log multiple metrics for the current run.
Parameters
----------
keyword arguments
key, value pair to be logged as metrics.
"""
raise NotImplementedError(f"Please implement the `log_metrics` method.")
def set_tags(self, **kwargs):
"""
Log a batch of tags for the current run.
Parameters
----------
keyword arguments
key, value pair to be logged as tags.
"""
raise NotImplementedError(f"Please implement the `set_tags` method.")
def delete_tags(self, *keys):
"""
Delete some tags from a run.
Parameters
----------
keys : series of strs of the keys
all the name of the tag to be deleted.
"""
raise NotImplementedError(f"Please implement the `delete_tags` method.")
def list_artifacts(self, artifact_path: str = None):
"""
List all the artifacts of a recorder.
Parameters
----------
artifact_path : str
the relative path for the artifact to be stored in the URI.
Returns
-------
A list of artifacts information (name, path, etc.) that being stored.
"""
raise NotImplementedError(f"Please implement the `list_artifacts` method.")
def list_metrics(self):
"""
List all the metrics of a recorder.
Returns
-------
A dictionary of metrics that being stored.
"""
raise NotImplementedError(f"Please implement the `list_metrics` method.")
def list_params(self):
"""
List all the params of a recorder.
Returns
-------
A dictionary of params that being stored.
"""
raise NotImplementedError(f"Please implement the `list_params` method.")
def list_tags(self):
"""
List all the tags of a recorder.
Returns
-------
A dictionary of tags that being stored.
"""
raise NotImplementedError(f"Please implement the `list_tags` method.")
class MLflowRecorder(Recorder):
"""
Use mlflow to implement a Recorder.
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
use file manager to help maintain the objects in the project.
"""
def __init__(self, experiment_id, uri, name=None, mlflow_run=None):
super(MLflowRecorder, self).__init__(experiment_id, name)
self._uri = uri
self._artifact_uri = None
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
# construct from mlflow run
if mlflow_run is not None:
assert isinstance(mlflow_run, mlflow.entities.run.Run), "Please input with a MLflow Run object."
self.name = mlflow_run.data.tags["mlflow.runName"]
self.id = mlflow_run.info.run_id
self.status = mlflow_run.info.status
self.start_time = (
datetime.fromtimestamp(float(mlflow_run.info.start_time) / 1000.0).strftime("%Y-%m-%d %H:%M:%S")
if mlflow_run.info.start_time is not None
else None
)
self.end_time = (
datetime.fromtimestamp(float(mlflow_run.info.end_time) / 1000.0).strftime("%Y-%m-%d %H:%M:%S")
if mlflow_run.info.end_time is not None
else None
)
def __repr__(self):
name = self.__class__.__name__
space_length = len(name) + 1
return "{name}(info={info},\n{space}uri={uri},\n{space}artifact_uri={artifact_uri},\n{space}client={client})".format(
name=name,
space=" " * space_length,
info=self.info,
uri=self.uri,
artifact_uri=self.artifact_uri,
client=self.client,
)
def __hash__(self) -> int:
return hash(self.info["id"])
def __eq__(self, o: object) -> bool:
if isinstance(o, MLflowRecorder):
return self.info["id"] == o.info["id"]
return False
@property
def uri(self):
return self._uri
@property
def artifact_uri(self):
return self._artifact_uri
def get_local_dir(self):
"""
This function will return the directory path of this recorder.
"""
if self.artifact_uri is not None:
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
local_dir_path = str(local_dir_path.resolve())
if os.path.isdir(local_dir_path):
return local_dir_path
else:
raise RuntimeError("This recorder is not saved in the local file system.")
else:
raise Exception(
"Please make sure the recorder has been created and started properly before getting artifact uri."
)
def start_run(self):
# set the tracking uri
mlflow.set_tracking_uri(self.uri)
# start the run
run = mlflow.start_run(self.id, self.experiment_id, self.name)
# save the run id and artifact_uri
self.id = run.info.run_id
self._artifact_uri = run.info.artifact_uri
self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.status = Recorder.STATUS_R
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
return run
def end_run(self, status: str = Recorder.STATUS_S):
assert status in [
Recorder.STATUS_S,
Recorder.STATUS_R,
Recorder.STATUS_FI,
Recorder.STATUS_FA,
], f"The status type {status} is not supported."
mlflow.end_run(status)
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.status != Recorder.STATUS_S:
self.status = status
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
if local_path is not None:
path = Path(local_path)
if path.is_dir():
self.client.log_artifacts(self.id, local_path, artifact_path)
else:
self.client.log_artifact(self.id, local_path, artifact_path)
else:
temp_dir = Path(tempfile.mkdtemp()).resolve()
for name, data in kwargs.items():
path = temp_dir / name
Serializable.general_dump(data, path)
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
shutil.rmtree(temp_dir)
def load_object(self, name):
"""
Load object such as prediction file or model checkpoint in mlflow.
Args:
name (str): the object name
Raises:
LoadObjectError: if raise some exceptions when load the object
Returns:
object: the saved object in mlflow.
"""
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
try:
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
except Exception as e:
raise LoadObjectError(message=str(e))
def log_params(self, **kwargs):
for name, data in kwargs.items():
self.client.log_param(self.id, name, data)
def log_metrics(self, step=None, **kwargs):
for name, data in kwargs.items():
self.client.log_metric(self.id, name, data, step=step)
def set_tags(self, **kwargs):
for name, data in kwargs.items():
self.client.set_tag(self.id, name, data)
def delete_tags(self, *keys):
for key in keys:
self.client.delete_tag(self.id, key)
def get_artifact_uri(self):
if self.artifact_uri is not None:
return self.artifact_uri
else:
raise Exception(
"Please make sure the recorder has been created and started properly before getting artifact uri."
)
def list_artifacts(self, artifact_path=None):
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
artifacts = self.client.list_artifacts(self.id, artifact_path)
return [art.path for art in artifacts]
def list_metrics(self):
run = self.client.get_run(self.id)
return run.data.metrics
def list_params(self):
run = self.client.get_run(self.id)
return run.data.params
def list_tags(self):
run = self.client.get_run(self.id)
return run.data.tags