1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/tests/rl/test_finite_env.py
Linlang 3097dcc995 fix(security): use RestrictedUnpickler in load_instance (#2153)
* fix(security): enforce RestrictedUnpickler for load_instance to prevent unsafe pickle deserialization

* fix: lint error
2026-03-10 20:45:38 +08:00

249 lines
7.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import Counter
import gym
import numpy as np
from tianshou.data import Batch, Collector
from tianshou.policy import BasePolicy
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from qlib.rl.utils.finite_env import (
LogWriter,
FiniteDummyVectorEnv,
FiniteShmemVectorEnv,
FiniteSubprocVectorEnv,
check_nan_observation,
generate_nan_observation,
)
_test_space = gym.spaces.Dict(
{
"sensors": gym.spaces.Dict(
{
"position": gym.spaces.Box(low=-100, high=100, shape=(3,)),
"velocity": gym.spaces.Box(low=-1, high=1, shape=(3,)),
"front_cam": gym.spaces.Tuple(
(gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)))
),
"rear_cam": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
}
),
"ext_controller": gym.spaces.MultiDiscrete((5, 2, 2)),
"inner_state": gym.spaces.Dict(
{
"charge": gym.spaces.Discrete(100),
"system_checks": gym.spaces.MultiBinary(10),
"job_status": gym.spaces.Dict(
{
"task": gym.spaces.Discrete(5),
"progress": gym.spaces.Box(low=0, high=100, shape=()),
}
),
}
),
}
)
class FiniteEnv(gym.Env):
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
self.iterator = None
self.observation_space = gym.spaces.Discrete(255)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
if self.iterator is None:
self.iterator = iter(self.loader)
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
return self.current_sample
except StopIteration:
self.iterator = None
return generate_nan_observation(self.observation_space)
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
return (
0,
1.0,
self.current_step >= self.step_count,
{"sample": self.current_sample, "action": action, "metric": 2.0},
)
class FiniteEnvWithComplexObs(FiniteEnv):
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
self.iterator = None
self.observation_space = gym.spaces.Discrete(255)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
if self.iterator is None:
self.iterator = iter(self.loader)
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
return _test_space.sample()
except StopIteration:
self.iterator = None
return generate_nan_observation(self.observation_space)
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
return (
_test_space.sample(),
1.0,
self.current_step >= self.step_count,
{"sample": _test_space.sample(), "action": action, "metric": 2.0},
)
class DummyDataset(Dataset):
def __init__(self, length):
self.length = length
self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
def __getitem__(self, index):
assert 0 <= index < self.length
return index, self.episodes[index]
def __len__(self):
return self.length
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def _finite_env_factory(dataset, num_replicas, rank, complex=False):
if complex:
return lambda: FiniteEnvWithComplexObs(dataset, num_replicas, rank)
return lambda: FiniteEnv(dataset, num_replicas, rank)
class MetricTracker(LogWriter):
def __init__(self, length):
super().__init__()
self.counter = Counter()
self.finished = set()
self.length = length
def on_env_step(self, env_id, obs, rew, done, info):
assert rew == 1.0
index = info["sample"]
if done:
# assert index not in self.finished
self.finished.add(index)
self.counter[index] += 1
def validate(self):
assert len(self.finished) == self.length
for k, v in self.counter.items():
assert v == k * 3 % 5 + 1
class DoNothingTracker(LogWriter):
def on_env_step(self, *args, **kwargs):
pass
def test_finite_dummy_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_finite_shmem_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_finite_subproc_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_nan():
assert check_nan_observation(generate_nan_observation(_test_space))
assert not check_nan_observation(_test_space.sample())
def test_finite_dummy_vector_env_complex():
length = 100
dataset = DummyDataset(length)
envs = FiniteDummyVectorEnv(
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
)
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
try:
test_collector.collect(n_step=10**18)
except StopIteration:
pass
def test_finite_shmem_vector_env_complex():
length = 100
dataset = DummyDataset(length)
envs = FiniteShmemVectorEnv(
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
)
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
try:
test_collector.collect(n_step=10**18)
except StopIteration:
pass