1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00
Files
qlib/tests/rl/test_finite_env.py
Yuge Zhang 9a40fd3cdc Qlib RL framework (stage 1) - single-asset order execution (#1076)
* rl init

* aux info

* Reward config

* update

* simple

* update saoe init

* update simulator and seed

* minor

* minor

* update sim

* checkpoint

* obs

* Update interpreter

* init qlib simulator

* checkpoint

* Refine codebase

* checkpoint

* checkpoint

* Add one test

* More tests

* Simulator checkpoint

* checkpoint

* First-step tested

* Checkpoint

* Update data_queue API

* Checkpoint

* Update test

* Move files

* Checkpoint

* Single-quote -> double-quote

* Fix finite env tests

* Tested with mypy

* pep-574

* No call for env done

* Update finite env docs

* Fix csv writer

* Refine tester

* Update logger

* Add another logger test

* Checkpoint

* Add network sanity test

* steps per episode is not correct

* Cleanup code, ready for PR

* Reformat with black

* Fix pylint for py37

* Fix lint

* Fix lint

* Fix flake

* update mypy command

* mypy

* Update exclude pattern

* Use pyproject.toml

* test

* .

* .

* Refactor pipeline

* .

* defaults run bash

* .

* Revert and skip follow_imports

* Fix toml issue

* fix mypy

* .

* .

* .

* Fix install

* Minor fix

* Fix test

* Fix test

* Remove requirements

* Revert

* fix tests

* Fix lint

* .

* .

* .

* .

* .

* update install from source command

* .

* Fix data download

* .

* .

* .

* .

* .

* .

* Fix py37

* Ignore tests on non-linux

* resolve comments

* fix tests

* resolve comments

* some typo

* style updates

* More comments

* fix dummy

* add warning

* Align precision in some system

* Added some impl notes

Co-authored-by: Young <afe.young@gmail.com>
2022-05-21 18:19:24 +08:00

250 lines
7.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections import Counter
import gym
import numpy as np
from tianshou.data import Batch, Collector
from tianshou.policy import BasePolicy
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from qlib.rl.utils.finite_env import (
LogWriter,
FiniteDummyVectorEnv,
FiniteShmemVectorEnv,
FiniteSubprocVectorEnv,
check_nan_observation,
generate_nan_observation,
)
_test_space = gym.spaces.Dict(
{
"sensors": gym.spaces.Dict(
{
"position": gym.spaces.Box(low=-100, high=100, shape=(3,)),
"velocity": gym.spaces.Box(low=-1, high=1, shape=(3,)),
"front_cam": gym.spaces.Tuple(
(gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)))
),
"rear_cam": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
}
),
"ext_controller": gym.spaces.MultiDiscrete((5, 2, 2)),
"inner_state": gym.spaces.Dict(
{
"charge": gym.spaces.Discrete(100),
"system_checks": gym.spaces.MultiBinary(10),
"job_status": gym.spaces.Dict(
{
"task": gym.spaces.Discrete(5),
"progress": gym.spaces.Box(low=0, high=100, shape=()),
}
),
}
),
}
)
class FiniteEnv(gym.Env):
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
self.iterator = None
self.observation_space = gym.spaces.Discrete(255)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
if self.iterator is None:
self.iterator = iter(self.loader)
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
return self.current_sample
except StopIteration:
self.iterator = None
return generate_nan_observation(self.observation_space)
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
return (
0,
1.0,
self.current_step >= self.step_count,
{"sample": self.current_sample, "action": action, "metric": 2.0},
)
class FiniteEnvWithComplexObs(FiniteEnv):
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None)
self.iterator = None
self.observation_space = gym.spaces.Discrete(255)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
if self.iterator is None:
self.iterator = iter(self.loader)
try:
self.current_sample, self.step_count = next(self.iterator)
self.current_step = 0
return _test_space.sample()
except StopIteration:
self.iterator = None
return generate_nan_observation(self.observation_space)
def step(self, action):
self.current_step += 1
assert self.current_step <= self.step_count
return (
_test_space.sample(),
1.0,
self.current_step >= self.step_count,
{"sample": _test_space.sample(), "action": action, "metric": 2.0},
)
class DummyDataset(Dataset):
def __init__(self, length):
self.length = length
self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
def __getitem__(self, index):
assert 0 <= index < self.length
return index, self.episodes[index]
def __len__(self):
return self.length
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def _finite_env_factory(dataset, num_replicas, rank, complex=False):
if complex:
return lambda: FiniteEnvWithComplexObs(dataset, num_replicas, rank)
return lambda: FiniteEnv(dataset, num_replicas, rank)
class MetricTracker(LogWriter):
def __init__(self, length):
super().__init__()
self.counter = Counter()
self.finished = set()
self.length = length
def on_env_step(self, env_id, obs, rew, done, info):
assert rew == 1.0
index = info["sample"]
if done:
# assert index not in self.finished
self.finished.add(index)
self.counter[index] += 1
def validate(self):
assert len(self.finished) == self.length
for k, v in self.counter.items():
assert v == k * 3 % 5 + 1
class DoNothingTracker(LogWriter):
def on_env_step(self, *args, **kwargs):
pass
def test_finite_dummy_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_finite_shmem_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_finite_subproc_vector_env():
length = 100
dataset = DummyDataset(length)
envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)])
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(1):
envs._logger = [MetricTracker(length)]
try:
test_collector.collect(n_step=10**18)
except StopIteration:
envs._logger[0].validate()
def test_nan():
assert check_nan_observation(generate_nan_observation(_test_space))
assert not check_nan_observation(_test_space.sample())
def test_finite_dummy_vector_env_complex():
length = 100
dataset = DummyDataset(length)
envs = FiniteDummyVectorEnv(
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
)
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
try:
test_collector.collect(n_step=10**18)
except StopIteration:
pass
def test_finite_shmem_vector_env_complex():
length = 100
dataset = DummyDataset(length)
envs = FiniteShmemVectorEnv(
DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)]
)
envs._collector_guarded = True
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
try:
test_collector.collect(n_step=10**18)
except StopIteration:
pass