1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00

Update Exp related codes

This commit is contained in:
Jactus
2020-10-29 12:58:52 +08:00
parent 1a9ee6cef8
commit 60d0cfcf64
8 changed files with 426 additions and 599 deletions

View File

@@ -13,7 +13,7 @@ import platform
import yaml
from pathlib import Path
from .utils import can_use_cache
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
# init qlib
@@ -22,6 +22,7 @@ def init(default_conf="client", **kwargs):
from .data.data import register_all_wrappers
from .log import get_module_logger, set_log_with_config
from .data.cache import H
from .workflow import R, QlibRecorder
C.reset()
H.clear()
@@ -79,6 +80,15 @@ def init(default_conf="client", **kwargs):
if "flask_server" in C:
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
# set up QlibRecorder
default_uri = str(Path(os.getcwd()).resolve() / "mlruns")
current_uri = C['exp_uri'] if C['exp_uri'] is not None else default_uri
# exp manager module
module = get_module_by_module_path('qlib.workflow')
exp_manager = init_instance_by_config(C['exp_manager'], module)
qr = QlibRecorder(exp_manager, default_uri, current_uri)
R.register(qr)
def _mount_nfs_uri(C):

View File

@@ -124,6 +124,12 @@ _default_config = {
},
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
},
# Defatult config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"kwargs": {}
},
"exp_uri": None,
}
MODE_CONF = {

View File

@@ -24,6 +24,7 @@ from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, get_provider_obj, register_wrapper
class CalendarProvider(abc.ABC):
@@ -1019,44 +1020,6 @@ class ClientProvider(BaseProvider):
DatasetD.set_conn(self.client)
class Wrapper(object):
"""Data Provider Wrapper"""
def __init__(self):
self._provider = None
def register(self, provider):
self._provider = provider
def __getattr__(self, key):
if self._provider is None:
raise AttributeError("Please run qlib.init() first using qlib")
return getattr(self._provider, key)
def get_cls_from_name(cls_name):
return getattr(importlib.import_module(".data", package="qlib"), cls_name)
def get_provider_obj(config, **params):
if isinstance(config, dict):
params.update(config["kwargs"])
config = config["class"]
return get_cls_from_name(config)(**params)
def register_wrapper(wrapper, cls_or_obj):
"""register_wrapper
:param wrapper: A wrapper of all kinds of providers
:param cls_or_obj: A class or class name or object instance in data/data.py
"""
if isinstance(cls_or_obj, str):
cls_or_obj = get_cls_from_name(cls_or_obj)
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
wrapper.register(obj)
Cal = Wrapper()
Inst = Wrapper()
FeatureD = Wrapper()

View File

@@ -611,3 +611,39 @@ def exists_qlib_data(qlib_dir):
return False
return True
#################### Wrapper #####################
class Wrapper(object):
"""Data Provider Wrapper"""
def __init__(self):
self._provider = None
def register(self, provider):
self._provider = provider
def __getattr__(self, key):
if self._provider is None:
raise AttributeError("Please run qlib.init() first using qlib")
return getattr(self._provider, key)
def get_provider_obj(config, **params):
module = get_module_by_module_path("qlib.data")
klass, kwargs = get_cls_kwargs(config, module)
kwargs.update(params)
return klass(**kwargs)
def register_wrapper(wrapper, cls_or_obj):
"""register_wrapper
:param wrapper: A wrapper of all kinds of providers
:param cls_or_obj: A class or class name or object instance in data/data.py
"""
if isinstance(cls_or_obj, str):
module = get_module_by_module_path("qlib.data")
cls_or_obj = getattr(module, cls_or_obj)
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
wrapper.register(obj)

View File

@@ -2,156 +2,62 @@
# Licensed under the MIT License.
from contextlib import contextmanager
from .record import MLflowRecorder
from .exp import MLflowExpManager
from .expm import *
from ..utils import Wrapper
class Record:
def __init__(self):
pass
class QlibRecorder:
def __init__(self, exp_manager, default_uri, current_uri):
self.exp_manager = exp_manager
self.default_uri = default_uri
self.current_uri = current_uri
@contextmanager
def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False):
raise NotImplementedError(f"Please implement the `start_exp` method.")
def start(self, experiment_name):
run = self.start_exp(experiment_name, self.current_uri)
yield run
self.end_exp()
def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None):
raise NotImplementedError(f"Please implement the `search_runs` method.")
def start_exp(self, experiment_name=None):
return self.exp_manager.start_exp(experiment_name, self.current_uri)
def end_exp(self):
self.exp_manager.end_exp()
def get_exp(self, experiment_id):
raise NotImplementedError(f"Please implement the `get_exp` method.")
def get_exp_by_name(self, experiment_name):
raise NotImplementedError(f"Please implement the `get_exp_by_name` method.")
def search_records(self, experiment_ids, **kwargs):
return self.exp_manager.search_records(experiment_ids, **kwargs)
def create_exp(self, experiment_name, artifact_location=None):
raise NotImplementedError(f"Please implement the `create_exp` method.")
def set_exp(self, experiment_name):
raise NotImplementedError(f"Please implement the `set_exp` method.")
def delete_exp(self, experiment_id):
raise NotImplementedError(f"Please implement the `create_exp` method.")
def set_tracking_uri(self, uri):
raise NotImplementedError(f"Please implement the `set_tracking_uri` method.")
def get_tracking_uri(self):
raise NotImplementedError(f"Please implement the `get_tracking_uri` method.")
def get_recorder(self):
raise NotImplementedError(f"Please implement the `get_recorder` method.")
def save_object(self, name, data):
raise NotImplementedError(f"Please implement the `save_object` method.")
def save_objects(self, name_data_list):
raise NotImplementedError(f"Please implement the `save_objects` method.")
def load_object(self, name):
raise NotImplementedError(f"Please implement the `load_object` method.")
def log_param(self, key, value):
raise NotImplementedError(f"Please implement the `log_param` method.")
def log_params(self, params):
raise NotImplementedError(f"Please implement the `log_params` method.")
def log_metric(self, key, value, step=None):
raise NotImplementedError(f"Please implement the `log_metric` method.")
def log_metrics(self, metrics, step=None):
raise NotImplementedError(f"Please implement the `log_metrics` method.")
def set_tag(self, key, value):
raise NotImplementedError(f"Please implement the `set_tag` method.")
def set_tags(self, tags):
raise NotImplementedError(f"Please implement the `log_tags` method.")
def delete_tag(self, key):
raise NotImplementedError(f"Please implement the `delete_tag` method.")
def log_artifact(self, local_path, artifact_path=None):
raise NotImplementedError(f"Please implement the `log_artifact` method.")
def log_artifacts(self, local_dir, artifact_path=None):
raise NotImplementedError(f"Please implement the `log_artifacts` method.")
def get_artifact_uri(self, artifact_path=None):
raise NotImplementedError(f"Please implement the `get_artifact_uri` method.")
class MLflowRecord(Record):
def __init__(self):
self.exp_manager = MLflowExpManager()
@contextmanager
def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False):
yield self.exp_manager.start_exp(experiment_name, uri, project_path, artifact_location, nested)
def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None):
return self.exp_manager.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
def get_exp(self, experiment_id):
return self.exp_manager.get_exp(experiment_id)
def get_exp_by_name(self, experiment_name):
return self.exp_manager.get_exp_by_name(experiment_name)
def create_exp(self, experiment_name, artifact_location=None):
self.exp_manager.create_exp(experiment_name, artifact_location)
def set_exp(self, experiment_name):
self.exp_manager.set_exp(experiment_name)
def get_exp(self, experiment_id=None, experiment_name=None):
return self.exp_manager.get_exp(experiment_id, experiment_name)
def delete_exp(self, experiment_id):
self.exp_manager.delete_exp(experiment_id)
def set_tracking_uri(self, uri):
self.exp_manager.set_tracking_uri(uri)
def get_tracking_uri(self):
return self.exp_manager.get_tracking_uri()
def get_uri(self, type):
return self.exp_manager.get_uri(type)
def get_recorder(self):
return self.exp_manager.get_recorder()
return self.exp_manager.active_recorder
def save_object(self, name, data):
self.exp_manager.active_recorder.save_object(name, data)
def save_object(self, data=None, name=None, local_path=None):
self.exp_manager.active_recorder.save_object(data, name, local_path)
def save_objects(self, name_data_list):
self.exp_manager.active_recorder.save_objects(name_data_list)
def save_objects(self, data_name_list=None, local_path=None):
self.exp_manager.active_recorder.save_objects(data_name_list, local_path)
def load_object(self, name):
return self.exp_manager.active_recorder.load_object(name)
def log_params(self, **kwargs):
self.exp_manager.active_recorder.log_params(**kwargs)
def log_metrics(self, step=None, **kwargs):
self.exp_manager.active_recorder.log_metrics(step, **kwargs)
def log_param(self, key, value):
self.exp_manager.active_recorder.log_param(key, value)
def log_params(self, params):
self.exp_manager.active_recorder.log_params(params)
def log_metric(self, key, value, step=None):
self.exp_manager.active_recorder.log_metric(key, value, step)
def log_metrics(self, metrics, step=None):
self.exp_manager.active_recorder.log_metrics(metrics, step)
def set_tag(self, key, value):
self.exp_manager.active_recorder.set_tag(key, value)
def set_tags(self, tags):
self.exp_manager.active_recorder.set_tags(tags)
def set_tags(self, **kwargs):
self.exp_manager.active_recorder.set_tags(**kwargs)
def delete_tag(self, key):
self.exp_manager.active_recorder.delete_tag(key)
def log_artifact(self, local_path, artifact_path=None):
self.exp_manager.active_recorder.log_artifact(local_path, artifact_path)
def log_artifacts(self, local_dir, artifact_path=None):
self.exp_manager.active_recorder.log_artifacts(local_dir, artifact_path)
def get_artifact_uri(self, artifact_path=None):
return self.exp_manager.active_recorder.get_artifact_uri(artifact_path)
# global record
R = MLflowRecord()
R = Wrapper()

View File

@@ -2,67 +2,23 @@
# Licensed under the MIT License.
import mlflow
from contextlib import contextmanager
from .record import MLflowRecorder
from pathlib import Path
class ExpManager:
class Experiment:
"""
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
Thie is the `Experiment` class for each experiment being run. The API is designed
"""
def __init__(self):
self.active_recorder = None
self.experiments = dict() # store the experiment names -> list of recorders.
self.exp_ids = list()
def _store_exp(self, id, name):
"""
Store the experiments in the experiments holder.
"""
if id in self.exp_ids:
raise Exception('Something went wrong when creating the experiment. Please check if the experiment is already created.')
if name in self.experiments:
assert int(id) == int(self.experiments[name][0]), 'Experiment id and name are not consistent when storing the experiment.'
else:
self.exp_ids.append(id)
self.experiments[name] = [id]
self.name = None
self.id = None
self.recorders = list()
def start_exp(self, project_path, experiment_name=None, uri=None, artifact_location=None, nested=False):
def search_records(self, **kwargs):
"""
Start running an experiment. This method can only work in the `with` statement.
Get a pandas DataFrame of records that fit the search criteria of the experiment.
Parameters
----------
project_path : str
path for the project.
experiment_name : str
name of the active experiment.
uri : str
the current tracking URI.
artifact_location : str
the location to store all the artifacts.
nested : boolean
controls whether run is nested in parent run.
Returns
None
"""
raise NotImplementedError(f"Please implement the `start_exp` method.")
def end_exp(self):
"""
End an active experiment.
"""
raise NotImplementedError(f"Please implement the `end_exp` method.")
def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None):
"""
Get a pandas DataFrame of runs that fit the search criteria.
Parameters
----------
experiment_ids : list
list of experiment IDs.
filter_string : str
filter query string, defaults to searching all runs.
run_view_type : int
@@ -74,192 +30,18 @@ class ExpManager:
Returns
-------
A pandas.DataFrame of runs.
A pandas.DataFrame of records.
"""
raise NotImplementedError(f"Please implement the `search_runs` method.")
def get_exp(self, experiment_id):
"""
Retrieve an experiment by experiment_id from the backend store.
Parameters
----------
experiment_id : str
the experiment id to return.
Returns
-------
An experiment object (e.g. mlflow.entities.Experiment).
"""
raise NotImplementedError(f"Please implement the `get_exp` method.")
def get_exp_by_name(self, experiment_name):
"""
Retrieve an experiment by experiment name from the backend store.
Parameters
----------
experiment_name : str
the experiment name to return.
Returns
-------
An experiment object (e.g. mlflow.entities.Experiment).
"""
raise NotImplementedError(f"Please implement the `get_exp_by_name` method.")
def create_exp(self, experiment_name, artifact_location=None):
"""
Create an experiment.
Parameters
----------
experiment_name : str
the experiment name, which must be unique.
artifact_location : str
the location to store run artifacts.
Returns
-------
String id of created experiment.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def set_exp(self, experiment_name):
"""
Set the experiment to be active.
Parameters
----------
experiment_name : str
the experiment name, which must be unique.
Returns
-------
String id of created experiment.
"""
raise NotImplementedError(f"Please implement the `set_exp` method.")
def delete_exp(self, experiment_id):
"""
Delete an experiment.
Parameters
----------
experiment_id : str
the experiment id.
Returns
-------
None
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def set_tracking_uri(self, uri):
"""
Set the tracking server URI.
Parameters
----------
uri : str
the uri of the tracking server, can be An empty string, or a local file path, prefixed with file:/.
or An HTTP URI or A Databricks workspace.
Returns
-------
None
"""
raise NotImplementedError(f"Please implement the `set_tracking_uri` method.")
def get_tracking_uri(self):
"""
Get the tracking server URI.
Parameters
----------
Returns
-------
The tracking URI.
"""
raise NotImplementedError(f"Please implement the `get_tracking_uri` method.")
def get_recorder(self):
"""
Get the current active Recorder.
Parameters
----------
Returns
-------
An Recorder object.
"""
raise NotImplementedError(f"Please implement the `get_recorder` method.")
raise NotImplementedError(f"Please implement the `search_records` method.")
class MLflowExpManager(ExpManager):
'''
Use mlflow to implement ExpManager.
'''
def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False):
# set the tracking uri
if uri is None:
assert project_path is not None, "Please provide the project_path if no uri is provided in order to set a proper tracking uri."
print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the project path.')
mlflow.set_tracking_uri(str(project_path / "mlruns"))
else:
mlflow.set_tracking_uri(uri)
# start the experiment
if experiment_name is None:
print('No experiment name provided. The default experiment name is set as `experiment`.')
experiment_id = self.create_exp('experiment', artifact_location)
# set the active experiment
self.set_exp('experiment')
experiment_name = 'experiment'
else:
if experiment_name not in self.experiments:
if self.get_exp_by_name(experiment_name) is not None:
raise Exception('The experiment has already been created before. Please pick another name or delete the files under tracking uri.')
experiment_id = self.create_exp(experiment_name, artifact_location)
else:
experiment_id = self.experiments(experiment_name)[0]
# set the active experiment
self.set_exp(experiment_name)
# store the id and name
self._store_exp(experiment_id, experiment_name)
# set up recorder
recorder = MLflowRecorder(experiment_id)
self.active_recorder = recorder
# store the recorder
self.experiments[experiment_name].append(self.active_recorder)
return self.active_recorder.start_run(experiment_id=experiment_id, nested=nested)
def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None):
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
def get_exp(self, experiment_id):
return mlflow.get_experiment(experiment_id)
def get_exp_by_name(self, experiment_name):
return mlflow.get_experiment_by_name(experiment_name)
def create_exp(self, experiment_name, artifact_location=None):
return mlflow.create_experiment(experiment_name, artifact_location)
def set_exp(self, experiment_name):
mlflow.set_experiment(experiment_name)
def delete_exp(self, experiment_id):
mlflow.delete_experiment(experiment_id)
self.experiments = {key:val for key, val in self.experiments.items() if val[0] != experiment_id}
def set_tracking_uri(self, uri):
mlflow.set_tracking_uri(uri)
def get_tracking_uri(self):
return mlflow.get_tracking_uri()
def get_recorder(self):
return self.active_recorder
class MLflowExperiment(Experiment):
"""
Use mlflow to implement Experiment.
"""
def search_records(self, **kwargs):
filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string')
run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type')
max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results')
order_by = kwargs.get('order_by')
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)

236
qlib/workflow/expm.py Normal file
View File

@@ -0,0 +1,236 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import mlflow
import os
from pathlib import Path
from contextlib import contextmanager
from .exp import MLflowExperiment
from .record import MLflowRecorder
class ExpManager:
"""
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""
def __init__(self):
self.default_uri = None
self.active_recorder = None # only one recorder can running each time
self.experiments = dict() # store the experiment name --> Experiment object
def start_exp(self, experiment_name=None, uri=None, **kwargs):
"""
Start running an experiment.
Parameters
----------
experiment_name : str
name of the active experiment.
uri : str
the current tracking URI.
artifact_location : str
the location to store all the artifacts.
nested : boolean
controls whether run is nested in parent run.
Returns
An object wrapped by context manager.
"""
raise NotImplementedError(f"Please implement the `start_exp` method.")
def end_exp(self, **kwargs):
"""
End an running experiment.
Parameters
----------
experiment_name : str
name of the active experiment.
"""
raise NotImplementedError(f"Please implement the `end_exp` method.")
def search_records(self, experiment_ids=None, **kwargs):
"""
Get a pandas DataFrame of records that fit the search criteria.
Parameters
----------
experiment_ids : list
list of experiment IDs.
filter_string : str
filter query string, defaults to searching all runs.
run_view_type : int
one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType).
max_results : int
the maximum number of runs to put in the dataframe.
order_by : list
list of columns to order by (e.g., “metrics.rmse”).
Returns
-------
A pandas.DataFrame of runs.
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def __create_exp(self, experiment_name, artifact_location=None):
"""
Create an experiment.
Parameters
----------
experiment_name : str
the experiment name, which must be unique.
artifact_location : str
the location to store run artifacts.
Returns
-------
An experiment object.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def get_exp(self, experiment_id=None, experiment_name=None):
"""
Retrieve an experiment by experiment_id from the backend store.
Parameters
----------
experiment_id : str
the experiment id to return.
Returns
-------
An experiment object.
"""
raise NotImplementedError(f"Please implement the `get_exp` method.")
def delete_exp(self, experiment_id):
"""
Delete an experiment.
Parameters
----------
experiment_id : str
the experiment id.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def get_uri(self, type):
"""
Get the default tracking URI or current URI.
Parameters
----------
type : str
the type of the tracking URI one wants to retrieve.
Returns
-------
The tracking URI string.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def get_recorder(self):
"""
Get the current active Recorder.
Parameters
----------
Returns
-------
An Recorder object.
"""
raise NotImplementedError(f"Please implement the `get_recorder` method.")
class MLflowExpManager(ExpManager):
'''
Use mlflow to implement ExpManager.
'''
def __init__(self):
super(MLflowExpManager, self).__init__()
self.default_uri = None
self.current_uri = None
def start_exp(self, experiment_name=None, uri=None):
# create experiment
experiment = self.__create_exp(experiment_name, uri)
# set up recorder
recorder = MLflowRecorder(experiment.id)
self.active_recorder = recorder
# store the recorder
experiment.recorders.append(self.active_recorder)
# store the experiment
self.experiments[experiment_name] = experiment
return self.active_recorder.start_run(experiment_id=experiment.id)
def end_exp(self):
self.active_recorder.end_run()
self.active_recorder = None
def __create_exp(self, experiment_name=None, uri=None):
# init experiment
experiment = MLflowExperiment()
# set the tracking uri
if uri is None:
print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.')
else:
self.current_uri = uri
mlflow.set_tracking_uri(self.current_uri)
# start the experiment
if experiment_name is None:
print('No experiment name provided. The default experiment name is set as `experiment`.')
experiment_id = mlflow.create_experiment('experiment')
# set the active experiment
mlflow.set_experiment('experiment')
experiment_name = 'experiment'
else:
if experiment_name not in self.experiments:
if mlflow.get_experiment_by_name(experiment_name) is not None:
raise Exception('The experiment has already been created before. Please pick another name or delete the files under uri.')
experiment_id = mlflow.create_experiment(experiment_name)
else:
experiment_id = self.experiments[experiment_name].id
experiment = self.experiments[experiment_name]
# set the active experiment
mlflow.set_experiment(experiment_name)
# set up experiment
experiment.id = experiment_id
experiment.name = experiment_name
return experiment
def search_records(self, experiment_ids, **kwargs):
filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string')
run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type')
max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results')
order_by = kwargs.get('order_by')
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
def get_exp(self, experiment_id=None, experiment_name=None):
assert experiment_id is not None or experiment_name is not None, 'Please provide at least one of the experiment id or name to retrieve an experiment.'
if experiment_name is not None:
return self.experiments[experiment_name]
elif:
for name in self.experiments:
if self.experiments[name].id == experiment_id:
return self.experiments[name]
else:
print('No valid experiment is found. Please make sure the id and name are correctly given.')
def delete_exp(self, experiment_id):
mlflow.delete_experiment(experiment_id)
self.experiments = {key:val for key, val in self.experiments.items() if val.id != experiment_id}
def get_uri(self, type):
if uri == 'default':
return self.default_uri
elif uri == 'current':
return self.current_uri
else:
raise ValueError('Input type is not supported. Please choose type default or current to get the uri.')
def get_recorder(self):
return self.active_recorder

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
import mlflow
import shutil
import shutil, os, pickle, tempfile, codecs
from pathlib import Path
from ..utils.objm import FileManager
@@ -12,45 +12,39 @@ class Recorder:
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""
def __init__(self, experiment_id, project_path=None):
def __init__(self, experiment_id):
self.experiment_id = experiment_id
self.recorder_id = None
self.recorder_name = None
self.fm = None
self.artifact_uri = None
def set_recorder_name(self, rname):
self.recorder_name = rname
def save_object(self, name, data):
def save_object(self, data, name, local_path=None):
"""
Save object such as prediction file or model checkpoints.
Save object such as prediction file or model checkpoints to the artifact URI.
Parameters
----------
name : str
name of the file to be saved.
data : any type
the data to be saved.
Returns
-------
None.
name : str
name of the file to be saved.
local_path : str
if provided, them save the file or directory to the artifact URI.
"""
raise NotImplementedError(f"Please implement the `save_object` method.")
def save_objects(self, name_data_list):
def save_objects(self, data_name_list, local_path=None):
"""
Save objects such as prediction file or model checkpoints.
Save objects such as prediction file or model checkpoints to the artifact URI.
Parameters
----------
name_data_list : list
list of (name, data) pairs
Returns
-------
None.
data_name_list : list
list of (data, name) pairs
local_path : str
if provided, them save the file or directory to the artifact URI.
"""
raise NotImplementedError(f"Please implement the `save_objects` method.")
@@ -98,99 +92,36 @@ class Recorder:
"""
raise NotImplementedError(f"Please implement the `end_run` method.")
def log_param(self, key, value):
"""
Log a parameter under the current run.
Parameters
----------
key : str
the name of the parameter
value : str
the value of the parameter
Returns
-------
None
"""
raise NotImplementedError(f"Please implement the `log_param` method.")
def log_params(self, params):
def log_params(self, **kwargs):
"""
Log a batch of params for the current run.
Parameters
----------
params : dict
dictionary of param_name: String -> value: String.
Returns
-------
None
keyword arguments
key, value pair to be logged as parameters.
"""
raise NotImplementedError(f"Please implement the `log_params` method.")
def log_metric(self, key, value, step=None):
"""
Log a metric under the current run.
Parameters
----------
key : str
the name of the metric
value : float
the value of the metric
Returns
-------
None
"""
raise NotImplementedError(f"Please implement the `log_metric` method.")
def log_metrics(self, metrics, step=None):
def log_metrics(self, step=None, **kwargs):
"""
Log multiple metrics for the current run.
Parameters
----------
metrics : dict
dictionary of metric_name: String -> value: Float.
Returns
-------
None
keyword arguments
key, value pair to be logged as metrics.
"""
raise NotImplementedError(f"Please implement the `log_metrics` method.")
def set_tag(self, key, value):
"""
Set a tag under the current run.
Parameters
----------
key : str
the name of the tag
value : str
the value of the tag
Returns
-------
None
"""
raise NotImplementedError(f"Please implement the `set_tag` method.")
def set_tags(self, tags):
def set_tags(self, **kwargs):
"""
Log a batch of tags for the current run.
Parameters
----------
tags : dict
dictionary of tag_name: String -> value: String.
Returns
-------
None
keyword arguments
key, value pair to be logged as tags.
"""
raise NotImplementedError(f"Please implement the `log_tags` method.")
@@ -202,67 +133,22 @@ class Recorder:
----------
key : str
the name of the tag to be deleted.
Returns
-------
None
"""
raise NotImplementedError(f"Please implement the `delete_tag` method.")
def log_artifact(self, local_path, artifact_path=None):
"""
Log a local file or directory as an artifact of the currently active run.
Parameters
----------
local_path : str
path to the file to write.
artifact_path : str
the directory in `artifact_uri` to write to.
Returns
-------
None
"""
raise NotImplementedError(f"Please implement the `log_artifact` method.")
def log_artifacts(self, local_dir, artifact_path=None):
"""
Log all the contents of a local directory as artifacts of the run.
Parameters
----------
local_dir : str
path to the directory of files to write.
artifact_path : str
the directory in `artifact_uri` to write to.
Returns
-------
None
"""
raise NotImplementedError(f"Please implement the `log_artifacts` method.")
def get_artifact_uri(self, artifact_path=None):
"""
Get the absolute URI of the specified artifact in the currently active run.
Parameters
----------
artifact_path : str
the directory in `artifact_uri` to write to.
Returns
-------
An absolute URI referring to the specified artifact or currently active Recorder.
"""
raise NotImplementedError(f"Please implement the `get_artifact_uri` method.")
class MLflowRecorder(Recorder):
'''
Use mlflow to implement a Recorder.
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
use file manager to help maintain the objects in the project.
'''
def __init__(self, experiment_id):
super(MLflowRecorder, self).__init__(experiment_id)
self.fm = None
self.temp_dir = None
def start_run(self, run_id=None, experiment_id=None,
run_name=None, nested=False):
if run_id is None:
@@ -277,65 +163,67 @@ class MLflowRecorder(Recorder):
self.recorder_id = run.info.run_id
self.artifact_uri = run.info.artifact_uri
# set up file manager for saving objects
if self.artifact_uri.startswith('file:/'):
self.fm = FileManager(Path(urllib.parse.urlparse(self.artifact_uri).path))
else:
self.fm = FileManager(Path(self.artifact_uri))
print(self.artifact_uri)
self.temp_dir = tempfile.mkdtemp()
self.fm = FileManager(Path(self.temp_dir).absolute())
return run
def end_run(self):
mlflow.end_run()
shutil.rmtree(self.temp_dir)
def save_object(self, name, data):
self.fm.save_obj(data, name)
import urllib
print(urllib.parse.urlparse(self.artifact_uri).scheme)
try:
self.log_artifact(self.fm.path / name)
except shutil.SameFileError:
pass
except Exception as e:
print(e)
def save_object(self, data, name, local_path=None):
if local_path is None:
assert data is not None and name is not None, "Please provide data and name input."
self.fm.save_obj(data, name)
mlflow.log_artifact(self.fm.path / name)
self.fm.remove(name)
else:
mlflow.log_artifact(local_path)
def save_objects(self, name_data_list):
self.fm.save_objs(name_data_list)
try:
self.log_artifacts(self.fm.path)
except shutil.SameFileError:
pass
except Exception as e:
print(e)
def save_objects(self, data_name_list, local_path=None):
if local_path is None:
assert data_name_list is not None, "Please provide data_name_list input."
self.fm.save_objs(data_name_list)
mlflow.log_artifacts(self.fm.path)
for obj, name in data_name_list:
self.fm.remove(name)
else:
mlflow.log_artifacts(local_path)
def load_object(self, name):
return self.fm.load_obj(name)
client = mlflow.tracking.MlflowClient()
path = client.download_artifacts(self.recorder_id, name)
try:
with Path(path).open('rb') as f:
f.seek(0)
return pickle.load(f)
except:
with codecs.open(path, mode="r", encoding='utf-8') as f:
return f.read()
def log_params(self, **kwargs):
keys = list(kwargs.keys())
if len(keys) == 0:
mlflow.log_param(keys[0], kwargs.get(keys[0]))
else:
mlflow.log_params(dict(kwargs))
def log_param(self, key, value):
mlflow.log_param(key, value)
def log_params(self, params):
mlflow.log_params(params)
def log_metric(self, key, value, step=None):
mlflow.log_metric(key, value, step)
def log_metrics(self, metrics, step=None):
mlflow.log_metrics(metrics, step)
def log_metrics(self, step=None, **kwargs):
keys = list(kwargs.keys())
if len(keys) == 0:
mlflow.log_metric(keys[0], kwargs.get(keys[0]))
else:
mlflow.log_metrics(dict(kwargs))
def set_tag(self, key, value):
mlflow.set_tag(key, value)
def set_tags(self, tags):
mlflow.set_tags(tags)
def set_tags(self, **kwargs):
keys = list(kwargs.keys())
if len(keys) == 0:
mlflow.set_tag(keys[0], kwargs.get(keys[0]))
else:
mlflow.set_tags(dict(kwargs))
def delete_tag(self, key):
mlflow.delete_tag(key)
def log_artifact(self, local_path, artifact_path=None):
mlflow.log_artifact(local_path, artifact_path)
def log_artifacts(self, local_dir, artifact_path=None):
mlflow.log_artifacts(local_dir, artifact_path)
def get_artifact_uri(self, artifact_path=None):
if self.artifact_uri is not None: