mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Black format
This commit is contained in:
@@ -18,8 +18,13 @@ from qlib.rl.data.pickle_styled import QlibIntradayBacktestData
|
||||
from qlib.rl.order_execution.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.order_execution.from_neutrader.feature import init_qlib
|
||||
from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState
|
||||
from qlib.rl.order_execution.utils import (_convert_tick_str_to_int, _dataframe_append, _get_common_infra,
|
||||
_get_ticks_slice, _price_advantage)
|
||||
from qlib.rl.order_execution.utils import (
|
||||
_convert_tick_str_to_int,
|
||||
_dataframe_append,
|
||||
_get_common_infra,
|
||||
_get_ticks_slice,
|
||||
_price_advantage,
|
||||
)
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
@@ -108,22 +113,33 @@ class StateMaintainer:
|
||||
|
||||
if len(execute_result) > 0:
|
||||
exchange = inner_executor.trade_exchange
|
||||
market_price = np.array([exchange.get_deal_price(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
direction=execute_order.direction,
|
||||
method=None,
|
||||
)]).reshape(-1)
|
||||
market_volume = np.array([exchange.get_volume(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
method=None,
|
||||
)]).reshape(-1)
|
||||
market_price = np.array(
|
||||
[
|
||||
exchange.get_deal_price(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
direction=execute_order.direction,
|
||||
method=None,
|
||||
)
|
||||
]
|
||||
).reshape(-1)
|
||||
market_volume = np.array(
|
||||
[
|
||||
exchange.get_volume(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
method=None,
|
||||
)
|
||||
]
|
||||
).reshape(-1)
|
||||
|
||||
datetime_list = _get_ticks_slice(
|
||||
self._tick_index, execute_result[0][0].start_time, execute_result[-1][0].start_time, include_end=True,
|
||||
self._tick_index,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
include_end=True,
|
||||
)
|
||||
else:
|
||||
market_price = np.array([])
|
||||
|
||||
@@ -91,16 +91,16 @@ def test_simulator_first_step():
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp('2019-03-04 09:30:00')
|
||||
assert state.cur_time == pd.Timestamp("2019-03-04 09:30:00")
|
||||
assert state.position == TOTAL_POSITION
|
||||
|
||||
AMOUNT = 300.0
|
||||
simulator.step(AMOUNT)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp('2019-03-04 10:00:00')
|
||||
assert state.cur_time == pd.Timestamp("2019-03-04 10:00:00")
|
||||
assert state.position == TOTAL_POSITION - AMOUNT
|
||||
assert len(state.history_exec) == 30
|
||||
assert state.history_exec.index[0] == pd.Timestamp('2019-03-04 09:30:00')
|
||||
assert state.history_exec.index[0] == pd.Timestamp("2019-03-04 09:30:00")
|
||||
|
||||
assert is_close(state.history_exec["market_volume"].iloc[0], 109382.382812)
|
||||
assert is_close(state.history_exec["market_price"].iloc[0], 149.566483)
|
||||
|
||||
@@ -18,7 +18,16 @@ from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.entries.test import backtest
|
||||
from qlib.rl.order_execution import SingleAssetOrderExecution, FullHistoryStateInterpreter, CurrentStepStateInterpreter, CategoricalActionInterpreter, TwapRelativeActionInterpreter, AllOne, Recurrent, PPO
|
||||
from qlib.rl.order_execution import (
|
||||
SingleAssetOrderExecution,
|
||||
FullHistoryStateInterpreter,
|
||||
CurrentStepStateInterpreter,
|
||||
CategoricalActionInterpreter,
|
||||
TwapRelativeActionInterpreter,
|
||||
AllOne,
|
||||
Recurrent,
|
||||
PPO,
|
||||
)
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
|
||||
Reference in New Issue
Block a user