From 7d466890c0227834348e863fe7ab2751bae45812 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 19 Jul 2022 15:20:19 +0800 Subject: [PATCH] Black format --- qlib/rl/order_execution/simulator_qlib.py | 48 ++++++++++++------- .../tests/test_simulator_qlib.py | 6 +-- tests/rl/test_saoe_simple.py | 11 ++++- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index b27fe55ff..30512c709 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -18,8 +18,13 @@ from qlib.rl.data.pickle_styled import QlibIntradayBacktestData from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig from qlib.rl.order_execution.from_neutrader.feature import init_qlib from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState -from qlib.rl.order_execution.utils import (_convert_tick_str_to_int, _dataframe_append, _get_common_infra, - _get_ticks_slice, _price_advantage) +from qlib.rl.order_execution.utils import ( + _convert_tick_str_to_int, + _dataframe_append, + _get_common_infra, + _get_ticks_slice, + _price_advantage, +) from qlib.rl.simulator import Simulator from qlib.strategy.base import BaseStrategy @@ -108,22 +113,33 @@ class StateMaintainer: if len(execute_result) > 0: exchange = inner_executor.trade_exchange - market_price = np.array([exchange.get_deal_price( - execute_order.stock_id, - execute_result[0][0].start_time, - execute_result[-1][0].start_time, - direction=execute_order.direction, - method=None, - )]).reshape(-1) - market_volume = np.array([exchange.get_volume( - execute_order.stock_id, - execute_result[0][0].start_time, - execute_result[-1][0].start_time, - method=None, - )]).reshape(-1) + market_price = np.array( + [ + exchange.get_deal_price( + execute_order.stock_id, + execute_result[0][0].start_time, + execute_result[-1][0].start_time, + direction=execute_order.direction, + method=None, + ) + ] + ).reshape(-1) + market_volume = np.array( + [ + exchange.get_volume( + execute_order.stock_id, + execute_result[0][0].start_time, + execute_result[-1][0].start_time, + method=None, + ) + ] + ).reshape(-1) datetime_list = _get_ticks_slice( - self._tick_index, execute_result[0][0].start_time, execute_result[-1][0].start_time, include_end=True, + self._tick_index, + execute_result[0][0].start_time, + execute_result[-1][0].start_time, + include_end=True, ) else: market_price = np.array([]) diff --git a/qlib/rl/order_execution/tests/test_simulator_qlib.py b/qlib/rl/order_execution/tests/test_simulator_qlib.py index ece369600..2bb63bb8a 100644 --- a/qlib/rl/order_execution/tests/test_simulator_qlib.py +++ b/qlib/rl/order_execution/tests/test_simulator_qlib.py @@ -91,16 +91,16 @@ def test_simulator_first_step(): order = get_order() simulator = get_simulator(order) state = simulator.get_state() - assert state.cur_time == pd.Timestamp('2019-03-04 09:30:00') + assert state.cur_time == pd.Timestamp("2019-03-04 09:30:00") assert state.position == TOTAL_POSITION AMOUNT = 300.0 simulator.step(AMOUNT) state = simulator.get_state() - assert state.cur_time == pd.Timestamp('2019-03-04 10:00:00') + assert state.cur_time == pd.Timestamp("2019-03-04 10:00:00") assert state.position == TOTAL_POSITION - AMOUNT assert len(state.history_exec) == 30 - assert state.history_exec.index[0] == pd.Timestamp('2019-03-04 09:30:00') + assert state.history_exec.index[0] == pd.Timestamp("2019-03-04 09:30:00") assert is_close(state.history_exec["market_volume"].iloc[0], 109382.382812) assert is_close(state.history_exec["market_price"].iloc[0], 149.566483) diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 51ba13b34..c683937d4 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -18,7 +18,16 @@ from qlib.config import C from qlib.log import set_log_with_config from qlib.rl.data import pickle_styled from qlib.rl.entries.test import backtest -from qlib.rl.order_execution import SingleAssetOrderExecution, FullHistoryStateInterpreter, CurrentStepStateInterpreter, CategoricalActionInterpreter, TwapRelativeActionInterpreter, AllOne, Recurrent, PPO +from qlib.rl.order_execution import ( + SingleAssetOrderExecution, + FullHistoryStateInterpreter, + CurrentStepStateInterpreter, + CategoricalActionInterpreter, + TwapRelativeActionInterpreter, + AllOne, + Recurrent, + PPO, +) from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")