mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
Merge pull request #302 from D-X-Y/main
Update repr for dataset/workflow classes and add uri kwarg for QlibRecorder
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -34,3 +34,5 @@ tags
|
||||
|
||||
.pytest_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
@@ -76,7 +76,7 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -87,7 +87,7 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
super().__init__(handler, segments)
|
||||
|
||||
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
|
||||
def init(self, handler_kwargs: Optional[Dict] = None, segment_kwargs: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
@@ -124,7 +124,7 @@ class DatasetH(Dataset):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
|
||||
self.segments = segment_kwargs.copy()
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -156,6 +156,11 @@ class DatasetH(Dataset):
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
name=self.__class__.__name__, handler=self.handler, segments=self.segments
|
||||
)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
"""
|
||||
Give a slice, retrieve the according data
|
||||
@@ -168,7 +173,7 @@ class DatasetH(Dataset):
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs,
|
||||
@@ -178,7 +183,7 @@ class DatasetH(Dataset):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segments : Union[List[str], Tuple[str], str, slice]
|
||||
segments : Union[List[Text], Tuple[Text], Text, slice]
|
||||
Describe the scope of the data to be prepared
|
||||
Here are some examples:
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class DataHandler(Serializable):
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported(The order will implied in the data).
|
||||
Any order of the index level can be suported (The order will be implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
@@ -47,8 +47,8 @@ class DataHandler(Serializable):
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -1489,7 +1489,7 @@ OpsList = [
|
||||
]
|
||||
|
||||
|
||||
class OpsWrapper(object):
|
||||
class OpsWrapper:
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -16,8 +16,11 @@ class QlibRecorder:
|
||||
def __init__(self, exp_manager):
|
||||
self.exp_manager = exp_manager
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
|
||||
|
||||
@contextmanager
|
||||
def start(self, experiment_name=None, recorder_name=None):
|
||||
def start(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
"""
|
||||
Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:
|
||||
|
||||
@@ -34,8 +37,13 @@ class QlibRecorder:
|
||||
name of the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
|
||||
The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
|
||||
Therefore, the next time when users call this function in the same experiment,
|
||||
they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
|
||||
"""
|
||||
run = self.start_exp(experiment_name, recorder_name)
|
||||
run = self.start_exp(experiment_name, recorder_name, uri)
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
@@ -272,7 +280,7 @@ class QlibRecorder:
|
||||
-------
|
||||
The uri of current experiment manager.
|
||||
"""
|
||||
return self.exp_manager.get_uri()
|
||||
return self.exp_manager.uri
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,7 @@ class Experiment:
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
@@ -173,11 +173,9 @@ class MLflowExperiment(Experiment):
|
||||
self._uri = uri
|
||||
self._default_name = None
|
||||
self._default_rec_name = "mlflow_recorder"
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def start(self, recorder_name=None):
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(self.name)
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# set up recorder
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
@@ -210,7 +208,6 @@ class MLflowExperiment(Experiment):
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
if is_new:
|
||||
mlflow.set_experiment(self.name)
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
self.active_recorder.start_run()
|
||||
@@ -239,7 +236,7 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input at least one of recorder id or name before retrieving recorder."
|
||||
if recorder_id is not None:
|
||||
try:
|
||||
run = self.client.get_run(recorder_id)
|
||||
run = self._client.get_run(recorder_id)
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
|
||||
return recorder
|
||||
except MlflowException:
|
||||
@@ -260,7 +257,7 @@ class MLflowExperiment(Experiment):
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
|
||||
return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def delete_recorder(self, recorder_id=None, recorder_name=None):
|
||||
assert (
|
||||
@@ -268,10 +265,10 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input a valid recorder id or name before deleting."
|
||||
try:
|
||||
if recorder_id is not None:
|
||||
self.client.delete_run(recorder_id)
|
||||
self._client.delete_run(recorder_id)
|
||||
else:
|
||||
recorder = self._get_recorder(recorder_name=recorder_name)
|
||||
self.client.delete_run(recorder.id)
|
||||
self._client.delete_run(recorder.id)
|
||||
except MlflowException as e:
|
||||
raise Exception(
|
||||
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
|
||||
@@ -280,7 +277,7 @@ class MLflowExperiment(Experiment):
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
|
||||
@@ -7,8 +7,10 @@ from mlflow.entities import ViewType
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
@@ -20,12 +22,24 @@ class ExpManager:
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
self.uri = uri
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
self._default_uri = uri
|
||||
self._current_uri = None
|
||||
self.default_exp_name = default_exp_name
|
||||
self.active_experiment = None # only one experiment can active each time
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
|
||||
def __repr__(self):
|
||||
return "{name}(default_uri={duri}, current_uri={curi})".format(
|
||||
name=self.__class__.__name__, duri=self._default_uri, curi=self._current_uri
|
||||
)
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Start an experiment. This method includes first get_or_create an experiment, and then
|
||||
set it to be active.
|
||||
@@ -45,7 +59,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
|
||||
"""
|
||||
End an active experiment.
|
||||
|
||||
@@ -58,7 +72,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
@@ -203,7 +217,8 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_exp` method.")
|
||||
|
||||
def get_uri(self):
|
||||
@property
|
||||
def uri(self):
|
||||
"""
|
||||
Get the default tracking URI or current URI.
|
||||
|
||||
@@ -211,7 +226,31 @@ class ExpManager:
|
||||
-------
|
||||
The tracking URI string.
|
||||
"""
|
||||
return self.uri
|
||||
return self._current_uri or self._default_uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text] = None):
|
||||
"""
|
||||
Set the current tracking URI and the corresponding variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
|
||||
"""
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
self._current_uri = self._default_uri
|
||||
else:
|
||||
# Temporarily re-set the current uri as the uri argument.
|
||||
self._current_uri = uri
|
||||
# Customized features for subclasses.
|
||||
self._set_uri()
|
||||
|
||||
def _set_uri(self):
|
||||
"""
|
||||
Customized features for subclasses' set_uri function.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_set_uri` method.")
|
||||
|
||||
def list_experiments(self):
|
||||
"""
|
||||
@@ -229,37 +268,43 @@ class MLflowExpManager(ExpManager):
|
||||
Use mlflow to implement ExpManager.
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
super(MLflowExpManager, self).__init__(uri, default_exp_name)
|
||||
self._client = None
|
||||
|
||||
def _set_uri(self):
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
logger.info("{:}".format(self._client))
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
|
||||
if not hasattr(self, "_client"):
|
||||
if self._client is None:
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
return self._client
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
else:
|
||||
self.uri = uri
|
||||
# create experiment
|
||||
def start_exp(
|
||||
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
|
||||
):
|
||||
# Set the tracking uri
|
||||
self.set_uri(uri)
|
||||
# Create experiment
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
# set up active experiment
|
||||
# Set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
# Start the experiment
|
||||
self.active_experiment.start(recorder_name)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(recorder_status)
|
||||
self.active_experiment = None
|
||||
# When an experiment end, we will release the current uri.
|
||||
self._current_uri = None
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
assert experiment_name is not None
|
||||
# init experiment
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
|
||||
@@ -34,7 +34,7 @@ class Recorder:
|
||||
self.status = Recorder.STATUS_S
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
|
||||
@@ -96,7 +96,6 @@ port_analysis_config = {
|
||||
}
|
||||
|
||||
|
||||
# train
|
||||
def train():
|
||||
"""train model
|
||||
|
||||
@@ -111,6 +110,9 @@ def train():
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
# To test __repr__
|
||||
print(dataset)
|
||||
print(R)
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
@@ -119,6 +121,8 @@ def train():
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
# To test __repr__
|
||||
print(recorder)
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
@@ -133,6 +137,27 @@ def train():
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
"""A fake experiment workflow to test uri
|
||||
|
||||
Returns
|
||||
-------
|
||||
pass_or_not_for_default_uri: bool
|
||||
pass_or_not_for_current_uri: bool
|
||||
temporary_exp_dir: str
|
||||
"""
|
||||
|
||||
# start exp
|
||||
default_uri = R.get_uri()
|
||||
current_uri = "file:./temp-test-exp-mag"
|
||||
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
|
||||
R.log_params(**flatten_dict(task))
|
||||
|
||||
current_uri_to_check = R.get_uri()
|
||||
default_uri_to_check = R.get_uri()
|
||||
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
|
||||
|
||||
|
||||
def backtest_analysis(pred, rid):
|
||||
"""backtest and analysis
|
||||
|
||||
@@ -181,6 +206,12 @@ class TestAllFlow(TestAutoData):
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
def test_2_expmanager(self):
|
||||
pass_default, pass_current, uri_path = fake_experiment()
|
||||
self.assertTrue(pass_default, msg="default uri is incorrect")
|
||||
self.assertTrue(pass_current, msg="current uri is incorrect")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
|
||||
Reference in New Issue
Block a user