1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Merge branch 'main' of github.com:you-n-g/qlib into main

This commit is contained in:
Young
2020-11-10 01:46:01 +00:00
6 changed files with 23 additions and 19 deletions

View File

@@ -82,7 +82,7 @@ def init(default_conf="client", **kwargs):
if "flask_server" in C:
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
# exp manager module
# set up QlibRecorder
module = get_module_by_module_path("qlib.workflow.expm")
exp_manager = init_instance_by_config(C["exp_manager"], module)
qr = QlibRecorder(exp_manager)

View File

@@ -43,26 +43,26 @@ class QlibRecorder:
def get_uri(self):
return self.exp_manager.get_uri()
def get_recorder(self):
return self.exp_manager.active_recorder
def get_recorder(self, recorder_id=None, recorder_name=None):
return self.exp_manager.active_experiment.get_recorder(recorder_id, recorder_name)
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
self.exp_manager.active_recorder.save_objects(local_path, artifact_path, **kwargs)
self.exp_manager.active_experiment.active_recorder.save_objects(local_path, artifact_path, **kwargs)
def load_object(self, name):
return self.exp_manager.active_recorder.load_object(name)
return self.exp_manager.active_experiment.active_recorder.load_object(name)
def log_params(self, **kwargs):
self.exp_manager.active_recorder.log_params(**kwargs)
self.exp_manager.active_experiment.active_recorder.log_params(**kwargs)
def log_metrics(self, step=None, **kwargs):
self.exp_manager.active_recorder.log_metrics(step, **kwargs)
self.exp_manager.active_experiment.active_recorder.log_metrics(step, **kwargs)
def set_tags(self, **kwargs):
self.exp_manager.active_recorder.set_tags(**kwargs)
self.exp_manager.active_experiment.active_recorder.set_tags(**kwargs)
def delete_tag(self, *key):
self.exp_manager.active_recorder.delete_tag(*key)
self.exp_manager.active_experiment.active_recorder.delete_tag(*key)
# global record

View File

@@ -4,7 +4,9 @@
import mlflow
from pathlib import Path
from .recorder import MLflowRecorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
class Experiment:
"""

View File

@@ -184,10 +184,12 @@ class MLflowExpManager(ExpManager):
else:
if experiment_name not in self.experiments:
if mlflow.get_experiment_by_name(experiment_name) is not None:
raise Exception(
"The experiment has already been created before. Please pick another name or delete the files under uri."
logger.info(
"The experiment has already been created before. Try to resume the experiment..."
)
experiment_id = mlflow.create_experiment(experiment_name)
experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
else:
experiment_id = mlflow.create_experiment(experiment_name)
else:
experiment_id = self.experiments[experiment_name].id
experiment = self.experiments[experiment_name]

View File

@@ -76,7 +76,7 @@ class SignalRecord(RecordTemp):
def generate(self, **kwargs):
# generate prediciton
pred = self.model.predict(self.dataset)
self.recorder.save_object(pred, "pred.pkl")
self.recorder.save_objects(data=pred, name="pred.pkl")
def load(self):
# try to load the saved object
@@ -132,9 +132,9 @@ class PortAnaRecord(SignalRecord):
assert super().check(), "Make sure the parent process is completed and store the data properly."
# custom strategy and get backtest
pred_score = super().load()
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.BACKTEST_CONFIG)
self.recorder.save_object(report_normal, "report_normal.pkl", self.artifact_path)
self.recorder.save_object(positions_normal, "positions_normal.pkl", self.artifact_path)
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
self.recorder.save_objects(data=report_normal, name="report_normal.pkl", artifact_path=self.artifact_path)
self.recorder.save_objects(data=positions_normal, name="positions_normal.pkl", artifact_path=self.artifact_path)
# analysis
analysis = dict()
@@ -143,7 +143,7 @@ class PortAnaRecord(SignalRecord):
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
self.recorder.save_object(pred, "port_analysis.pkl", self.artifact_path)
self.recorder.save_objects(data=analysis_df, name="port_analysis.pkl", artifact_path=self.artifact_path)
def load(self):
# try to load the saved object

View File

@@ -201,7 +201,7 @@ class MLflowRecorder(Recorder):
def load_object(self, name):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
path = client.download_artifacts(self.recorder_id, name)
path = client.download_artifacts(self.id, name)
try:
with Path(path).open("rb") as f:
f.seek(0)
@@ -242,5 +242,5 @@ class MLflowRecorder(Recorder):
def list_artifacts(self, artifact_path=None):
client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
artifacts = client.list_artifacts(self.id, path)
artifacts = client.list_artifacts(self.id, artifact_path)
return artifacts