1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

Merge pull request #302 from D-X-Y/main

Update repr for dataset/workflow classes and add uri kwarg for QlibRecorder
This commit is contained in:
you-n-g
2021-03-08 14:01:53 +08:00
committed by GitHub
9 changed files with 134 additions and 46 deletions

2
.gitignore vendored
View File

@@ -34,3 +34,5 @@ tags
.pytest_cache/
.vscode/
*.swp

View File

@@ -1,5 +1,5 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
@@ -76,7 +76,7 @@ class DatasetH(Dataset):
- The processing is related to data split.
"""
def __init__(self, handler: Union[dict, DataHandler], segments: dict):
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict):
"""
Parameters
----------
@@ -87,7 +87,7 @@ class DatasetH(Dataset):
"""
super().__init__(handler, segments)
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
def init(self, handler_kwargs: Optional[Dict] = None, segment_kwargs: Optional[Dict] = None):
"""
Initialize the DatasetH
@@ -124,7 +124,7 @@ class DatasetH(Dataset):
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
self.segments = segment_kwargs.copy()
def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
"""
Setup the underlying data.
@@ -156,6 +156,11 @@ class DatasetH(Dataset):
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
def __repr__(self):
return "{name}(handler={handler}, segments={segments})".format(
name=self.__class__.__name__, handler=self.handler, segments=self.segments
)
def _prepare_seg(self, slc: slice, **kwargs):
"""
Give a slice, retrieve the according data
@@ -168,7 +173,7 @@ class DatasetH(Dataset):
def prepare(
self,
segments: Union[List[str], Tuple[str], str, slice],
segments: Union[List[Text], Tuple[Text], Text, slice],
col_set=DataHandler.CS_ALL,
data_key=DataHandlerLP.DK_I,
**kwargs,
@@ -178,7 +183,7 @@ class DatasetH(Dataset):
Parameters
----------
segments : Union[List[str], Tuple[str], str, slice]
segments : Union[List[Text], Tuple[Text], Text, slice]
Describe the scope of the data to be prepared
Here are some examples:

View File

@@ -35,7 +35,7 @@ class DataHandler(Serializable):
The data handler try to maintain a handler with 2 level.
`datetime` & `instruments`.
Any order of the index level can be suported(The order will implied in the data).
Any order of the index level can be suported (The order will be implied in the data).
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
Example of the data:
@@ -47,8 +47,8 @@ class DataHandler(Serializable):
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
datetime instrument
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
"""

View File

@@ -1489,7 +1489,7 @@ OpsList = [
]
class OpsWrapper(object):
class OpsWrapper:
"""Ops Wrapper"""
def __init__(self):

View File

@@ -16,8 +16,11 @@ class QlibRecorder:
def __init__(self, exp_manager):
self.exp_manager = exp_manager
def __repr__(self):
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
@contextmanager
def start(self, experiment_name=None, recorder_name=None):
def start(self, experiment_name=None, recorder_name=None, uri=None):
"""
Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:
@@ -34,8 +37,13 @@ class QlibRecorder:
name of the experiment one wants to start.
recorder_name : str
name of the recorder under the experiment one wants to start.
uri : str
The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
Therefore, the next time when users call this function in the same experiment,
they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
"""
run = self.start_exp(experiment_name, recorder_name)
run = self.start_exp(experiment_name, recorder_name, uri)
try:
yield run
except Exception as e:
@@ -272,7 +280,7 @@ class QlibRecorder:
-------
The uri of current experiment manager.
"""
return self.exp_manager.get_uri()
return self.exp_manager.uri
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
"""

View File

@@ -23,7 +23,7 @@ class Experiment:
self.active_recorder = None # only one recorder can running each time
def __repr__(self):
return str(self.info)
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)
def __str__(self):
return str(self.info)
@@ -173,11 +173,9 @@ class MLflowExperiment(Experiment):
self._uri = uri
self._default_name = None
self._default_rec_name = "mlflow_recorder"
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
def start(self, recorder_name=None):
# set the active experiment
mlflow.set_experiment(self.name)
logger.info(f"Experiment {self.id} starts running ...")
# set up recorder
recorder = self.create_recorder(recorder_name)
@@ -210,7 +208,6 @@ class MLflowExperiment(Experiment):
else:
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
if is_new:
mlflow.set_experiment(self.name)
self.active_recorder = recorder
# start the recorder
self.active_recorder.start_run()
@@ -239,7 +236,7 @@ class MLflowExperiment(Experiment):
), "Please input at least one of recorder id or name before retrieving recorder."
if recorder_id is not None:
try:
run = self.client.get_run(recorder_id)
run = self._client.get_run(recorder_id)
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
return recorder
except MlflowException:
@@ -260,7 +257,7 @@ class MLflowExperiment(Experiment):
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
order_by = kwargs.get("order_by")
return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
def delete_recorder(self, recorder_id=None, recorder_name=None):
assert (
@@ -268,10 +265,10 @@ class MLflowExperiment(Experiment):
), "Please input a valid recorder id or name before deleting."
try:
if recorder_id is not None:
self.client.delete_run(recorder_id)
self._client.delete_run(recorder_id)
else:
recorder = self._get_recorder(recorder_name=recorder_name)
self.client.delete_run(recorder.id)
self._client.delete_run(recorder.id)
except MlflowException as e:
raise Exception(
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
@@ -280,7 +277,7 @@ class MLflowExperiment(Experiment):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
def list_recorders(self, max_results=UNLIMITED):
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])

View File

@@ -7,8 +7,10 @@ from mlflow.entities import ViewType
import os
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Text
from .exp import MLflowExperiment, Experiment
from .recorder import Recorder, MLflowRecorder
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
@@ -20,12 +22,24 @@ class ExpManager:
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
"""
def __init__(self, uri, default_exp_name):
self.uri = uri
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
self._default_uri = uri
self._current_uri = None
self.default_exp_name = default_exp_name
self.active_experiment = None # only one experiment can active each time
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
def __repr__(self):
return "{name}(default_uri={duri}, current_uri={curi})".format(
name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri
)
def start_exp(
self,
experiment_name: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
**kwargs,
):
"""
Start an experiment. This method includes first get_or_create an experiment, and then
set it to be active.
@@ -45,7 +59,7 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `start_exp` method.")
def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
"""
End an active experiment.
@@ -58,7 +72,7 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `end_exp` method.")
def create_exp(self, experiment_name=None):
def create_exp(self, experiment_name: Optional[Text] = None):
"""
Create an experiment.
@@ -203,7 +217,8 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `delete_exp` method.")
def get_uri(self):
@property
def uri(self):
"""
Get the default tracking URI or current URI.
@@ -211,7 +226,31 @@ class ExpManager:
-------
The tracking URI string.
"""
return self.uri
return self._current_uri or self._default_uri
def set_uri(self, uri: Optional[Text] = None):
"""
Set the current tracking URI and the corresponding variables.
Parameters
----------
uri : str
"""
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
self._current_uri = self._default_uri
else:
# Temporarily re-set the current uri as the uri argument.
self._current_uri = uri
# Customized features for subclasses.
self._set_uri()
def _set_uri(self):
"""
Customized features for subclasses' set_uri function.
"""
raise NotImplementedError(f"Please implement the `_set_uri` method.")
def list_experiments(self):
"""
@@ -229,37 +268,43 @@ class MLflowExpManager(ExpManager):
Use mlflow to implement ExpManager.
"""
def __init__(self, uri, default_exp_name):
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
super(MLflowExpManager, self).__init__(uri, default_exp_name)
self._client = None
def _set_uri(self):
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
logger.info("{:}".format(self._client))
@property
def client(self):
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
if not hasattr(self, "_client"):
if self._client is None:
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
return self._client
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
# set the tracking uri
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
else:
self.uri = uri
# create experiment
def start_exp(
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
):
# Set the tracking uri
self.set_uri(uri)
# Create experiment
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
# set up active experiment
# Set up active experiment
self.active_experiment = experiment
# start the experiment
# Start the experiment
self.active_experiment.start(recorder_name)
return self.active_experiment
def end_exp(self, recorder_status: str = Recorder.STATUS_S):
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
if self.active_experiment is not None:
self.active_experiment.end(recorder_status)
self.active_experiment = None
# When an experiment end, we will release the current uri.
self._current_uri = None
def create_exp(self, experiment_name=None):
def create_exp(self, experiment_name: Optional[Text] = None):
assert experiment_name is not None
# init experiment
experiment_id = self.client.create_experiment(experiment_name)

View File

@@ -34,7 +34,7 @@ class Recorder:
self.status = Recorder.STATUS_S
def __repr__(self):
return str(self.info)
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)
def __str__(self):
return str(self.info)

View File

@@ -96,7 +96,6 @@ port_analysis_config = {
}
# train
def train():
"""train model
@@ -111,6 +110,9 @@ def train():
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
# To test __repr__
print(dataset)
print(R)
# start exp
with R.start(experiment_name="workflow"):
@@ -119,6 +121,8 @@ def train():
# prediction
recorder = R.get_recorder()
# To test __repr__
print(recorder)
rid = recorder.id
sr = SignalRecord(model, dataset, recorder)
sr.generate()
@@ -133,6 +137,27 @@ def train():
return pred_score, {"ic": ic, "ric": ric}, rid
def fake_experiment():
"""A fake experiment workflow to test uri
Returns
-------
pass_or_not_for_default_uri: bool
pass_or_not_for_current_uri: bool
temporary_exp_dir: str
"""
# start exp
default_uri = R.get_uri()
current_uri = "file:./temp-test-exp-mag"
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
R.log_params(**flatten_dict(task))
current_uri_to_check = R.get_uri()
default_uri_to_check = R.get_uri()
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
def backtest_analysis(pred, rid):
"""backtest and analysis
@@ -181,6 +206,12 @@ class TestAllFlow(TestAutoData):
"backtest failed",
)
def test_2_expmanager(self):
pass_default, pass_current, uri_path = fake_experiment()
self.assertTrue(pass_default, msg="default uri is incorrect")
self.assertTrue(pass_current, msg="current uri is incorrect")
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
def suite():
_suite = unittest.TestSuite()