From b85a5c224e95b4cd207e605c6ad80b54a2af98c6 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 12 Mar 2025 14:11:42 +0800 Subject: [PATCH] fixed pytest error in CI --- qlib/rl/trainer/trainer.py | 2 +- tests/rl/test_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index fb73dd549..a1046e966 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -200,7 +200,7 @@ class Trainer: if ckpt_path is not None: _logger.info("Resuming states from %s", str(ckpt_path)) - self.load_state_dict(torch.load(ckpt_path)) + self.load_state_dict(torch.load(ckpt_path, weights_only=False)) else: self.initialize() diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py index 751fbd387..f842d9781 100644 --- a/tests/rl/test_trainer.py +++ b/tests/rl/test_trainer.py @@ -194,7 +194,7 @@ def test_trainer_checkpoint(): assert (output_dir / "002.pth").exists() assert os.readlink(output_dir / "latest.pth") == str(output_dir / "002.pth") - trainer.load_state_dict(torch.load(output_dir / "001.pth")) + trainer.load_state_dict(torch.load(output_dir / "001.pth", weights_only=False)) assert trainer.current_iter == 1 assert trainer.current_episode == 100