1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

Remove set_log_basic_config, refine count_parameters, rename root_uri as get_local_dir

This commit is contained in:
D-X-Y
2021-03-11 02:33:00 +00:00
parent e061443560
commit f6ed175070
4 changed files with 33 additions and 33 deletions

View File

@@ -4,7 +4,19 @@
import torch.nn as nn
def count_parameters(models_or_parameters, unit="mb"):
def count_parameters(models_or_parameters, unit="m"):
"""
This function is to obtain the storage size unit of a (or multiple) models.
Parameters
----------
models_or_parameters : PyTorch model(s) or a list of parameters.
unit : the storage size unit.
Returns
-------
The number of parameters of the given model(s) or parameters.
"""
if isinstance(models_or_parameters, nn.Module):
counts = sum(v.numel() for v in models_or_parameters.parameters())
elif isinstance(models_or_parameters, nn.Parameter):
@@ -13,12 +25,13 @@ def count_parameters(models_or_parameters, unit="mb"):
return sum(count_parameters(x, unit) for x in models_or_parameters)
else:
counts = sum(v.numel() for v in models_or_parameters)
if unit.lower() == "mb":
counts /= 1e6
elif unit.lower() == "kb":
counts /= 1e3
elif unit.lower() == "gb":
counts /= 1e9
unit = unit.lower()
if unit == "kb" or unit == "k":
counts /= 2 ** 10
elif unit == "mb" or unit == "m":
counts /= 2 ** 20
elif unit == "gb" or unit == "g":
counts /= 2 ** 30
elif unit is not None:
raise ValueError("Unknow unit: {:}".format(unit))
return counts

View File

@@ -108,27 +108,6 @@ def set_log_with_config(log_config: Dict[Text, Any]):
logging_config.dictConfig(log_config)
def set_log_basic_config(filename: Optional[Text] = None, format: Optional[Text] = None, level: Optional[int] = None):
"""
Set the basic configuration for the logging system.
See details at https://docs.python.org/3/library/logging.html#logging.basicConfig
:param filename: str or None
The path to save the logs.
:param format: the logging format
:param level: int
:return: Logger
Logger object.
"""
if level is None:
level = C.logging_level
if format is None:
format = C.logging_config["formatters"]["logger_format"]["format"]
logging.basicConfig(filename=filename, format=format, level=level)
class LogFilter(logging.Filter):
def __init__(self, param=None):
self.param = param

View File

@@ -240,12 +240,18 @@ class MLflowRecorder(Recorder):
def artifact_uri(self):
return self._artifact_uri
@property
def root_uri(self):
start_str = "file:"
def get_local_dir(self):
"""
This function will return the directory path of this recorder.
"""
if self.artifact_uri is not None:
xpath = self.artifact_uri.strip(start_str)
return (Path(xpath) / "..").resolve()
local_file_prefix = "file:"
if self.artifact_uri.startswith(local_file_prefix):
xpath = self.artifact_uri.lstrip(local_file_prefix)
return (Path(xpath) / "..").resolve()
else:
raise RuntimeError("This recorder is not saved in the local file system.")
else:
raise Exception(
"Please make sure the recorder has been created and started properly before getting artifact uri."

View File

@@ -123,6 +123,8 @@ def train():
recorder = R.get_recorder()
# To test __repr__
print(recorder)
# To test get_local_dir
print(recorder.get_local_dir())
rid = recorder.id
sr = SignalRecord(model, dataset, recorder)
sr.generate()