1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/tests/rl/test_logger.py
Yuge Zhang 9a40fd3cdc Qlib RL framework (stage 1) - single-asset order execution (#1076)
* rl init

* aux info

* Reward config

* update

* simple

* update saoe init

* update simulator and seed

* minor

* minor

* update sim

* checkpoint

* obs

* Update interpreter

* init qlib simulator

* checkpoint

* Refine codebase

* checkpoint

* checkpoint

* Add one test

* More tests

* Simulator checkpoint

* checkpoint

* First-step tested

* Checkpoint

* Update data_queue API

* Checkpoint

* Update test

* Move files

* Checkpoint

* Single-quote -> double-quote

* Fix finite env tests

* Tested with mypy

* pep-574

* No call for env done

* Update finite env docs

* Fix csv writer

* Refine tester

* Update logger

* Add another logger test

* Checkpoint

* Add network sanity test

* steps per episode is not correct

* Cleanup code, ready for PR

* Reformat with black

* Fix pylint for py37

* Fix lint

* Fix lint

* Fix flake

* update mypy command

* mypy

* Update exclude pattern

* Use pyproject.toml

* test

* .

* .

* Refactor pipeline

* .

* defaults run bash

* .

* Revert and skip follow_imports

* Fix toml issue

* fix mypy

* .

* .

* .

* Fix install

* Minor fix

* Fix test

* Fix test

* Remove requirements

* Revert

* fix tests

* Fix lint

* .

* .

* .

* .

* .

* update install from source command

* .

* Fix data download

* .

* .

* .

* .

* .

* .

* Fix py37

* Ignore tests on non-linux

* resolve comments

* fix tests

* resolve comments

* some typo

* style updates

* More comments

* fix dummy

* add warning

* Align precision in some system

* Added some impl notes

Co-authored-by: Young <afe.young@gmail.com>
2022-05-21 18:19:24 +08:00

157 lines
4.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from random import randint, choice
from pathlib import Path
import re
import gym
import numpy as np
import pandas as pd
from gym import spaces
from tianshou.data import Collector, Batch
from tianshou.policy import BasePolicy
from qlib.log import set_log_with_config
from qlib.config import C
from qlib.constant import INF
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.simulator import Simulator
from qlib.rl.utils.data_queue import DataQueue
from qlib.rl.utils.env_wrapper import InfoDict, EnvWrapper
from qlib.rl.utils.log import LogLevel, LogCollector, CsvWriter, ConsoleWriter
from qlib.rl.utils.finite_env import vectorize_env
class SimpleEnv(gym.Env[int, int]):
def __init__(self):
self.logger = LogCollector()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
self.step_count = 0
return 0
def step(self, action: int):
self.logger.reset()
self.logger.add_scalar("reward", 42.0)
self.logger.add_scalar("a", randint(1, 10))
self.logger.add_array("b", pd.DataFrame({"a": [1, 2], "b": [3, 4]}))
if self.step_count >= 3:
done = choice([False, True])
else:
done = False
if 2 <= self.step_count <= 3:
self.logger.add_scalar("c", randint(11, 20))
self.step_count += 1
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
def learn(self, batch):
pass
def test_simple_env_logger(caplog):
set_log_with_config(C.logging_config)
for venv_cls_name in ["dummy", "shmem", "subproc"]:
writer = ConsoleWriter()
csv_writer = CsvWriter(Path(__file__).parent / ".output")
venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer])
with venv.collector_guard():
collector = Collector(AnyPolicy(), venv)
collector.collect(n_episode=30)
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert output_file.columns.tolist() == ["reward", "a", "c"]
assert len(output_file) >= 30
line_counter = 0
for line in caplog.text.splitlines():
line = line.strip()
if line:
line_counter += 1
assert re.match(r".*reward 42\.0000 \(42.0000\) a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
assert line_counter >= 3
class SimpleSimulator(Simulator[int, float, float]):
def __init__(self, initial: int, **kwargs) -> None:
self.initial = float(initial)
def step(self, action: float) -> None:
import torch
self.initial += action
self.env.logger.add_scalar("test_a", torch.tensor(233.0))
self.env.logger.add_scalar("test_b", np.array(200))
def get_state(self) -> float:
return self.initial
def done(self) -> bool:
return self.initial % 1 > 0.5
class DummyStateInterpreter(StateInterpreter[float, float]):
def interpret(self, state: float) -> float:
return state
@property
def observation_space(self) -> spaces.Box:
return spaces.Box(0, np.inf, shape=(), dtype=np.float32)
class DummyActionInterpreter(ActionInterpreter[float, int, float]):
def interpret(self, state: float, action: int) -> float:
return action / 100
@property
def action_space(self) -> spaces.Box:
return spaces.Discrete(5)
class RandomFivePolicy(BasePolicy):
def forward(self, batch, state=None):
return Batch(act=np.random.randint(5, size=len(batch)))
def learn(self, batch):
pass
def test_logger_with_env_wrapper():
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
env_wrapper_factory = lambda: EnvWrapper(
SimpleSimulator,
DummyStateInterpreter(),
DummyActionInterpreter(),
data_iterator,
logger=LogCollector(LogLevel.DEBUG),
)
# loglevel can be debug here because metrics can all dump into csv
# otherwise, csv writer might crash
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
with venv.collector_guard():
collector = Collector(RandomFivePolicy(), venv)
collector.collect(n_episode=INF * len(venv))
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
assert len(output_df) == 20
# obs has a increasing trend
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
assert (output_df["test_a"] == 233).all()
assert (output_df["test_b"] == 200).all()
assert "steps_per_episode" in output_df and "reward" in output_df