mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Migrate NeuTrader to Qlib RL (#1169)
* Refine previous version RL codes
* Polish utils/__init__.py
* Draft
* Use | instead of Union
* Simulator & action interpreter
* Test passed
* Migrate to SAOEState & new qlib interpreter
* Black format
* . Revert file_storage change
* Refactor file structure & renaming functions
* Enrich test cases
* Add QlibIntradayBacktestData
* Test interpreter
* Black format
* .
.
.
* Rename receive_execute_result()
* Use indicator to simplify state update
* Format code
* Modify data path
* Adjust file structure
* Minor change
* Add copyright message
* Format code
* Rename util functions
* Add CI
* Pylint issue
* Remove useless code to pass pylint
* Pass mypy
* Mypy issue
* mypy issue
* mypy issue
* Revert "mypy issue"
This reverts commit 8eb1b0174e.
* mypy issue
* mypy issue
* Fix the numpy version incompatible bug
* Fix a minor typing issue
* Try to skip python 3.7 test for qlib simulator
* Resolve PR comments by Yuge; solve several CI issues.
* Black issue
* Fix a low-level type error
* Change data name
* Resolve PR comments. Leave TODOs in the code base.
Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
@@ -5,6 +5,8 @@ from random import randint, choice
|
||||
from pathlib import Path
|
||||
|
||||
import re
|
||||
from typing import Any, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -24,16 +26,16 @@ from qlib.rl.utils.finite_env import vectorize_env
|
||||
|
||||
|
||||
class SimpleEnv(gym.Env[int, int]):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.logger = LogCollector()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, *args: Any, **kwargs: Any) -> int:
|
||||
self.step_count = 0
|
||||
return 0
|
||||
|
||||
def step(self, action: int):
|
||||
def step(self, action: int) -> Tuple[int, float, bool, dict]:
|
||||
self.logger.reset()
|
||||
|
||||
self.logger.add_scalar("reward", 42.0)
|
||||
@@ -53,6 +55,9 @@ class SimpleEnv(gym.Env[int, int]):
|
||||
|
||||
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
|
||||
|
||||
def render(self, mode: str = "human") -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AnyPolicy(BasePolicy):
|
||||
def forward(self, batch, state=None):
|
||||
@@ -86,7 +91,8 @@ def test_simple_env_logger(caplog):
|
||||
|
||||
|
||||
class SimpleSimulator(Simulator[int, float, float]):
|
||||
def __init__(self, initial: int, **kwargs) -> None:
|
||||
def __init__(self, initial: int, **kwargs: Any) -> None:
|
||||
super(SimpleSimulator, self).__init__(initial, **kwargs)
|
||||
self.initial = float(initial)
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
|
||||
177
tests/rl/test_qlib_simulator.py
Normal file
177
tests/rl/test_qlib_simulator.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.contrib.strategy import TWAPStrategy
|
||||
from qlib.rl.order_execution import CategoricalActionInterpreter
|
||||
from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, SingleAssetOrderExecutionQlib
|
||||
|
||||
TOTAL_POSITION = 2100.0
|
||||
|
||||
python_version_request = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
|
||||
|
||||
|
||||
def is_close(a: float, b: float, epsilon: float = 1e-4) -> bool:
|
||||
return abs(a - b) <= epsilon
|
||||
|
||||
|
||||
def get_order() -> Order:
|
||||
return Order(
|
||||
stock_id="SH600000",
|
||||
amount=TOTAL_POSITION,
|
||||
direction=OrderDir.BUY,
|
||||
start_time=pd.Timestamp("2019-03-04 09:30:00"),
|
||||
end_time=pd.Timestamp("2019-03-04 14:29:00"),
|
||||
)
|
||||
|
||||
|
||||
def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib:
|
||||
def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor:
|
||||
return NestedExecutor(
|
||||
time_per_step=time_per_step,
|
||||
inner_strategy=TWAPStrategy(),
|
||||
inner_executor=SimulatorExecutor(
|
||||
time_per_step="1min",
|
||||
verbose=False,
|
||||
trade_type=SimulatorExecutor.TT_SERIAL,
|
||||
generate_report=False,
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
),
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
)
|
||||
|
||||
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "qlib_simulator"
|
||||
|
||||
# fmt: off
|
||||
qlib_config = {
|
||||
"provider_uri_day": DATA_ROOT_DIR / "qlib_1d",
|
||||
"provider_uri_1min": DATA_ROOT_DIR / "qlib_1min",
|
||||
"feature_root_dir": DATA_ROOT_DIR / "qlib_handler_stock",
|
||||
"feature_columns_today": [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5",
|
||||
],
|
||||
"feature_columns_yesterday": [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
exchange_config = ExchangeConfig(
|
||||
limit_threshold=("$ask == 0", "$bid == 0"),
|
||||
deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
|
||||
volume_threshold={
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
"sell": ("current", "$bidV1"),
|
||||
},
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
cash_limit=None,
|
||||
generate_report=False,
|
||||
)
|
||||
|
||||
return SingleAssetOrderExecutionQlib(
|
||||
order=order,
|
||||
time_per_step="30min",
|
||||
qlib_config=qlib_config,
|
||||
inner_executor_fn=_inner_executor_fn,
|
||||
exchange_config=exchange_config,
|
||||
)
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_simulator_first_step():
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2019-03-04 09:30:00")
|
||||
assert state.position == TOTAL_POSITION
|
||||
|
||||
AMOUNT = 300.0
|
||||
simulator.step(AMOUNT)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2019-03-04 10:00:00")
|
||||
assert state.position == TOTAL_POSITION - AMOUNT
|
||||
assert len(state.history_exec) == 30
|
||||
assert state.history_exec.index[0] == pd.Timestamp("2019-03-04 09:30:00")
|
||||
|
||||
assert is_close(state.history_exec["market_volume"].iloc[0], 109382.382812)
|
||||
assert is_close(state.history_exec["market_price"].iloc[0], 149.566483)
|
||||
assert (state.history_exec["amount"] == AMOUNT / 30).all()
|
||||
assert (state.history_exec["deal_amount"] == AMOUNT / 30).all()
|
||||
assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483)
|
||||
assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825)
|
||||
assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30)
|
||||
# assert state.history_exec["ffr"].iloc[0] == 1 / 60 # FIXME
|
||||
|
||||
assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938)
|
||||
assert state.history_steps["amount"].iloc[0] == AMOUNT
|
||||
assert state.history_steps["deal_amount"].iloc[0] == AMOUNT
|
||||
assert state.history_steps["ffr"].iloc[0] == 1.0
|
||||
assert is_close(
|
||||
state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0),
|
||||
(state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000,
|
||||
)
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_simulator_stop_twap() -> None:
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
NUM_STEPS = 7
|
||||
for i in range(NUM_STEPS):
|
||||
simulator.step(TOTAL_POSITION / NUM_STEPS)
|
||||
|
||||
HISTORY_STEP_LENGTH = 30 * NUM_STEPS
|
||||
state = simulator.get_state()
|
||||
assert len(state.history_exec) == HISTORY_STEP_LENGTH
|
||||
|
||||
assert (state.history_exec["deal_amount"] == TOTAL_POSITION / HISTORY_STEP_LENGTH).all()
|
||||
assert is_close(state.history_steps["position"].iloc[0], TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS)
|
||||
assert is_close(state.history_steps["position"].iloc[-1], 0.0)
|
||||
assert is_close(state.position, 0.0)
|
||||
assert is_close(state.metrics["ffr"], 1.0)
|
||||
|
||||
assert is_close(state.metrics["market_price"], state.backtest_data.get_deal_price().mean())
|
||||
assert is_close(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
|
||||
assert is_close(state.metrics["trade_price"], state.metrics["market_price"])
|
||||
assert is_close(state.metrics["pa"], 0.0)
|
||||
|
||||
assert simulator.done()
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_interpreter() -> None:
|
||||
NUM_EXECUTION = 3
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
|
||||
|
||||
NUM_STEPS = 7
|
||||
state = simulator.get_state()
|
||||
position_history = []
|
||||
for i in range(NUM_STEPS):
|
||||
simulator.step(interpreter_action(state, 1))
|
||||
state = simulator.get_state()
|
||||
position_history.append(state.position)
|
||||
|
||||
assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simulator_first_step()
|
||||
test_simulator_stop_twap()
|
||||
test_interpreter()
|
||||
@@ -9,7 +9,6 @@ from typing import NamedTuple
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from tianshou.data import Batch
|
||||
|
||||
@@ -17,8 +16,8 @@ from qlib.backtest import Order
|
||||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.order_execution import *
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
@@ -38,7 +37,7 @@ CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
|
||||
|
||||
|
||||
def test_pickle_data_inspect():
|
||||
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
|
||||
data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
|
||||
assert len(data) == 390
|
||||
|
||||
data = pickle_styled.load_intraday_processed_data(
|
||||
|
||||
Reference in New Issue
Block a user