mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
* rl init * aux info * Reward config * update * simple * update saoe init * update simulator and seed * minor * minor * update sim * checkpoint * obs * Update interpreter * init qlib simulator * checkpoint * Refine codebase * checkpoint * checkpoint * Add one test * More tests * Simulator checkpoint * checkpoint * First-step tested * Checkpoint * Update data_queue API * Checkpoint * Update test * Move files * Checkpoint * Single-quote -> double-quote * Fix finite env tests * Tested with mypy * pep-574 * No call for env done * Update finite env docs * Fix csv writer * Refine tester * Update logger * Add another logger test * Checkpoint * Add network sanity test * steps per episode is not correct * Cleanup code, ready for PR * Reformat with black * Fix pylint for py37 * Fix lint * Fix lint * Fix flake * update mypy command * mypy * Update exclude pattern * Use pyproject.toml * test * . * . * Refactor pipeline * . * defaults run bash * . * Revert and skip follow_imports * Fix toml issue * fix mypy * . * . * . * Fix install * Minor fix * Fix test * Fix test * Remove requirements * Revert * fix tests * Fix lint * . * . * . * . * . * update install from source command * . * Fix data download * . * . * . * . * . * . * Fix py37 * Ignore tests on non-linux * resolve comments * fix tests * resolve comments * some typo * style updates * More comments * fix dummy * add warning * Align precision in some system * Added some impl notes Co-authored-by: Young <afe.young@gmail.com>
338 lines
12 KiB
Python
338 lines
12 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
"""
|
|
This is to support finite env in vector env.
|
|
See https://github.com/thu-ml/tianshou/issues/322 for details.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
|
|
import gym
|
|
import numpy as np
|
|
from typing import Any, Set, Callable, Type
|
|
|
|
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
|
|
|
|
from qlib.typehint import Literal
|
|
from .log import LogWriter
|
|
|
|
__all__ = [
|
|
"generate_nan_observation",
|
|
"check_nan_observation",
|
|
"FiniteVectorEnv",
|
|
"FiniteDummyVectorEnv",
|
|
"FiniteSubprocVectorEnv",
|
|
"FiniteShmemVectorEnv",
|
|
"FiniteEnvType",
|
|
"vectorize_env",
|
|
]
|
|
|
|
|
|
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
|
|
|
|
|
|
def fill_invalid(obj):
|
|
if isinstance(obj, (int, float, bool)):
|
|
return fill_invalid(np.array(obj))
|
|
if hasattr(obj, "dtype"):
|
|
if isinstance(obj, np.ndarray):
|
|
if np.issubdtype(obj.dtype, np.floating):
|
|
return np.full_like(obj, np.nan)
|
|
return np.full_like(obj, np.iinfo(obj.dtype).max)
|
|
# dealing with corner cases that numpy number is not supported by tianshou's sharray
|
|
return fill_invalid(np.array(obj))
|
|
elif isinstance(obj, dict):
|
|
return {k: fill_invalid(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [fill_invalid(v) for v in obj]
|
|
elif isinstance(obj, tuple):
|
|
return tuple(fill_invalid(v) for v in obj)
|
|
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
|
|
|
|
|
|
def is_invalid(arr):
|
|
if hasattr(arr, "dtype"):
|
|
if np.issubdtype(arr.dtype, np.floating):
|
|
return np.isnan(arr).all()
|
|
return (np.iinfo(arr.dtype).max == arr).all()
|
|
if isinstance(arr, dict):
|
|
return all(is_invalid(o) for o in arr.values())
|
|
if isinstance(arr, (list, tuple)):
|
|
return all(is_invalid(o) for o in arr)
|
|
if isinstance(arr, (int, float, bool, np.number)):
|
|
return is_invalid(np.array(arr))
|
|
return True
|
|
|
|
|
|
def generate_nan_observation(obs_space: gym.Space) -> Any:
|
|
"""The NaN observation that indicates the environment receives no seed.
|
|
|
|
We assume that obs is complex and there must be something like float.
|
|
Otherwise this logic doesn't work.
|
|
"""
|
|
|
|
sample = obs_space.sample()
|
|
sample = fill_invalid(sample)
|
|
return sample
|
|
|
|
|
|
def check_nan_observation(obs: Any) -> bool:
|
|
"""Check whether obs is generated by :func:`generate_nan_observation`."""
|
|
return is_invalid(obs)
|
|
|
|
|
|
class FiniteVectorEnv(BaseVectorEnv):
|
|
"""To allow the paralleled env workers consume a single DataQueue until it's exhausted.
|
|
|
|
See `tianshou issue #322 <https://github.com/thu-ml/tianshou/issues/322>`_.
|
|
|
|
The requirement is to make every possible seed (stored in :class:`qlib.rl.utils.DataQueue` in our case)
|
|
consumed by exactly one environment. This is not possible by tianshou's native VectorEnv and Collector,
|
|
because tianshou is unaware of this "exactly one" constraint, and might launch extra workers.
|
|
|
|
Consider a corner case, where concurrency is 2, but there is only one seed in DataQueue.
|
|
The reset of two workers must be both called according to the logic in collect.
|
|
The returned results of two workers are collected, regardless of what they are.
|
|
The problem is, one of the reset result must be invalid, or repeated,
|
|
because there's only one need in queue, and collector isn't aware of such situation.
|
|
|
|
Luckily, we can hack the vector env, and make a protocol between single env and vector env.
|
|
The single environment (should be :class:`qlib.rl.utils.EnvWrapper` in our case) is responsible for
|
|
reading from queue, and generate a special observation when the queue is exhausted. The special obs
|
|
is called "nan observation", because simply using none causes problems in shared-memory vector env.
|
|
:class:`FiniteVectorEnv` then read the observations from all workers, and select those non-nan
|
|
observation. It also maintains an ``_alive_env_ids`` to track which workers should never be
|
|
called again. When also the environments are exhausted, it will raise StopIteration exception.
|
|
|
|
The usage of this vector env in collector are two parts:
|
|
|
|
1. If the data queue is finite (usually when inference), collector should collect "infinity" number of
|
|
episodes, until the vector env exhausts by itself.
|
|
2. If the data queue is infinite (usually in training), collector can set number of episodes / steps.
|
|
In this case, data would be randomly ordered, and some repetitions wouldn't matter.
|
|
|
|
One extra function of this vector env is that it has a logger that explicitly collects logs
|
|
from child workers. See :class:`qlib.rl.utils.LogWriter`.
|
|
"""
|
|
|
|
def __init__(
|
|
self, logger: LogWriter | list[LogWriter], env_fns: list[Callable[..., gym.Env]], **kwargs: Any
|
|
) -> None:
|
|
super().__init__(env_fns, **kwargs)
|
|
|
|
self._logger: list[LogWriter] = logger if isinstance(logger, list) else [logger]
|
|
self._alive_env_ids: Set[int] = set()
|
|
self._reset_alive_envs()
|
|
self._default_obs = self._default_info = self._default_rew = None
|
|
self._zombie = False
|
|
|
|
self._collector_guarded: bool = False
|
|
|
|
def _reset_alive_envs(self):
|
|
if not self._alive_env_ids:
|
|
# starting or running out
|
|
self._alive_env_ids = set(range(self.env_num))
|
|
|
|
# to workaround with tianshou's buffer and batch
|
|
def _set_default_obs(self, obs):
|
|
if obs is not None and self._default_obs is None:
|
|
self._default_obs = copy.deepcopy(obs)
|
|
|
|
def _set_default_info(self, info):
|
|
if info is not None and self._default_info is None:
|
|
self._default_info = copy.deepcopy(info)
|
|
|
|
def _set_default_rew(self, rew):
|
|
if rew is not None and self._default_rew is None:
|
|
self._default_rew = copy.deepcopy(rew)
|
|
|
|
def _get_default_obs(self):
|
|
return copy.deepcopy(self._default_obs)
|
|
|
|
def _get_default_info(self):
|
|
return copy.deepcopy(self._default_info)
|
|
|
|
def _get_default_rew(self):
|
|
return copy.deepcopy(self._default_rew)
|
|
|
|
# END
|
|
|
|
@staticmethod
|
|
def _postproc_env_obs(obs):
|
|
# reserved for shmem vector env to restore empty observation
|
|
if obs is None or check_nan_observation(obs):
|
|
return None
|
|
return obs
|
|
|
|
@contextmanager
|
|
def collector_guard(self):
|
|
"""Guard the collector. Recommended to guard every collect.
|
|
|
|
This guard is for two purposes.
|
|
|
|
1. Catch and ignore the StopIteration exception, which is the stopping signal
|
|
thrown by FiniteEnv to let tianshou know that ``collector.collect()`` should exit.
|
|
2. Notify the loggers that the collect is done what it's done.
|
|
|
|
Examples
|
|
--------
|
|
>>> with finite_env.collector_guard():
|
|
... collector.collect(n_episode=INF)
|
|
"""
|
|
self._collector_guarded = True
|
|
|
|
try:
|
|
yield self
|
|
except StopIteration:
|
|
pass
|
|
finally:
|
|
self._collector_guarded = False
|
|
|
|
# At last trigger the loggers
|
|
for logger in self._logger:
|
|
logger.on_env_all_done()
|
|
|
|
def reset(self, id=None):
|
|
assert not self._zombie
|
|
|
|
# Check whether it's guarded by collector_guard()
|
|
if not self._collector_guarded:
|
|
warnings.warn(
|
|
"Collector is not guarded by FiniteEnv. "
|
|
"This may cause unexpected problems, like unexpected StopIteration exception, "
|
|
"or missing logs.",
|
|
RuntimeWarning,
|
|
)
|
|
|
|
id = self._wrap_id(id)
|
|
self._reset_alive_envs()
|
|
|
|
# ask super to reset alive envs and remap to current index
|
|
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
|
obs = [None] * len(id)
|
|
id2idx = {i: k for k, i in enumerate(id)}
|
|
if request_id:
|
|
for i, o in zip(request_id, super().reset(request_id)):
|
|
obs[id2idx[i]] = self._postproc_env_obs(o)
|
|
|
|
for i, o in zip(id, obs):
|
|
if o is None and i in self._alive_env_ids:
|
|
self._alive_env_ids.remove(i)
|
|
|
|
# logging
|
|
for i, o in zip(id, obs):
|
|
if i in self._alive_env_ids:
|
|
for logger in self._logger:
|
|
logger.on_env_reset(i, obs)
|
|
|
|
# fill empty observation with default(fake) observation
|
|
for o in obs:
|
|
self._set_default_obs(o)
|
|
for i, o in enumerate(obs):
|
|
if o is None:
|
|
obs[i] = self._get_default_obs()
|
|
|
|
if not self._alive_env_ids:
|
|
# comment this line so that the env becomes indisposable
|
|
# self.reset()
|
|
self._zombie = True
|
|
raise StopIteration
|
|
|
|
return np.stack(obs)
|
|
|
|
def step(self, action, id=None):
|
|
assert not self._zombie
|
|
id = self._wrap_id(id)
|
|
id2idx = {i: k for k, i in enumerate(id)}
|
|
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
|
result = [[None, None, False, None] for _ in range(len(id))]
|
|
|
|
# ask super to step alive envs and remap to current index
|
|
if request_id:
|
|
valid_act = np.stack([action[id2idx[i]] for i in request_id])
|
|
for i, r in zip(request_id, zip(*super().step(valid_act, request_id))):
|
|
result[id2idx[i]] = list(r)
|
|
result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
|
|
|
|
# logging
|
|
for i, r in zip(id, result):
|
|
if i in self._alive_env_ids:
|
|
for logger in self._logger:
|
|
logger.on_env_step(i, *r)
|
|
|
|
# fill empty observation/info with default(fake)
|
|
for _, r, ___, i in result:
|
|
self._set_default_info(i)
|
|
self._set_default_rew(r)
|
|
for i, r in enumerate(result):
|
|
if r[0] is None:
|
|
result[i][0] = self._get_default_obs()
|
|
if r[1] is None:
|
|
result[i][1] = self._get_default_rew()
|
|
if r[3] is None:
|
|
result[i][3] = self._get_default_info()
|
|
|
|
return list(map(np.stack, zip(*result)))
|
|
|
|
|
|
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
|
|
pass
|
|
|
|
|
|
class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
|
|
pass
|
|
|
|
|
|
class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv):
|
|
pass
|
|
|
|
|
|
def vectorize_env(
|
|
env_factory: Callable[..., gym.Env],
|
|
env_type: FiniteEnvType,
|
|
concurrency: int,
|
|
logger: LogWriter | list[LogWriter],
|
|
) -> FiniteVectorEnv:
|
|
"""Helper function to create a vector env.
|
|
|
|
Parameters
|
|
----------
|
|
env_factory
|
|
Callable to instantiate one single ``gym.Env``.
|
|
All concurrent workers will have the same ``env_factory``.
|
|
env_type
|
|
dummy or subproc or shmem. Corresponding to
|
|
`parallelism in tianshou <https://tianshou.readthedocs.io/en/master/api/tianshou.env.html#vectorenv>`_.
|
|
concurrency
|
|
Concurrent environment workers.
|
|
logger
|
|
Log writers.
|
|
|
|
Warnings
|
|
--------
|
|
Please do not use lambda expression here for ``env_factory`` as it may create incorrectly-shared instances.
|
|
|
|
Don't do: ::
|
|
|
|
vectorize_env(lambda: EnvWrapper(...), ...)
|
|
|
|
Please do: ::
|
|
|
|
def env_factory(): ...
|
|
vectorize_env(env_factory, ...)
|
|
"""
|
|
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
|
|
"dummy": FiniteDummyVectorEnv,
|
|
"subproc": FiniteSubprocVectorEnv,
|
|
"shmem": FiniteShmemVectorEnv,
|
|
}
|
|
|
|
finite_env_cls = env_type_cls_mapping[env_type]
|
|
|
|
return finite_env_cls(logger, [env_factory for _ in range(concurrency)])
|