diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index d150466c1..f6c659533 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -352,7 +352,8 @@ class TRAModel(Model): "model": copy.deepcopy(self.model.state_dict()), "tra": copy.deepcopy(self.tra.state_dict()), } - torch.save(best_params, self.logdir + "/model.bin") + if self.logdir is not None: + torch.save(best_params, self.logdir + "/model.bin") else: stop_rounds += 1 if stop_rounds >= self.early_stop: