mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 02:21:18 +08:00
Update Exp related codes
This commit is contained in:
@@ -13,7 +13,7 @@ import platform
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
from .utils import can_use_cache
|
||||
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
|
||||
|
||||
|
||||
# init qlib
|
||||
@@ -22,6 +22,7 @@ def init(default_conf="client", **kwargs):
|
||||
from .data.data import register_all_wrappers
|
||||
from .log import get_module_logger, set_log_with_config
|
||||
from .data.cache import H
|
||||
from .workflow import R, QlibRecorder
|
||||
|
||||
C.reset()
|
||||
H.clear()
|
||||
@@ -79,6 +80,15 @@ def init(default_conf="client", **kwargs):
|
||||
|
||||
if "flask_server" in C:
|
||||
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
|
||||
# set up QlibRecorder
|
||||
default_uri = str(Path(os.getcwd()).resolve() / "mlruns")
|
||||
current_uri = C['exp_uri'] if C['exp_uri'] is not None else default_uri
|
||||
# exp manager module
|
||||
module = get_module_by_module_path('qlib.workflow')
|
||||
exp_manager = init_instance_by_config(C['exp_manager'], module)
|
||||
qr = QlibRecorder(exp_manager, default_uri, current_uri)
|
||||
R.register(qr)
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
|
||||
@@ -124,6 +124,12 @@ _default_config = {
|
||||
},
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
"exp_manager": {
|
||||
"class": "MLflowExpManager",
|
||||
"kwargs": {}
|
||||
},
|
||||
"exp_uri": None,
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
|
||||
@@ -24,6 +24,7 @@ from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, get_provider_obj, register_wrapper
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
@@ -1019,44 +1020,6 @@ class ClientProvider(BaseProvider):
|
||||
DatasetD.set_conn(self.client)
|
||||
|
||||
|
||||
class Wrapper(object):
|
||||
"""Data Provider Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
self._provider = None
|
||||
|
||||
def register(self, provider):
|
||||
self._provider = provider
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._provider is None:
|
||||
raise AttributeError("Please run qlib.init() first using qlib")
|
||||
return getattr(self._provider, key)
|
||||
|
||||
|
||||
def get_cls_from_name(cls_name):
|
||||
return getattr(importlib.import_module(".data", package="qlib"), cls_name)
|
||||
|
||||
|
||||
def get_provider_obj(config, **params):
|
||||
if isinstance(config, dict):
|
||||
params.update(config["kwargs"])
|
||||
config = config["class"]
|
||||
return get_cls_from_name(config)(**params)
|
||||
|
||||
|
||||
def register_wrapper(wrapper, cls_or_obj):
|
||||
"""register_wrapper
|
||||
|
||||
:param wrapper: A wrapper of all kinds of providers
|
||||
:param cls_or_obj: A class or class name or object instance in data/data.py
|
||||
"""
|
||||
if isinstance(cls_or_obj, str):
|
||||
cls_or_obj = get_cls_from_name(cls_or_obj)
|
||||
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
|
||||
wrapper.register(obj)
|
||||
|
||||
|
||||
Cal = Wrapper()
|
||||
Inst = Wrapper()
|
||||
FeatureD = Wrapper()
|
||||
|
||||
@@ -611,3 +611,39 @@ def exists_qlib_data(qlib_dir):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
#################### Wrapper #####################
|
||||
class Wrapper(object):
|
||||
"""Data Provider Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
self._provider = None
|
||||
|
||||
def register(self, provider):
|
||||
self._provider = provider
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._provider is None:
|
||||
raise AttributeError("Please run qlib.init() first using qlib")
|
||||
return getattr(self._provider, key)
|
||||
|
||||
|
||||
def get_provider_obj(config, **params):
|
||||
module = get_module_by_module_path("qlib.data")
|
||||
klass, kwargs = get_cls_kwargs(config, module)
|
||||
kwargs.update(params)
|
||||
return klass(**kwargs)
|
||||
|
||||
|
||||
def register_wrapper(wrapper, cls_or_obj):
|
||||
"""register_wrapper
|
||||
|
||||
:param wrapper: A wrapper of all kinds of providers
|
||||
:param cls_or_obj: A class or class name or object instance in data/data.py
|
||||
"""
|
||||
if isinstance(cls_or_obj, str):
|
||||
module = get_module_by_module_path("qlib.data")
|
||||
cls_or_obj = getattr(module, cls_or_obj)
|
||||
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
|
||||
wrapper.register(obj)
|
||||
@@ -2,156 +2,62 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from .record import MLflowRecorder
|
||||
from .exp import MLflowExpManager
|
||||
from .expm import *
|
||||
from ..utils import Wrapper
|
||||
|
||||
class Record:
|
||||
def __init__(self):
|
||||
pass
|
||||
class QlibRecorder:
|
||||
def __init__(self, exp_manager, default_uri, current_uri):
|
||||
self.exp_manager = exp_manager
|
||||
self.default_uri = default_uri
|
||||
self.current_uri = current_uri
|
||||
|
||||
@contextmanager
|
||||
def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False):
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
def start(self, experiment_name):
|
||||
run = self.start_exp(experiment_name, self.current_uri)
|
||||
yield run
|
||||
self.end_exp()
|
||||
|
||||
def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None):
|
||||
raise NotImplementedError(f"Please implement the `search_runs` method.")
|
||||
def start_exp(self, experiment_name=None):
|
||||
return self.exp_manager.start_exp(experiment_name, self.current_uri)
|
||||
|
||||
def end_exp(self):
|
||||
self.exp_manager.end_exp()
|
||||
|
||||
def get_exp(self, experiment_id):
|
||||
raise NotImplementedError(f"Please implement the `get_exp` method.")
|
||||
|
||||
def get_exp_by_name(self, experiment_name):
|
||||
raise NotImplementedError(f"Please implement the `get_exp_by_name` method.")
|
||||
def search_records(self, experiment_ids, **kwargs):
|
||||
return self.exp_manager.search_records(experiment_ids, **kwargs)
|
||||
|
||||
def create_exp(self, experiment_name, artifact_location=None):
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def set_exp(self, experiment_name):
|
||||
raise NotImplementedError(f"Please implement the `set_exp` method.")
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def set_tracking_uri(self, uri):
|
||||
raise NotImplementedError(f"Please implement the `set_tracking_uri` method.")
|
||||
|
||||
def get_tracking_uri(self):
|
||||
raise NotImplementedError(f"Please implement the `get_tracking_uri` method.")
|
||||
|
||||
def get_recorder(self):
|
||||
raise NotImplementedError(f"Please implement the `get_recorder` method.")
|
||||
|
||||
def save_object(self, name, data):
|
||||
raise NotImplementedError(f"Please implement the `save_object` method.")
|
||||
|
||||
def save_objects(self, name_data_list):
|
||||
raise NotImplementedError(f"Please implement the `save_objects` method.")
|
||||
|
||||
def load_object(self, name):
|
||||
raise NotImplementedError(f"Please implement the `load_object` method.")
|
||||
|
||||
def log_param(self, key, value):
|
||||
raise NotImplementedError(f"Please implement the `log_param` method.")
|
||||
|
||||
def log_params(self, params):
|
||||
raise NotImplementedError(f"Please implement the `log_params` method.")
|
||||
|
||||
def log_metric(self, key, value, step=None):
|
||||
raise NotImplementedError(f"Please implement the `log_metric` method.")
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
raise NotImplementedError(f"Please implement the `log_metrics` method.")
|
||||
|
||||
def set_tag(self, key, value):
|
||||
raise NotImplementedError(f"Please implement the `set_tag` method.")
|
||||
|
||||
def set_tags(self, tags):
|
||||
raise NotImplementedError(f"Please implement the `log_tags` method.")
|
||||
|
||||
def delete_tag(self, key):
|
||||
raise NotImplementedError(f"Please implement the `delete_tag` method.")
|
||||
|
||||
def log_artifact(self, local_path, artifact_path=None):
|
||||
raise NotImplementedError(f"Please implement the `log_artifact` method.")
|
||||
|
||||
def log_artifacts(self, local_dir, artifact_path=None):
|
||||
raise NotImplementedError(f"Please implement the `log_artifacts` method.")
|
||||
|
||||
def get_artifact_uri(self, artifact_path=None):
|
||||
raise NotImplementedError(f"Please implement the `get_artifact_uri` method.")
|
||||
|
||||
class MLflowRecord(Record):
|
||||
def __init__(self):
|
||||
self.exp_manager = MLflowExpManager()
|
||||
|
||||
@contextmanager
|
||||
def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False):
|
||||
yield self.exp_manager.start_exp(experiment_name, uri, project_path, artifact_location, nested)
|
||||
|
||||
def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None):
|
||||
return self.exp_manager.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def get_exp(self, experiment_id):
|
||||
return self.exp_manager.get_exp(experiment_id)
|
||||
|
||||
def get_exp_by_name(self, experiment_name):
|
||||
return self.exp_manager.get_exp_by_name(experiment_name)
|
||||
|
||||
def create_exp(self, experiment_name, artifact_location=None):
|
||||
self.exp_manager.create_exp(experiment_name, artifact_location)
|
||||
|
||||
def set_exp(self, experiment_name):
|
||||
self.exp_manager.set_exp(experiment_name)
|
||||
def get_exp(self, experiment_id=None, experiment_name=None):
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name)
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
self.exp_manager.delete_exp(experiment_id)
|
||||
|
||||
def set_tracking_uri(self, uri):
|
||||
self.exp_manager.set_tracking_uri(uri)
|
||||
|
||||
def get_tracking_uri(self):
|
||||
return self.exp_manager.get_tracking_uri()
|
||||
|
||||
def get_uri(self, type):
|
||||
return self.exp_manager.get_uri(type)
|
||||
|
||||
def get_recorder(self):
|
||||
return self.exp_manager.get_recorder()
|
||||
return self.exp_manager.active_recorder
|
||||
|
||||
def save_object(self, name, data):
|
||||
self.exp_manager.active_recorder.save_object(name, data)
|
||||
def save_object(self, data=None, name=None, local_path=None):
|
||||
self.exp_manager.active_recorder.save_object(data, name, local_path)
|
||||
|
||||
def save_objects(self, name_data_list):
|
||||
self.exp_manager.active_recorder.save_objects(name_data_list)
|
||||
def save_objects(self, data_name_list=None, local_path=None):
|
||||
self.exp_manager.active_recorder.save_objects(data_name_list, local_path)
|
||||
|
||||
def load_object(self, name):
|
||||
return self.exp_manager.active_recorder.load_object(name)
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
self.exp_manager.active_recorder.log_params(**kwargs)
|
||||
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
self.exp_manager.active_recorder.log_metrics(step, **kwargs)
|
||||
|
||||
def log_param(self, key, value):
|
||||
self.exp_manager.active_recorder.log_param(key, value)
|
||||
|
||||
def log_params(self, params):
|
||||
self.exp_manager.active_recorder.log_params(params)
|
||||
|
||||
def log_metric(self, key, value, step=None):
|
||||
self.exp_manager.active_recorder.log_metric(key, value, step)
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
self.exp_manager.active_recorder.log_metrics(metrics, step)
|
||||
|
||||
def set_tag(self, key, value):
|
||||
self.exp_manager.active_recorder.set_tag(key, value)
|
||||
|
||||
def set_tags(self, tags):
|
||||
self.exp_manager.active_recorder.set_tags(tags)
|
||||
def set_tags(self, **kwargs):
|
||||
self.exp_manager.active_recorder.set_tags(**kwargs)
|
||||
|
||||
def delete_tag(self, key):
|
||||
self.exp_manager.active_recorder.delete_tag(key)
|
||||
|
||||
def log_artifact(self, local_path, artifact_path=None):
|
||||
self.exp_manager.active_recorder.log_artifact(local_path, artifact_path)
|
||||
|
||||
def log_artifacts(self, local_dir, artifact_path=None):
|
||||
self.exp_manager.active_recorder.log_artifacts(local_dir, artifact_path)
|
||||
|
||||
def get_artifact_uri(self, artifact_path=None):
|
||||
return self.exp_manager.active_recorder.get_artifact_uri(artifact_path)
|
||||
|
||||
# global record
|
||||
R = MLflowRecord()
|
||||
R = Wrapper()
|
||||
@@ -2,67 +2,23 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
from contextlib import contextmanager
|
||||
from .record import MLflowRecorder
|
||||
from pathlib import Path
|
||||
|
||||
class ExpManager:
|
||||
class Experiment:
|
||||
"""
|
||||
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
Thie is the `Experiment` class for each experiment being run. The API is designed
|
||||
"""
|
||||
def __init__(self):
|
||||
self.active_recorder = None
|
||||
self.experiments = dict() # store the experiment names -> list of recorders.
|
||||
self.exp_ids = list()
|
||||
|
||||
def _store_exp(self, id, name):
|
||||
"""
|
||||
Store the experiments in the experiments holder.
|
||||
"""
|
||||
if id in self.exp_ids:
|
||||
raise Exception('Something went wrong when creating the experiment. Please check if the experiment is already created.')
|
||||
if name in self.experiments:
|
||||
assert int(id) == int(self.experiments[name][0]), 'Experiment id and name are not consistent when storing the experiment.'
|
||||
else:
|
||||
self.exp_ids.append(id)
|
||||
self.experiments[name] = [id]
|
||||
self.name = None
|
||||
self.id = None
|
||||
self.recorders = list()
|
||||
|
||||
def start_exp(self, project_path, experiment_name=None, uri=None, artifact_location=None, nested=False):
|
||||
def search_records(self, **kwargs):
|
||||
"""
|
||||
Start running an experiment. This method can only work in the `with` statement.
|
||||
Get a pandas DataFrame of records that fit the search criteria of the experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
project_path : str
|
||||
path for the project.
|
||||
experiment_name : str
|
||||
name of the active experiment.
|
||||
uri : str
|
||||
the current tracking URI.
|
||||
artifact_location : str
|
||||
the location to store all the artifacts.
|
||||
nested : boolean
|
||||
controls whether run is nested in parent run.
|
||||
|
||||
Returns
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
def end_exp(self):
|
||||
"""
|
||||
End an active experiment.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
|
||||
def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None):
|
||||
"""
|
||||
Get a pandas DataFrame of runs that fit the search criteria.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_ids : list
|
||||
list of experiment IDs.
|
||||
filter_string : str
|
||||
filter query string, defaults to searching all runs.
|
||||
run_view_type : int
|
||||
@@ -74,192 +30,18 @@ class ExpManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
A pandas.DataFrame of runs.
|
||||
A pandas.DataFrame of records.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_runs` method.")
|
||||
|
||||
def get_exp(self, experiment_id):
|
||||
"""
|
||||
Retrieve an experiment by experiment_id from the backend store.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
the experiment id to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An experiment object (e.g. mlflow.entities.Experiment).
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_exp` method.")
|
||||
|
||||
def get_exp_by_name(self, experiment_name):
|
||||
"""
|
||||
Retrieve an experiment by experiment name from the backend store.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
the experiment name to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An experiment object (e.g. mlflow.entities.Experiment).
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_exp_by_name` method.")
|
||||
|
||||
def create_exp(self, experiment_name, artifact_location=None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
the experiment name, which must be unique.
|
||||
artifact_location : str
|
||||
the location to store run artifacts.
|
||||
|
||||
Returns
|
||||
-------
|
||||
String id of created experiment.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def set_exp(self, experiment_name):
|
||||
"""
|
||||
Set the experiment to be active.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
the experiment name, which must be unique.
|
||||
|
||||
Returns
|
||||
-------
|
||||
String id of created experiment.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_exp` method.")
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
"""
|
||||
Delete an experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
the experiment id.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def set_tracking_uri(self, uri):
|
||||
"""
|
||||
Set the tracking server URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
the uri of the tracking server, can be An empty string, or a local file path, prefixed with file:/.
|
||||
or An HTTP URI or A Databricks workspace.
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_tracking_uri` method.")
|
||||
|
||||
def get_tracking_uri(self):
|
||||
"""
|
||||
Get the tracking server URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Returns
|
||||
-------
|
||||
The tracking URI.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_tracking_uri` method.")
|
||||
|
||||
def get_recorder(self):
|
||||
"""
|
||||
Get the current active Recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Returns
|
||||
-------
|
||||
An Recorder object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_recorder` method.")
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
|
||||
class MLflowExpManager(ExpManager):
|
||||
'''
|
||||
Use mlflow to implement ExpManager.
|
||||
'''
|
||||
def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False):
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
assert project_path is not None, "Please provide the project_path if no uri is provided in order to set a proper tracking uri."
|
||||
print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the project path.')
|
||||
mlflow.set_tracking_uri(str(project_path / "mlruns"))
|
||||
else:
|
||||
mlflow.set_tracking_uri(uri)
|
||||
# start the experiment
|
||||
if experiment_name is None:
|
||||
print('No experiment name provided. The default experiment name is set as `experiment`.')
|
||||
experiment_id = self.create_exp('experiment', artifact_location)
|
||||
# set the active experiment
|
||||
self.set_exp('experiment')
|
||||
experiment_name = 'experiment'
|
||||
else:
|
||||
if experiment_name not in self.experiments:
|
||||
if self.get_exp_by_name(experiment_name) is not None:
|
||||
raise Exception('The experiment has already been created before. Please pick another name or delete the files under tracking uri.')
|
||||
experiment_id = self.create_exp(experiment_name, artifact_location)
|
||||
else:
|
||||
experiment_id = self.experiments(experiment_name)[0]
|
||||
# set the active experiment
|
||||
self.set_exp(experiment_name)
|
||||
|
||||
# store the id and name
|
||||
self._store_exp(experiment_id, experiment_name)
|
||||
# set up recorder
|
||||
recorder = MLflowRecorder(experiment_id)
|
||||
self.active_recorder = recorder
|
||||
# store the recorder
|
||||
self.experiments[experiment_name].append(self.active_recorder)
|
||||
|
||||
return self.active_recorder.start_run(experiment_id=experiment_id, nested=nested)
|
||||
|
||||
def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None):
|
||||
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def get_exp(self, experiment_id):
|
||||
return mlflow.get_experiment(experiment_id)
|
||||
|
||||
def get_exp_by_name(self, experiment_name):
|
||||
return mlflow.get_experiment_by_name(experiment_name)
|
||||
|
||||
def create_exp(self, experiment_name, artifact_location=None):
|
||||
return mlflow.create_experiment(experiment_name, artifact_location)
|
||||
|
||||
def set_exp(self, experiment_name):
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
mlflow.delete_experiment(experiment_id)
|
||||
self.experiments = {key:val for key, val in self.experiments.items() if val[0] != experiment_id}
|
||||
|
||||
def set_tracking_uri(self, uri):
|
||||
mlflow.set_tracking_uri(uri)
|
||||
|
||||
def get_tracking_uri(self):
|
||||
return mlflow.get_tracking_uri()
|
||||
|
||||
def get_recorder(self):
|
||||
return self.active_recorder
|
||||
class MLflowExperiment(Experiment):
|
||||
"""
|
||||
Use mlflow to implement Experiment.
|
||||
"""
|
||||
def search_records(self, **kwargs):
|
||||
filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string')
|
||||
run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type')
|
||||
max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results')
|
||||
order_by = kwargs.get('order_by')
|
||||
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)
|
||||
236
qlib/workflow/expm.py
Normal file
236
qlib/workflow/expm.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from .exp import MLflowExperiment
|
||||
from .record import MLflowRecorder
|
||||
|
||||
class ExpManager:
|
||||
"""
|
||||
This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow.
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
def __init__(self):
|
||||
self.default_uri = None
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
self.experiments = dict() # store the experiment name --> Experiment object
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None, **kwargs):
|
||||
"""
|
||||
Start running an experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
name of the active experiment.
|
||||
uri : str
|
||||
the current tracking URI.
|
||||
artifact_location : str
|
||||
the location to store all the artifacts.
|
||||
nested : boolean
|
||||
controls whether run is nested in parent run.
|
||||
|
||||
Returns
|
||||
An object wrapped by context manager.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
def end_exp(self, **kwargs):
|
||||
"""
|
||||
End an running experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
name of the active experiment.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
|
||||
def search_records(self, experiment_ids=None, **kwargs):
|
||||
"""
|
||||
Get a pandas DataFrame of records that fit the search criteria.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_ids : list
|
||||
list of experiment IDs.
|
||||
filter_string : str
|
||||
filter query string, defaults to searching all runs.
|
||||
run_view_type : int
|
||||
one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType).
|
||||
max_results : int
|
||||
the maximum number of runs to put in the dataframe.
|
||||
order_by : list
|
||||
list of columns to order by (e.g., “metrics.rmse”).
|
||||
|
||||
Returns
|
||||
-------
|
||||
A pandas.DataFrame of runs.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def __create_exp(self, experiment_name, artifact_location=None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
the experiment name, which must be unique.
|
||||
artifact_location : str
|
||||
the location to store run artifacts.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An experiment object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
Retrieve an experiment by experiment_id from the backend store.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
the experiment id to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An experiment object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_exp` method.")
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
"""
|
||||
Delete an experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
the experiment id.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def get_uri(self, type):
|
||||
"""
|
||||
Get the default tracking URI or current URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
type : str
|
||||
the type of the tracking URI one wants to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The tracking URI string.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
def get_recorder(self):
|
||||
"""
|
||||
Get the current active Recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Returns
|
||||
-------
|
||||
An Recorder object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_recorder` method.")
|
||||
|
||||
|
||||
class MLflowExpManager(ExpManager):
|
||||
'''
|
||||
Use mlflow to implement ExpManager.
|
||||
'''
|
||||
def __init__(self):
|
||||
super(MLflowExpManager, self).__init__()
|
||||
self.default_uri = None
|
||||
self.current_uri = None
|
||||
|
||||
def start_exp(self, experiment_name=None, uri=None):
|
||||
# create experiment
|
||||
experiment = self.__create_exp(experiment_name, uri)
|
||||
# set up recorder
|
||||
recorder = MLflowRecorder(experiment.id)
|
||||
self.active_recorder = recorder
|
||||
# store the recorder
|
||||
experiment.recorders.append(self.active_recorder)
|
||||
# store the experiment
|
||||
self.experiments[experiment_name] = experiment
|
||||
|
||||
return self.active_recorder.start_run(experiment_id=experiment.id)
|
||||
|
||||
def end_exp(self):
|
||||
self.active_recorder.end_run()
|
||||
self.active_recorder = None
|
||||
|
||||
def __create_exp(self, experiment_name=None, uri=None):
|
||||
# init experiment
|
||||
experiment = MLflowExperiment()
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.')
|
||||
else:
|
||||
self.current_uri = uri
|
||||
mlflow.set_tracking_uri(self.current_uri)
|
||||
# start the experiment
|
||||
if experiment_name is None:
|
||||
print('No experiment name provided. The default experiment name is set as `experiment`.')
|
||||
experiment_id = mlflow.create_experiment('experiment')
|
||||
# set the active experiment
|
||||
mlflow.set_experiment('experiment')
|
||||
experiment_name = 'experiment'
|
||||
else:
|
||||
if experiment_name not in self.experiments:
|
||||
if mlflow.get_experiment_by_name(experiment_name) is not None:
|
||||
raise Exception('The experiment has already been created before. Please pick another name or delete the files under uri.')
|
||||
experiment_id = mlflow.create_experiment(experiment_name)
|
||||
else:
|
||||
experiment_id = self.experiments[experiment_name].id
|
||||
experiment = self.experiments[experiment_name]
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(experiment_name)
|
||||
# set up experiment
|
||||
experiment.id = experiment_id
|
||||
experiment.name = experiment_name
|
||||
|
||||
return experiment
|
||||
|
||||
def search_records(self, experiment_ids, **kwargs):
|
||||
filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string')
|
||||
run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type')
|
||||
max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results')
|
||||
order_by = kwargs.get('order_by')
|
||||
return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None):
|
||||
assert experiment_id is not None or experiment_name is not None, 'Please provide at least one of the experiment id or name to retrieve an experiment.'
|
||||
if experiment_name is not None:
|
||||
return self.experiments[experiment_name]
|
||||
elif:
|
||||
for name in self.experiments:
|
||||
if self.experiments[name].id == experiment_id:
|
||||
return self.experiments[name]
|
||||
else:
|
||||
print('No valid experiment is found. Please make sure the id and name are correctly given.')
|
||||
|
||||
def delete_exp(self, experiment_id):
|
||||
mlflow.delete_experiment(experiment_id)
|
||||
self.experiments = {key:val for key, val in self.experiments.items() if val.id != experiment_id}
|
||||
|
||||
def get_uri(self, type):
|
||||
if uri == 'default':
|
||||
return self.default_uri
|
||||
elif uri == 'current':
|
||||
return self.current_uri
|
||||
else:
|
||||
raise ValueError('Input type is not supported. Please choose type default or current to get the uri.')
|
||||
|
||||
def get_recorder(self):
|
||||
return self.active_recorder
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import shutil
|
||||
import shutil, os, pickle, tempfile, codecs
|
||||
from pathlib import Path
|
||||
from ..utils.objm import FileManager
|
||||
|
||||
@@ -12,45 +12,39 @@ class Recorder:
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_id, project_path=None):
|
||||
def __init__(self, experiment_id):
|
||||
self.experiment_id = experiment_id
|
||||
self.recorder_id = None
|
||||
self.recorder_name = None
|
||||
self.fm = None
|
||||
self.artifact_uri = None
|
||||
|
||||
def set_recorder_name(self, rname):
|
||||
self.recorder_name = rname
|
||||
|
||||
def save_object(self, name, data):
|
||||
def save_object(self, data, name, local_path=None):
|
||||
"""
|
||||
Save object such as prediction file or model checkpoints.
|
||||
Save object such as prediction file or model checkpoints to the artifact URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
name of the file to be saved.
|
||||
data : any type
|
||||
the data to be saved.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None.
|
||||
name : str
|
||||
name of the file to be saved.
|
||||
local_path : str
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `save_object` method.")
|
||||
|
||||
def save_objects(self, name_data_list):
|
||||
def save_objects(self, data_name_list, local_path=None):
|
||||
"""
|
||||
Save objects such as prediction file or model checkpoints.
|
||||
Save objects such as prediction file or model checkpoints to the artifact URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name_data_list : list
|
||||
list of (name, data) pairs
|
||||
|
||||
Returns
|
||||
-------
|
||||
None.
|
||||
data_name_list : list
|
||||
list of (data, name) pairs
|
||||
local_path : str
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `save_objects` method.")
|
||||
|
||||
@@ -98,99 +92,36 @@ class Recorder:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_run` method.")
|
||||
|
||||
def log_param(self, key, value):
|
||||
"""
|
||||
Log a parameter under the current run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
the name of the parameter
|
||||
value : str
|
||||
the value of the parameter
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_param` method.")
|
||||
|
||||
def log_params(self, params):
|
||||
def log_params(self, **kwargs):
|
||||
"""
|
||||
Log a batch of params for the current run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params : dict
|
||||
dictionary of param_name: String -> value: String.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
keyword arguments
|
||||
key, value pair to be logged as parameters.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_params` method.")
|
||||
|
||||
def log_metric(self, key, value, step=None):
|
||||
"""
|
||||
Log a metric under the current run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
the name of the metric
|
||||
value : float
|
||||
the value of the metric
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_metric` method.")
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
"""
|
||||
Log multiple metrics for the current run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metrics : dict
|
||||
dictionary of metric_name: String -> value: Float.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
keyword arguments
|
||||
key, value pair to be logged as metrics.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_metrics` method.")
|
||||
|
||||
def set_tag(self, key, value):
|
||||
"""
|
||||
Set a tag under the current run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
the name of the tag
|
||||
value : str
|
||||
the value of the tag
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_tag` method.")
|
||||
|
||||
def set_tags(self, tags):
|
||||
def set_tags(self, **kwargs):
|
||||
"""
|
||||
Log a batch of tags for the current run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tags : dict
|
||||
dictionary of tag_name: String -> value: String.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
keyword arguments
|
||||
key, value pair to be logged as tags.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_tags` method.")
|
||||
|
||||
@@ -202,67 +133,22 @@ class Recorder:
|
||||
----------
|
||||
key : str
|
||||
the name of the tag to be deleted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_tag` method.")
|
||||
|
||||
def log_artifact(self, local_path, artifact_path=None):
|
||||
"""
|
||||
Log a local file or directory as an artifact of the currently active run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
local_path : str
|
||||
path to the file to write.
|
||||
artifact_path : str
|
||||
the directory in `artifact_uri` to write to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_artifact` method.")
|
||||
|
||||
def log_artifacts(self, local_dir, artifact_path=None):
|
||||
"""
|
||||
Log all the contents of a local directory as artifacts of the run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
local_dir : str
|
||||
path to the directory of files to write.
|
||||
artifact_path : str
|
||||
the directory in `artifact_uri` to write to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `log_artifacts` method.")
|
||||
|
||||
def get_artifact_uri(self, artifact_path=None):
|
||||
"""
|
||||
Get the absolute URI of the specified artifact in the currently active run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
artifact_path : str
|
||||
the directory in `artifact_uri` to write to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An absolute URI referring to the specified artifact or currently active Recorder.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_artifact_uri` method.")
|
||||
|
||||
|
||||
class MLflowRecorder(Recorder):
|
||||
'''
|
||||
Use mlflow to implement a Recorder.
|
||||
|
||||
Due to the fact that mlflow will only log artifact from a file or directory, we decide to
|
||||
use file manager to help maintain the objects in the project.
|
||||
'''
|
||||
def __init__(self, experiment_id):
|
||||
super(MLflowRecorder, self).__init__(experiment_id)
|
||||
self.fm = None
|
||||
self.temp_dir = None
|
||||
|
||||
def start_run(self, run_id=None, experiment_id=None,
|
||||
run_name=None, nested=False):
|
||||
if run_id is None:
|
||||
@@ -277,65 +163,67 @@ class MLflowRecorder(Recorder):
|
||||
self.recorder_id = run.info.run_id
|
||||
self.artifact_uri = run.info.artifact_uri
|
||||
# set up file manager for saving objects
|
||||
if self.artifact_uri.startswith('file:/'):
|
||||
self.fm = FileManager(Path(urllib.parse.urlparse(self.artifact_uri).path))
|
||||
else:
|
||||
self.fm = FileManager(Path(self.artifact_uri))
|
||||
print(self.artifact_uri)
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.fm = FileManager(Path(self.temp_dir).absolute())
|
||||
return run
|
||||
|
||||
def end_run(self):
|
||||
mlflow.end_run()
|
||||
shutil.rmtree(self.temp_dir)
|
||||
|
||||
def save_object(self, name, data):
|
||||
self.fm.save_obj(data, name)
|
||||
import urllib
|
||||
print(urllib.parse.urlparse(self.artifact_uri).scheme)
|
||||
try:
|
||||
self.log_artifact(self.fm.path / name)
|
||||
except shutil.SameFileError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
def save_object(self, data, name, local_path=None):
|
||||
if local_path is None:
|
||||
assert data is not None and name is not None, "Please provide data and name input."
|
||||
self.fm.save_obj(data, name)
|
||||
mlflow.log_artifact(self.fm.path / name)
|
||||
self.fm.remove(name)
|
||||
else:
|
||||
mlflow.log_artifact(local_path)
|
||||
|
||||
def save_objects(self, name_data_list):
|
||||
self.fm.save_objs(name_data_list)
|
||||
try:
|
||||
self.log_artifacts(self.fm.path)
|
||||
except shutil.SameFileError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
def save_objects(self, data_name_list, local_path=None):
|
||||
if local_path is None:
|
||||
assert data_name_list is not None, "Please provide data_name_list input."
|
||||
self.fm.save_objs(data_name_list)
|
||||
mlflow.log_artifacts(self.fm.path)
|
||||
for obj, name in data_name_list:
|
||||
self.fm.remove(name)
|
||||
else:
|
||||
mlflow.log_artifacts(local_path)
|
||||
|
||||
def load_object(self, name):
|
||||
return self.fm.load_obj(name)
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
path = client.download_artifacts(self.recorder_id, name)
|
||||
try:
|
||||
with Path(path).open('rb') as f:
|
||||
f.seek(0)
|
||||
return pickle.load(f)
|
||||
except:
|
||||
with codecs.open(path, mode="r", encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
keys = list(kwargs.keys())
|
||||
if len(keys) == 0:
|
||||
mlflow.log_param(keys[0], kwargs.get(keys[0]))
|
||||
else:
|
||||
mlflow.log_params(dict(kwargs))
|
||||
|
||||
def log_param(self, key, value):
|
||||
mlflow.log_param(key, value)
|
||||
|
||||
def log_params(self, params):
|
||||
mlflow.log_params(params)
|
||||
|
||||
def log_metric(self, key, value, step=None):
|
||||
mlflow.log_metric(key, value, step)
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
mlflow.log_metrics(metrics, step)
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
keys = list(kwargs.keys())
|
||||
if len(keys) == 0:
|
||||
mlflow.log_metric(keys[0], kwargs.get(keys[0]))
|
||||
else:
|
||||
mlflow.log_metrics(dict(kwargs))
|
||||
|
||||
def set_tag(self, key, value):
|
||||
mlflow.set_tag(key, value)
|
||||
|
||||
def set_tags(self, tags):
|
||||
mlflow.set_tags(tags)
|
||||
def set_tags(self, **kwargs):
|
||||
keys = list(kwargs.keys())
|
||||
if len(keys) == 0:
|
||||
mlflow.set_tag(keys[0], kwargs.get(keys[0]))
|
||||
else:
|
||||
mlflow.set_tags(dict(kwargs))
|
||||
|
||||
def delete_tag(self, key):
|
||||
mlflow.delete_tag(key)
|
||||
|
||||
def log_artifact(self, local_path, artifact_path=None):
|
||||
mlflow.log_artifact(local_path, artifact_path)
|
||||
|
||||
def log_artifacts(self, local_dir, artifact_path=None):
|
||||
mlflow.log_artifacts(local_dir, artifact_path)
|
||||
|
||||
def get_artifact_uri(self, artifact_path=None):
|
||||
if self.artifact_uri is not None:
|
||||
|
||||
Reference in New Issue
Block a user