diff --git a/qlib/rl/trainer/vessel.py b/qlib/rl/trainer/vessel.py index 6cd2eb3e9..b7912b488 100644 --- a/qlib/rl/trainer/vessel.py +++ b/qlib/rl/trainer/vessel.py @@ -168,7 +168,9 @@ class TrainingVessel(TrainingVesselBase): self.policy.train() with vector_env.collector_guard(): - collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env))) + collector = Collector( + self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)), exploration_noise=True + ) # Number of episodes collected in each training iteration can be overridden by fast dev run. if self.trainer.fast_dev_run is not None: