diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index fb73dd549..a1046e966 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -200,7 +200,7 @@ class Trainer: if ckpt_path is not None: _logger.info("Resuming states from %s", str(ckpt_path)) - self.load_state_dict(torch.load(ckpt_path)) + self.load_state_dict(torch.load(ckpt_path, weights_only=False)) else: self.initialize() diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py index 751fbd387..f842d9781 100644 --- a/tests/rl/test_trainer.py +++ b/tests/rl/test_trainer.py @@ -194,7 +194,7 @@ def test_trainer_checkpoint(): assert (output_dir / "002.pth").exists() assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth") - trainer.load_state_dict(torch.load(output_dir / "001.pth")) + trainer.load_state_dict(torch.load(output_dir / "001.pth", weights_only=False)) assert trainer.current_iter == 1 assert trainer.current_episode == 100