mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
Remove set_log_basic_config, refine count_parameters, rename root_uri as get_local_dir
This commit is contained in:
@@ -4,7 +4,19 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def count_parameters(models_or_parameters, unit="mb"):
|
||||
def count_parameters(models_or_parameters, unit="m"):
|
||||
"""
|
||||
This function is to obtain the storage size unit of a (or multiple) models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_or_parameters : PyTorch model(s) or a list of parameters.
|
||||
unit : the storage size unit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The number of parameters of the given model(s) or parameters.
|
||||
"""
|
||||
if isinstance(models_or_parameters, nn.Module):
|
||||
counts = sum(v.numel() for v in models_or_parameters.parameters())
|
||||
elif isinstance(models_or_parameters, nn.Parameter):
|
||||
@@ -13,12 +25,13 @@ def count_parameters(models_or_parameters, unit="mb"):
|
||||
return sum(count_parameters(x, unit) for x in models_or_parameters)
|
||||
else:
|
||||
counts = sum(v.numel() for v in models_or_parameters)
|
||||
if unit.lower() == "mb":
|
||||
counts /= 1e6
|
||||
elif unit.lower() == "kb":
|
||||
counts /= 1e3
|
||||
elif unit.lower() == "gb":
|
||||
counts /= 1e9
|
||||
unit = unit.lower()
|
||||
if unit == "kb" or unit == "k":
|
||||
counts /= 2 ** 10
|
||||
elif unit == "mb" or unit == "m":
|
||||
counts /= 2 ** 20
|
||||
elif unit == "gb" or unit == "g":
|
||||
counts /= 2 ** 30
|
||||
elif unit is not None:
|
||||
raise ValueError("Unknow unit: {:}".format(unit))
|
||||
return counts
|
||||
|
||||
21
qlib/log.py
21
qlib/log.py
@@ -108,27 +108,6 @@ def set_log_with_config(log_config: Dict[Text, Any]):
|
||||
logging_config.dictConfig(log_config)
|
||||
|
||||
|
||||
def set_log_basic_config(filename: Optional[Text] = None, format: Optional[Text] = None, level: Optional[int] = None):
|
||||
"""
|
||||
Set the basic configuration for the logging system.
|
||||
See details at https://docs.python.org/3/library/logging.html#logging.basicConfig
|
||||
|
||||
:param filename: str or None
|
||||
The path to save the logs.
|
||||
:param format: the logging format
|
||||
:param level: int
|
||||
:return: Logger
|
||||
Logger object.
|
||||
"""
|
||||
if level is None:
|
||||
level = C.logging_level
|
||||
|
||||
if format is None:
|
||||
format = C.logging_config["formatters"]["logger_format"]["format"]
|
||||
|
||||
logging.basicConfig(filename=filename, format=format, level=level)
|
||||
|
||||
|
||||
class LogFilter(logging.Filter):
|
||||
def __init__(self, param=None):
|
||||
self.param = param
|
||||
|
||||
@@ -240,12 +240,18 @@ class MLflowRecorder(Recorder):
|
||||
def artifact_uri(self):
|
||||
return self._artifact_uri
|
||||
|
||||
@property
|
||||
def root_uri(self):
|
||||
start_str = "file:"
|
||||
def get_local_dir(self):
|
||||
"""
|
||||
This function will return the directory path of this recorder.
|
||||
"""
|
||||
if self.artifact_uri is not None:
|
||||
xpath = self.artifact_uri.strip(start_str)
|
||||
return (Path(xpath) / "..").resolve()
|
||||
local_file_prefix = "file:"
|
||||
if self.artifact_uri.startswith(local_file_prefix):
|
||||
xpath = self.artifact_uri.lstrip(local_file_prefix)
|
||||
return (Path(xpath) / "..").resolve()
|
||||
else:
|
||||
raise RuntimeError("This recorder is not saved in the local file system.")
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
"Please make sure the recorder has been created and started properly before getting artifact uri."
|
||||
|
||||
@@ -123,6 +123,8 @@ def train():
|
||||
recorder = R.get_recorder()
|
||||
# To test __repr__
|
||||
print(recorder)
|
||||
# To test get_local_dir
|
||||
print(recorder.get_local_dir())
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
Reference in New Issue
Block a user