mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
fix workflow bug (#882)
* fix workflow bug * Fix output of pytorch NN * Fix parameter bug
This commit is contained in:
@@ -240,7 +240,7 @@ class DNNModelPytorch(Model):
|
||||
R.log_metrics(val_loss=loss_val.val, step=step)
|
||||
if verbose:
|
||||
self.logger.info(
|
||||
"[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
|
||||
"[Step {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
|
||||
)
|
||||
evals_result["train"].append(train_loss)
|
||||
evals_result["valid"].append(loss_val.val)
|
||||
|
||||
@@ -37,8 +37,8 @@ def _log_task_info(task_config: dict):
|
||||
def _exe_task(task_config: dict):
|
||||
rec = R.get_recorder()
|
||||
# model & dataset initiation
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
model: Model = init_instance_by_config(task_config["model"], accept_types=Model)
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset)
|
||||
reweighter: Reweighter = task_config.get("reweighter", None)
|
||||
# model training
|
||||
auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
|
||||
|
||||
@@ -226,7 +226,9 @@ class QlibRecorder:
|
||||
"""
|
||||
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
|
||||
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
def get_exp(
|
||||
self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False
|
||||
) -> Experiment:
|
||||
"""
|
||||
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
|
||||
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
|
||||
@@ -291,6 +293,10 @@ class QlibRecorder:
|
||||
create : boolean
|
||||
an argument determines whether the method will automatically create a new experiment
|
||||
according to user's specification if the experiment hasn't been created before.
|
||||
start : bool
|
||||
when start is True,
|
||||
if the experiment has not started(not activated), it will start
|
||||
It is designed for R.log_params to auto start experiments
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -300,7 +306,7 @@ class QlibRecorder:
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=experiment_name,
|
||||
create=create,
|
||||
start=False,
|
||||
start=start,
|
||||
)
|
||||
|
||||
def delete_exp(self, experiment_id=None, experiment_name=None):
|
||||
@@ -542,7 +548,7 @@ class QlibRecorder:
|
||||
keyword argument:
|
||||
name1=value1, name2=value2, ...
|
||||
"""
|
||||
self.get_exp().get_recorder(start=True).log_params(**kwargs)
|
||||
self.get_exp(start=True).get_recorder(start=True).log_params(**kwargs)
|
||||
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
"""
|
||||
@@ -567,7 +573,7 @@ class QlibRecorder:
|
||||
keyword argument:
|
||||
name1=value1, name2=value2, ...
|
||||
"""
|
||||
self.get_exp().get_recorder(start=True).log_metrics(step, **kwargs)
|
||||
self.get_exp(start=True).get_recorder(start=True).log_metrics(step, **kwargs)
|
||||
|
||||
def set_tags(self, **kwargs):
|
||||
"""
|
||||
@@ -592,7 +598,7 @@ class QlibRecorder:
|
||||
keyword argument:
|
||||
name1=value1, name2=value2, ...
|
||||
"""
|
||||
self.get_exp().get_recorder(start=True).set_tags(**kwargs)
|
||||
self.get_exp(start=True).get_recorder(start=True).set_tags(**kwargs)
|
||||
|
||||
|
||||
class RecorderWrapper(Wrapper):
|
||||
|
||||
@@ -178,7 +178,7 @@ class ExpManager:
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
if is_new and start:
|
||||
if self.active_experiment is None and start:
|
||||
self.active_experiment = exp
|
||||
# start the recorder
|
||||
self.active_experiment.start()
|
||||
|
||||
Reference in New Issue
Block a user