diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 73da5e19f..13c4bc7a0 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -11,6 +11,7 @@ from datetime import datetime from qlib.utils.exceptions import LoadObjectError from ..utils.objm import FileManager from ..log import get_module_logger +from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository logger = get_module_logger("workflow", logging.INFO) @@ -335,7 +336,11 @@ class MLflowRecorder(Recorder): path = self.client.download_artifacts(self.id, name) with Path(path).open("rb") as f: data = pickle.load(f) - os.remove(path) + ar = self.client._tracking_client._get_artifact_repo(self.id) + if isinstance(ar, AzureBlobArtifactRepository): + # for saving disk space + # For safety, only remove redundant file for specific ArtifactRepository + shutil.rmtree(Path(path).absolute().parent) return data except Exception as e: raise LoadObjectError(message=str(e))