1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Migrate NeuTrader to Qlib RL (#1169)

* Refine previous version RL codes

* Polish utils/__init__.py

* Draft

* Use | instead of Union

* Simulator & action interpreter

* Test passed

* Migrate to SAOEState & new qlib interpreter

* Black format

* . Revert file_storage change

* Refactor file structure & renaming functions

* Enrich test cases

* Add QlibIntradayBacktestData

* Test interpreter

* Black format

* .

.

.

* Rename receive_execute_result()

* Use indicator to simplify state update

* Format code

* Modify data path

* Adjust file structure

* Minor change

* Add copyright message

* Format code

* Rename util functions

* Add CI

* Pylint issue

* Remove useless code to pass pylint

* Pass mypy

* Mypy issue

* mypy issue

* mypy issue

* Revert "mypy issue"

This reverts commit 8eb1b0174e.

* mypy issue

* mypy issue

* Fix the numpy version incompatible bug

* Fix a minor typing issue

* Try to skip python 3.7 test for qlib simulator

* Resolve PR comments by Yuge; solve several CI issues.

* Black issue

* Fix a low-level type error

* Change data name

* Resolve PR comments. Leave TODOs in the code base.

Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
Huoran Li
2022-08-01 09:56:07 +08:00
committed by GitHub
parent 687edd79d0
commit 2752bdc92c
35 changed files with 1305 additions and 257 deletions

View File

@@ -5,6 +5,8 @@ from random import randint, choice
from pathlib import Path
import re
from typing import Any, Tuple
import gym
import numpy as np
import pandas as pd
@@ -24,16 +26,16 @@ from qlib.rl.utils.finite_env import vectorize_env
class SimpleEnv(gym.Env[int, int]):
def __init__(self):
def __init__(self) -> None:
self.logger = LogCollector()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
def reset(self, *args: Any, **kwargs: Any) -> int:
self.step_count = 0
return 0
def step(self, action: int):
def step(self, action: int) -> Tuple[int, float, bool, dict]:
self.logger.reset()
self.logger.add_scalar("reward", 42.0)
@@ -53,6 +55,9 @@ class SimpleEnv(gym.Env[int, int]):
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
def render(self, mode: str = "human") -> None:
pass
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
@@ -86,7 +91,8 @@ def test_simple_env_logger(caplog):
class SimpleSimulator(Simulator[int, float, float]):
def __init__(self, initial: int, **kwargs) -> None:
def __init__(self, initial: int, **kwargs: Any) -> None:
super(SimpleSimulator, self).__init__(initial, **kwargs)
self.initial = float(initial)
def step(self, action: float) -> None:

View File

@@ -0,0 +1,177 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import pandas as pd
import pytest
from qlib.backtest.decision import Order, OrderDir
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
from qlib.backtest.utils import CommonInfrastructure
from qlib.contrib.strategy import TWAPStrategy
from qlib.rl.order_execution import CategoricalActionInterpreter
from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, SingleAssetOrderExecutionQlib
TOTAL_POSITION = 2100.0
python_version_request = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def is_close(a: float, b: float, epsilon: float = 1e-4) -> bool:
return abs(a - b) <= epsilon
def get_order() -> Order:
return Order(
stock_id="SH600000",
amount=TOTAL_POSITION,
direction=OrderDir.BUY,
start_time=pd.Timestamp("2019-03-04 09:30:00"),
end_time=pd.Timestamp("2019-03-04 14:29:00"),
)
def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib:
def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor:
return NestedExecutor(
time_per_step=time_per_step,
inner_strategy=TWAPStrategy(),
inner_executor=SimulatorExecutor(
time_per_step="1min",
verbose=False,
trade_type=SimulatorExecutor.TT_SERIAL,
generate_report=False,
common_infra=common_infra,
track_data=True,
),
common_infra=common_infra,
track_data=True,
)
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "qlib_simulator"
# fmt: off
qlib_config = {
"provider_uri_day": DATA_ROOT_DIR / "qlib_1d",
"provider_uri_1min": DATA_ROOT_DIR / "qlib_1min",
"feature_root_dir": DATA_ROOT_DIR / "qlib_handler_stock",
"feature_columns_today": [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5",
],
"feature_columns_yesterday": [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
],
}
# fmt: on
exchange_config = ExchangeConfig(
limit_threshold=("$ask == 0", "$bid == 0"),
deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
volume_threshold={
"all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"),
"buy": ("current", "$askV1"),
"sell": ("current", "$bidV1"),
},
open_cost=0.0005,
close_cost=0.0015,
min_cost=5.0,
trade_unit=None,
cash_limit=None,
generate_report=False,
)
return SingleAssetOrderExecutionQlib(
order=order,
time_per_step="30min",
qlib_config=qlib_config,
inner_executor_fn=_inner_executor_fn,
exchange_config=exchange_config,
)
@python_version_request
def test_simulator_first_step():
order = get_order()
simulator = get_simulator(order)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2019-03-04 09:30:00")
assert state.position == TOTAL_POSITION
AMOUNT = 300.0
simulator.step(AMOUNT)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2019-03-04 10:00:00")
assert state.position == TOTAL_POSITION - AMOUNT
assert len(state.history_exec) == 30
assert state.history_exec.index[0] == pd.Timestamp("2019-03-04 09:30:00")
assert is_close(state.history_exec["market_volume"].iloc[0], 109382.382812)
assert is_close(state.history_exec["market_price"].iloc[0], 149.566483)
assert (state.history_exec["amount"] == AMOUNT / 30).all()
assert (state.history_exec["deal_amount"] == AMOUNT / 30).all()
assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483)
assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825)
assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30)
# assert state.history_exec["ffr"].iloc[0] == 1 / 60 # FIXME
assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938)
assert state.history_steps["amount"].iloc[0] == AMOUNT
assert state.history_steps["deal_amount"].iloc[0] == AMOUNT
assert state.history_steps["ffr"].iloc[0] == 1.0
assert is_close(
state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0),
(state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000,
)
@python_version_request
def test_simulator_stop_twap() -> None:
order = get_order()
simulator = get_simulator(order)
NUM_STEPS = 7
for i in range(NUM_STEPS):
simulator.step(TOTAL_POSITION / NUM_STEPS)
HISTORY_STEP_LENGTH = 30 * NUM_STEPS
state = simulator.get_state()
assert len(state.history_exec) == HISTORY_STEP_LENGTH
assert (state.history_exec["deal_amount"] == TOTAL_POSITION / HISTORY_STEP_LENGTH).all()
assert is_close(state.history_steps["position"].iloc[0], TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS)
assert is_close(state.history_steps["position"].iloc[-1], 0.0)
assert is_close(state.position, 0.0)
assert is_close(state.metrics["ffr"], 1.0)
assert is_close(state.metrics["market_price"], state.backtest_data.get_deal_price().mean())
assert is_close(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
assert is_close(state.metrics["trade_price"], state.metrics["market_price"])
assert is_close(state.metrics["pa"], 0.0)
assert simulator.done()
@python_version_request
def test_interpreter() -> None:
NUM_EXECUTION = 3
order = get_order()
simulator = get_simulator(order)
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
NUM_STEPS = 7
state = simulator.get_state()
position_history = []
for i in range(NUM_STEPS):
simulator.step(interpreter_action(state, 1))
state = simulator.get_state()
position_history.append(state.position)
assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)
if __name__ == "__main__":
test_simulator_first_step()
test_simulator_stop_twap()
test_interpreter()

View File

@@ -9,7 +9,6 @@ from typing import NamedTuple
import numpy as np
import pandas as pd
import pytest
import torch
from tianshou.data import Batch
@@ -17,8 +16,8 @@ from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.trainer import backtest, train
from qlib.rl.order_execution import *
from qlib.rl.trainer import backtest, train
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
@@ -38,7 +37,7 @@ CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
def test_pickle_data_inspect():
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
assert len(data) == 390
data = pickle_styled.load_intraday_processed_data(