1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 01:21:18 +08:00
Files
qlib/tests/rl/test_logger.py
Huoran Li 2752bdc92c Migrate NeuTrader to Qlib RL (#1169)
* Refine previous version RL codes

* Polish utils/__init__.py

* Draft

* Use | instead of Union

* Simulator & action interpreter

* Test passed

* Migrate to SAOEState & new qlib interpreter

* Black format

* . Revert file_storage change

* Refactor file structure & renaming functions

* Enrich test cases

* Add QlibIntradayBacktestData

* Test interpreter

* Black format

* .

.

.

* Rename receive_execute_result()

* Use indicator to simplify state update

* Format code

* Modify data path

* Adjust file structure

* Minor change

* Add copyright message

* Format code

* Rename util functions

* Add CI

* Pylint issue

* Remove useless code to pass pylint

* Pass mypy

* Mypy issue

* mypy issue

* mypy issue

* Revert "mypy issue"

This reverts commit 8eb1b0174e.

* mypy issue

* mypy issue

* Fix the numpy version incompatible bug

* Fix a minor typing issue

* Try to skip python 3.7 test for qlib simulator

* Resolve PR comments by Yuge; solve several CI issues.

* Black issue

* Fix a low-level type error

* Change data name

* Resolve PR comments. Leave TODOs in the code base.

Co-authored-by: Young <afe.young@gmail.com>
2022-08-01 09:56:07 +08:00

163 lines
5.2 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from random import randint, choice
from pathlib import Path
import re
from typing import Any, Tuple
import gym
import numpy as np
import pandas as pd
from gym import spaces
from tianshou.data import Collector, Batch
from tianshou.policy import BasePolicy
from qlib.log import set_log_with_config
from qlib.config import C
from qlib.constant import INF
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.utils.data_queue import DataQueue
from qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper
from qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter
from qlib.rl.utils.finite_env import vectorize_env
class SimpleEnv(gym.Env[int, int]):
def __init__(self) -> None:
self.logger = LogCollector()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self, *args: Any, **kwargs: Any) -> int:
self.step_count = 0
return 0
def step(self, action: int) -> Tuple[int, float, bool, dict]:
self.logger.reset()
self.logger.add_scalar("reward", 42.0)
self.logger.add_scalar("a", randint(1, 10))
self.logger.add_array("b", pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
if self.step_count >= 3:
done = choice([False, True])
else:
done = False
if 2 <= self.step_count <= 3:
self.logger.add_scalar("c", randint(11, 20))
self.step_count += 1
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
def render(self, mode: str = "human") -> None:
pass
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def test_simple_env_logger(caplog):
set_log_with_config(C.logging_config)
for venv_cls_name in ["dummy", "shmem", "subproc"]:
writer = ConsoleWriter()
csv_writer = CsvWriter(Path(__file__).parent / ".output")
venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer])
with venv.collector_guard():
collector = Collector(AnyPolicy(), venv)
collector.collect(n_episode=30)
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert output_file.columns.tolist() == ["reward", "a", "c"]
assert len(output_file) >= 30
line_counter = 0
for line in caplog.text.splitlines():
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3
class SimpleSimulator(Simulator[int, float, float]):
def __init__(self, initial: int, **kwargs: Any) -> None:
super(SimpleSimulator, self).__init__(initial, **kwargs)
self.initial = float(initial)
def step(self, action: float) -> None:
import torch
self.initial += action
self.env.logger.add_scalar("test_a", torch.tensor(233.0))
self.env.logger.add_scalar("test_b", np.array(200))
def get_state(self) -> float:
return self.initial
def done(self) -> bool:
return self.initial % 1 > 0.5
class DummyStateInterpreter(StateInterpreter[float, float]):
def interpret(self, state: float) -> float:
return state
@property
def observation_space(self) -> spaces.Box:
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
class DummyActionInterpreter(ActionInterpreter[float, int, float]):
def interpret(self, state: float, action: int) -> float:
return action / 100
@property
def action_space(self) -> spaces.Box:
return spaces.Discrete(5)
class RandomFivePolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.random.randint(5, size=len(batch)))
def learn(self, batch):
pass
def test_logger_with_env_wrapper():
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
env_wrapper_factory = lambda: EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)
# loglevel can be debug here because metrics can all dump into csv
# otherwise, csv writer might crash
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
with venv.collector_guard():
collector = Collector(RandomFivePolicy(), venv)
collector.collect(n_episode=INF * len(venv))
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(output_df) == 20
# obs has a increasing trend
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
assert (output_df["test_a"] == 233).all()
assert (output_df["test_b"] == 200).all()
assert "steps_per_episode" in output_df and "reward" in output_df