diff --git a/qlib/contrib/model/pytorch_utils.py b/qlib/contrib/model/pytorch_utils.py index e7a8e8d67..1148a596a 100644 --- a/qlib/contrib/model/pytorch_utils.py +++ b/qlib/contrib/model/pytorch_utils.py @@ -4,7 +4,19 @@ import torch.nn as nn -def count_parameters(models_or_parameters, unit="mb"): +def count_parameters(models_or_parameters, unit="m"): + """ + This function is to obtain the storage size unit of a (or multiple) models. + + Parameters + ---------- + models_or_parameters : PyTorch model(s) or a list of parameters. + unit : the storage size unit. + + Returns + ------- + The number of parameters of the given model(s) or parameters. + """ if isinstance(models_or_parameters, nn.Module): counts = sum(v.numel() for v in models_or_parameters.parameters()) elif isinstance(models_or_parameters, nn.Parameter): @@ -13,12 +25,13 @@ def count_parameters(models_or_parameters, unit="mb"): return sum(count_parameters(x, unit) for x in models_or_parameters) else: counts = sum(v.numel() for v in models_or_parameters) - if unit.lower() == "mb": - counts /= 1e6 - elif unit.lower() == "kb": - counts /= 1e3 - elif unit.lower() == "gb": - counts /= 1e9 + unit = unit.lower() + if unit == "kb" or unit == "k": + counts /= 2 ** 10 + elif unit == "mb" or unit == "m": + counts /= 2 ** 20 + elif unit == "gb" or unit == "g": + counts /= 2 ** 30 elif unit is not None: raise ValueError("Unknow unit: {:}".format(unit)) return counts diff --git a/qlib/log.py b/qlib/log.py index 78f12eb09..126acb9d2 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -108,27 +108,6 @@ def set_log_with_config(log_config: Dict[Text, Any]): logging_config.dictConfig(log_config) -def set_log_basic_config(filename: Optional[Text] = None, format: Optional[Text] = None, level: Optional[int] = None): - """ - Set the basic configuration for the logging system. - See details at https://docs.python.org/3/library/logging.html#logging.basicConfig - - :param filename: str or None - The path to save the logs. - :param format: the logging format - :param level: int - :return: Logger - Logger object. - """ - if level is None: - level = C.logging_level - - if format is None: - format = C.logging_config["formatters"]["logger_format"]["format"] - - logging.basicConfig(filename=filename, format=format, level=level) - - class LogFilter(logging.Filter): def __init__(self, param=None): self.param = param diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 519b69710..bc0a9ef77 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -240,12 +240,18 @@ class MLflowRecorder(Recorder): def artifact_uri(self): return self._artifact_uri - @property - def root_uri(self): - start_str = "file:" + def get_local_dir(self): + """ + This function will return the directory path of this recorder. + """ if self.artifact_uri is not None: - xpath = self.artifact_uri.strip(start_str) - return (Path(xpath) / "..").resolve() + local_file_prefix = "file:" + if self.artifact_uri.startswith(local_file_prefix): + xpath = self.artifact_uri.lstrip(local_file_prefix) + return (Path(xpath) / "..").resolve() + else: + raise RuntimeError("This recorder is not saved in the local file system.") + else: raise Exception( "Please make sure the recorder has been created and started properly before getting artifact uri." diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index d9d684697..fbf15d29a 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -123,6 +123,8 @@ def train(): recorder = R.get_recorder() # To test __repr__ print(recorder) + # To test get_local_dir + print(recorder.get_local_dir()) rid = recorder.id sr = SignalRecord(model, dataset, recorder) sr.generate()