1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00

Add RecordTemp & update

This commit is contained in:
Jactus
2020-11-02 11:05:40 +08:00
parent da9d1c8ac6
commit 5f9c8be33d
11 changed files with 263 additions and 81 deletions

View File

@@ -82,12 +82,11 @@ def init(default_conf="client", **kwargs):
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
# set up QlibRecorder
default_uri = str(Path(os.getcwd()).resolve() / "mlruns")
current_uri = C["exp_uri"] if C["exp_uri"] is not None else default_uri
uri = C["exp_uri"]
# exp manager module
module = get_module_by_module_path("qlib.workflow")
module = get_module_by_module_path("qlib.workflow.expm")
exp_manager = init_instance_by_config(C["exp_manager"], module)
qr = QlibRecorder(exp_manager, default_uri, current_uri)
qr = QlibRecorder(exp_manager, uri)
R.register(qr)

View File

@@ -126,7 +126,7 @@ _default_config = {
},
# Defatult config for experiment manager
"exp_manager": {"class": "MLflowExpManager", "kwargs": {}},
"exp_uri": None,
"exp_uri": str(Path(os.getcwd()).resolve() / "mlruns"),
}
MODE_CONF = {

View File

@@ -24,7 +24,7 @@ from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, get_provider_obj, register_wrapper
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
class CalendarProvider(abc.ABC):
@@ -1031,34 +1031,44 @@ D = Wrapper()
def register_all_wrappers():
"""register_all_wrappers"""
logger = get_module_logger("data")
_calendar_provider = get_provider_obj(C.calendar_provider)
module = get_module_by_module_path("qlib.data")
_calendar_provider = init_instance_by_config(C.calendar_provider, module)
if getattr(C, "calendar_cache", None) is not None:
_calendar_provider = get_provider_obj(C.calendar_cache, provider=_calendar_provider)
register_wrapper(Cal, _calendar_provider)
_calendar_cache_config = {}
_calendar_cache_config.update(C.calendar_cache)
_calendar_cache_config['kwargs'].update(provider=_calendar_provider)
_calendar_provider = init_instance_by_config(_calendar_cache_config, module)
register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calenar_cache}")
register_wrapper(Inst, C.instrument_provider)
register_wrapper(Inst, C.instrument_provider, "qlib.data")
logger.debug(f"registering Inst {C.instrument_provider}")
if getattr(C, "feature_provider", None) is not None:
feature_provider = get_provider_obj(C.feature_provider)
register_wrapper(FeatureD, feature_provider)
feature_provider = init_instance_by_config(C.feature_provider, module)
register_wrapper(FeatureD, feature_provider, "qlib.data")
logger.debug(f"registering FeatureD {C.feature_provider}")
if getattr(C, "expression_provider", None) is not None:
# This provider is unnecessary in client provider
_eprovider = get_provider_obj(C.expression_provider)
_eprovider = init_instance_by_config(C.expression_provider, module)
if getattr(C, "expression_cache", None) is not None:
_eprovider = get_provider_obj(C.expression_cache, provider=_eprovider)
register_wrapper(ExpressionD, _eprovider)
_expression_cache_config = {}
_expression_cache_config.update(C.expression_cache)
_expression_cache_config['kwargs'].update(provider=_eprovider)
_eprovider = init_instance_by_config(C.expression_cache, module)
register_wrapper(ExpressionD, _eprovider, "qlib.data")
logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}")
_dprovider = get_provider_obj(C.dataset_provider)
_dprovider = init_instance_by_config(C.dataset_provider, module)
if getattr(C, "dataset_cache", None) is not None:
_dprovider = get_provider_obj(C.dataset_cache, provider=_dprovider)
register_wrapper(DatasetD, _dprovider)
_dataset_cache_config = {}
_dataset_cache_config.update(C.dataset_cache)
_dataset_cache_config['kwargs'].update(provider=_dprovider)
_dprovider = init_instance_by_config(_dataset_cache_config, module)
register_wrapper(DatasetD, _dprovider, "qlib.data")
logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}")
register_wrapper(D, C.provider)
register_wrapper(D, C.provider, "qlib.data")
logger.debug(f"registering D {C.provider}")

View File

@@ -632,21 +632,14 @@ class Wrapper(object):
return getattr(self._provider, key)
def get_provider_obj(config, **params):
module = get_module_by_module_path("qlib.data")
klass, kwargs = get_cls_kwargs(config, module)
kwargs.update(params)
return klass(**kwargs)
def register_wrapper(wrapper, cls_or_obj):
def register_wrapper(wrapper, cls_or_obj, module_path=None):
"""register_wrapper
:param wrapper: A wrapper of all kinds of providers
:param cls_or_obj: A class or class name or object instance in data/data.py
:param wrapper: A wrapper.
:param cls_or_obj: A class or class name or object instance.
"""
if isinstance(cls_or_obj, str):
module = get_module_by_module_path("qlib.data")
module = get_module_by_module_path(module_path)
cls_or_obj = getattr(module, cls_or_obj)
obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
wrapper.register(obj)

View File

@@ -2,24 +2,29 @@
# Licensed under the MIT License.
from contextlib import contextmanager
from .expm import *
from .expm import MLflowExpManager
from ..utils import Wrapper
class QlibRecorder:
def __init__(self, exp_manager, default_uri, current_uri):
"""
A global system that helps to manage the experiments.
"""
def __init__(self, exp_manager, uri):
self.exp_manager = exp_manager
self.default_uri = default_uri
self.current_uri = current_uri
self.uri = uri
@contextmanager
def start(self, experiment_name):
run = self.start_exp(experiment_name, self.current_uri)
yield run
run = self.start_exp(experiment_name, self.uri)
try:
yield run
except:
self.end_exp() # end the experiment if something went wrong
self.end_exp()
def start_exp(self, experiment_name=None):
return self.exp_manager.start_exp(experiment_name, self.current_uri)
return self.exp_manager.start_exp(experiment_name, self.uri)
def end_exp(self):
self.exp_manager.end_exp()
@@ -33,8 +38,8 @@ class QlibRecorder:
def delete_exp(self, experiment_id):
self.exp_manager.delete_exp(experiment_id)
def get_uri(self, type):
return self.exp_manager.get_uri(type)
def get_uri(self):
return self.exp_manager.get_uri()
def get_recorder(self):
return self.exp_manager.active_recorder

View File

@@ -3,7 +3,7 @@
import mlflow
from pathlib import Path
from .recorder import MLflowRecorder
class Experiment:
"""
@@ -15,6 +15,19 @@ class Experiment:
self.id = None
self.recorders = list()
def create_recorder(self):
"""
Create a recorder for each experiment.
Parameters
----------
Returns
-------
A recorder instance.
"""
raise NotImplementedError(f"Please implement the `create_recorder` method.")
def search_records(self, **kwargs):
"""
Get a pandas DataFrame of records that fit the search criteria of the experiment.
@@ -36,15 +49,39 @@ class Experiment:
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def delete_recorder(self, rid):
"""
Create a recorder for each experiment.
Parameters
----------
rid : str
the id of the recorder to be deleted.
Returns
-------
A recorder instance.
"""
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
class MLflowExperiment(Experiment):
"""
Use mlflow to implement Experiment.
"""
def create_recorder(self):
recorder = MLflowRecorder(self.id)
self.recorders.append(recorder)
return recorders
def search_records(self, **kwargs):
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
order_by = kwargs.get("order_by")
return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by)
def delete_recorder(self, rid):
mlflow.delete_run(rid)
self.recorders = [r for r in self.recorders if r.recorder_id == rid]

View File

@@ -6,8 +6,10 @@ import os
from pathlib import Path
from contextlib import contextmanager
from .exp import MLflowExperiment
from .record import MLflowRecorder
from .recorder import MLflowRecorder
from ..log import get_module_logger
logger = get_module_logger('workflow', 'Warning')
class ExpManager:
"""
@@ -16,7 +18,7 @@ class ExpManager:
"""
def __init__(self):
self.default_uri = None
self.uri = None
self.active_recorder = None # only one recorder can running each time
self.experiments = dict() # store the experiment name --> Experiment object
@@ -117,20 +119,18 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
def get_uri(self, type):
def get_uri(self):
"""
Get the default tracking URI or current URI.
Parameters
----------
type : str
the type of the tracking URI one wants to retrieve.
Returns
-------
The tracking URI string.
"""
raise NotImplementedError(f"Please implement the `create_exp` method.")
return self.uri
def get_recorder(self):
"""
@@ -143,7 +143,7 @@ class ExpManager:
-------
An Recorder object.
"""
raise NotImplementedError(f"Please implement the `get_recorder` method.")
return self.active_recorder
class MLflowExpManager(ExpManager):
@@ -153,17 +153,14 @@ class MLflowExpManager(ExpManager):
def __init__(self):
super(MLflowExpManager, self).__init__()
self.default_uri = None
self.current_uri = None
self.uri = None
def start_exp(self, experiment_name=None, uri=None):
# create experiment
experiment = self.__create_exp(experiment_name, uri)
# set up recorder
recorder = MLflowRecorder(experiment.id)
recorder = experiment.create_recorder()
self.active_recorder = recorder
# store the recorder
experiment.recorders.append(self.active_recorder)
# store the experiment
self.experiments[experiment_name] = experiment
@@ -178,15 +175,15 @@ class MLflowExpManager(ExpManager):
experiment = MLflowExperiment()
# set the tracking uri
if uri is None:
print(
logger.warning(
"No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory."
)
else:
self.current_uri = uri
mlflow.set_tracking_uri(self.current_uri)
self.uri = uri
mlflow.set_tracking_uri(self.uri)
# start the experiment
if experiment_name is None:
print("No experiment name provided. The default experiment name is set as `experiment`.")
logger.warning("No experiment name provided. The default experiment name is set as `experiment`.")
experiment_id = mlflow.create_experiment("experiment")
# set the active experiment
mlflow.set_experiment("experiment")
@@ -227,19 +224,8 @@ class MLflowExpManager(ExpManager):
if self.experiments[name].id == experiment_id:
return self.experiments[name]
else:
print("No valid experiment is found. Please make sure the id and name are correctly given.")
raise Exception("No valid experiment is found. Please make sure the id and name are correctly given.")
def delete_exp(self, experiment_id):
mlflow.delete_experiment(experiment_id)
self.experiments = {key: val for key, val in self.experiments.items() if val.id != experiment_id}
def get_uri(self, type):
if uri == "default":
return self.default_uri
elif uri == "current":
return self.current_uri
else:
raise ValueError("Input type is not supported. Please choose type default or current to get the uri.")
def get_recorder(self):
return self.active_recorder

View File

@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from pathlib import Path
from ..contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from ..utils import init_instance_by_config, get_module_by_module_path
class RecordTemp:
def __init__(self, *args, **kwargs):
pass
def generate(self, **kwargs):
"""
Generate certain records such as IC, backtest etc., and save them.
Parameters
----------
kwargs
Return
------
The generated records.
"""
raise NotImplementedError(f"Please implement the `generate` method.")
def check(self, **kwargs):
"""
Check if the records is properly generated and saved.
Parameters
----------
kwargs
"""
raise NotImplementedError(f"Please implement the `check` method.")
# TODO: this can only be run under R's running experiment.
class SignalRecord(RecordTemp):
def __init__(self, model, dataset, recorder, **kwargs):
super(SignalRecord, self).__init__()
self.model = model
self.dataset = dataset
self.recorder = recorder
def generate(self, **kwargs):
# generate prediciton
pred = self.model.predict(self.dataset)
self.recorder.save_object(pred, 'pred.pkl')
def load(self):
# try to load the saved object
try:
pred = self.recorder.load_object('pred.pkl')
return pred
except:
raise Exception('Something went wrong when loading the saved object.')
def check(self, **kwargs):
return self.recorder.check('pred.pkl')
# TODO
class SigAnaRecord(SignalRecord):
def __init__(self, recorder, **kwargs):
def generate(self):
pass
def load(self):
pass
def check(self):
pass
class PortAnaRecord(SignalRecord):
def __init__(self, recorder, STRATEGY_CONFIG, BACKTEST_CONFIG, **kwargs):
self.recorder = recorder
self.STRATEGY_CONFIG = STRATEGY_CONFIG
self.BACKTEST_CONFIG = BACKTEST_CONFIG
module = get_module_by_module_path("qlib.contrib.strategy")
self.strategy = init_instance_by_config(STRATEGY_CONFIG, module)
self.artifact_path = Path('portfolio_analysis').resolve()
def generate(self, **kwargs):
"""
STRATEGY_CONFIG : dict
define the strategy class as well as the kwargs.
BACKTEST_CONFIG : dict
define the backtest kwargs.
"""
# check previously stored prediction results
assert super().check(), "Make sure the parent process is completed and store the data properly."
# custom strategy and get backtest
pred_score = super().load()
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.BACKTEST_CONFIG)
self.recorder.save_object(report_normal, 'report_normal.pkl', self.artifact_path)
self.recorder.save_object(positions_normal, 'positions_normal.pkl', self.artifact_path)
# analysis
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
self.recorder.save_object(pred, 'port_analysis.pkl', self.artifact_path)
def load(self):
# try to load the saved object
try:
pred = self.recorder.load_object(self.artifact_path / 'port_analysis.pkl'')
return pred
except:
raise Exception('Something went wrong when loading the saved object.')
def check(self):
return self.recorder.check('port_analysis.pkl', self.artifact_path)

View File

@@ -21,7 +21,7 @@ class Recorder:
def set_recorder_name(self, rname):
self.recorder_name = rname
def save_object(self, data, name, local_path=None):
def save_object(self, data=None, name=None, local_path=None, artifact_path=None):
"""
Save object such as prediction file or model checkpoints to the artifact URI.
@@ -33,10 +33,12 @@ class Recorder:
name of the file to be saved.
local_path : str
if provided, them save the file or directory to the artifact URI.
artifact_path=None : str
the relative path for the artifact to be stored in the URI.
"""
raise NotImplementedError(f"Please implement the `save_object` method.")
def save_objects(self, data_name_list, local_path=None):
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None):
"""
Save objects such as prediction file or model checkpoints to the artifact URI.
@@ -46,6 +48,8 @@ class Recorder:
list of (data, name) pairs
local_path : str
if provided, them save the file or directory to the artifact URI.
artifact_path=None : str
the relative path for the artifact to be stored in the URI.
"""
raise NotImplementedError(f"Please implement the `save_objects` method.")
@@ -162,6 +166,7 @@ class MLflowRecorder(Recorder):
# save the run id and artifact_uri
self.recorder_id = run.info.run_id
self.artifact_uri = run.info.artifact_uri
self._uri = mlflow.get_tracking_uri() # Fix!!! : this is not proper to have uri in recorder
# set up file manager for saving objects
self.temp_dir = tempfile.mkdtemp()
self.fm = FileManager(Path(self.temp_dir).absolute())
@@ -171,27 +176,27 @@ class MLflowRecorder(Recorder):
mlflow.end_run()
shutil.rmtree(self.temp_dir)
def save_object(self, data, name, local_path=None):
def save_object(self, data=None, name=None, local_path=None, artifact_path=None):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
if local_path is None:
assert data is not None and name is not None, "Please provide data and name input."
self.fm.save_obj(data, name)
mlflow.log_artifact(self.fm.path / name)
self.fm.remove(name)
client.log_artifact(self.recorder_id, self.fm.path / name, artifact_path)
else:
mlflow.log_artifact(local_path)
assert local_path is not None, "Please provide a valid local path for the "
client.log_artifact(self.recorder_id, local_path, artifact_path)
def save_objects(self, data_name_list, local_path=None):
def save_objects(self, data_name_list=None, local_path=None, artifact_path=None):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
if local_path is None:
assert data_name_list is not None, "Please provide data_name_list input."
self.fm.save_objs(data_name_list)
mlflow.log_artifacts(self.fm.path)
for obj, name in data_name_list:
self.fm.remove(name)
client.log_artifacts(self.recorder_id, self.fm.path, artifact_path)
else:
mlflow.log_artifacts(local_path)
client.log_artifacts(self.recorder_id, local_path, artifact_path)
def load_object(self, name):
client = mlflow.tracking.MlflowClient()
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
path = client.download_artifacts(self.recorder_id, name)
try:
with Path(path).open("rb") as f:
@@ -229,3 +234,11 @@ class MLflowRecorder(Recorder):
if self.artifact_uri is not None:
return self.artifact_uri
return mlflow.get_artifact_uri(artifact_path)
def check(self, name, path=None):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
artifacts = client.list_artifacts(self.recorder_id, path)
for artifact in artifacts
if name in artifact.path:
return True
return False

View File

@@ -22,3 +22,4 @@ scikit_learn==0.23.2
torch==1.6.0
tqdm==4.49.0
yahooquery==2.2.7
mlflow==1.11.0

View File

@@ -50,6 +50,7 @@ REQUIRED = [
"matplotlib==3.1.3",
"tables>=3.6.1",
"pyyaml>=5.3.1",
"mlflow>=1.10.0",
"tqdm",
"loguru",
"lightgbm",