1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00
Files
qlib/tests/rl/test_logger.py
Yuge Zhang 25ecb1135f Qlib RL framework (stage 2) - trainer (#1125)
* checkpoint

(cherry picked from commit 1a8e0bd4671ee6d624a7d09bb198a273282cd050)

* Not a workable version

(cherry picked from commit 3498e185684cd5590d3ab97e0ab69eab8c1e0e3a)

* vessel

* ckpt

* .

* vessel

* .

* .

* checkpoint callback

* .

* cleanup

* logger

* .

* test

* .

* add test

* .

* .

* .

* .

* New reward

* Add train API

* fix mypy

* fix lint

* More comment

* 3.7 compat

* fix test

* fix test

* .

* Resolve comments

* fix typehint
2022-06-28 19:53:05 +08:00

157 lines
4.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from random import randint, choice
from pathlib import Path
import re
import gym
import numpy as np
import pandas as pd
from gym import spaces
from tianshou.data import Collector, Batch
from tianshou.policy import BasePolicy
from qlib.log import set_log_with_config
from qlib.config import C
from qlib.constant import INF
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.utils.data_queue import DataQueue
from qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper
from qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter
from qlib.rl.utils.finite_env import vectorize_env
class SimpleEnv(gym.Env[int, int]):
def __init__(self):
self.logger = LogCollector()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.step_count = 0
return 0
def step(self, action: int):
self.logger.reset()
self.logger.add_scalar("reward", 42.0)
self.logger.add_scalar("a", randint(1, 10))
self.logger.add_array("b", pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
if self.step_count >= 3:
done = choice([False, True])
else:
done = False
if 2 <= self.step_count <= 3:
self.logger.add_scalar("c", randint(11, 20))
self.step_count += 1
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def test_simple_env_logger(caplog):
set_log_with_config(C.logging_config)
for venv_cls_name in ["dummy", "shmem", "subproc"]:
writer = ConsoleWriter()
csv_writer = CsvWriter(Path(__file__).parent / ".output")
venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer])
with venv.collector_guard():
collector = Collector(AnyPolicy(), venv)
collector.collect(n_episode=30)
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert output_file.columns.tolist() == ["reward", "a", "c"]
assert len(output_file) >= 30
line_counter = 0
for line in caplog.text.splitlines():
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3
class SimpleSimulator(Simulator[int, float, float]):
def __init__(self, initial: int, **kwargs) -> None:
self.initial = float(initial)
def step(self, action: float) -> None:
import torch
self.initial += action
self.env.logger.add_scalar("test_a", torch.tensor(233.0))
self.env.logger.add_scalar("test_b", np.array(200))
def get_state(self) -> float:
return self.initial
def done(self) -> bool:
return self.initial % 1 > 0.5
class DummyStateInterpreter(StateInterpreter[float, float]):
def interpret(self, state: float) -> float:
return state
@property
def observation_space(self) -> spaces.Box:
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
class DummyActionInterpreter(ActionInterpreter[float, int, float]):
def interpret(self, state: float, action: int) -> float:
return action / 100
@property
def action_space(self) -> spaces.Box:
return spaces.Discrete(5)
class RandomFivePolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.random.randint(5, size=len(batch)))
def learn(self, batch):
pass
def test_logger_with_env_wrapper():
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
env_wrapper_factory = lambda: EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)
# loglevel can be debug here because metrics can all dump into csv
# otherwise, csv writer might crash
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
with venv.collector_guard():
collector = Collector(RandomFivePolicy(), venv)
collector.collect(n_episode=INF * len(venv))
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(output_df) == 20
# obs has a increasing trend
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
assert (output_df["test_a"] == 233).all()
assert (output_df["test_b"] == 200).all()
assert "steps_per_episode" in output_df and "reward" in output_df