1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

fix workflow bug (#882)

* fix workflow bug

* Fix output of pytorch NN

* Fix parameter bug
This commit is contained in:
you-n-g
2022-01-22 10:18:37 +08:00
committed by GitHub
parent d533219738
commit 01afd06e18
4 changed files with 15 additions and 9 deletions

View File

@@ -240,7 +240,7 @@ class DNNModelPytorch(Model):
R.log_metrics(val_loss=loss_val.val, step=step)
if verbose:
self.logger.info(
"[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
"[Step {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
)
evals_result["train"].append(train_loss)
evals_result["valid"].append(loss_val.val)

View File

@@ -37,8 +37,8 @@ def _log_task_info(task_config: dict):
def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
reweighter: Reweighter = task_config.get("reweighter", None)
# model training
auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)

View File

@@ -226,7 +226,9 @@ class QlibRecorder:
"""
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
def get_exp(
self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False
) -> Experiment:
"""
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
@@ -291,6 +293,10 @@ class QlibRecorder:
create : boolean
an argument determines whether the method will automatically create a new experiment
according to user's specification if the experiment hasn't been created before.
start : bool
when start is True,
if the experiment has not started(not activated), it will start
It is designed for R.log_params to auto start experiments
Returns
-------
@@ -300,7 +306,7 @@ class QlibRecorder:
experiment_id=experiment_id,
experiment_name=experiment_name,
create=create,
start=False,
start=start,
)
def delete_exp(self, experiment_id=None, experiment_name=None):
@@ -542,7 +548,7 @@ class QlibRecorder:
keyword argument:
name1=value1, name2=value2, ...
"""
self.get_exp().get_recorder(start=True).log_params(**kwargs)
self.get_exp(start=True).get_recorder(start=True).log_params(**kwargs)
def log_metrics(self, step=None, **kwargs):
"""
@@ -567,7 +573,7 @@ class QlibRecorder:
keyword argument:
name1=value1, name2=value2, ...
"""
self.get_exp().get_recorder(start=True).log_metrics(step, **kwargs)
self.get_exp(start=True).get_recorder(start=True).log_metrics(step, **kwargs)
def set_tags(self, **kwargs):
"""
@@ -592,7 +598,7 @@ class QlibRecorder:
keyword argument:
name1=value1, name2=value2, ...
"""
self.get_exp().get_recorder(start=True).set_tags(**kwargs)
self.get_exp(start=True).get_recorder(start=True).set_tags(**kwargs)
class RecorderWrapper(Wrapper):

View File

@@ -178,7 +178,7 @@ class ExpManager:
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
False,
)
if is_new and start:
if self.active_experiment is None and start:
self.active_experiment = exp
# start the recorder
self.active_experiment.start()