From 09f51061e1ea3f04f3888218839d1841f1821770 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 8 Jul 2022 14:52:32 +0800 Subject: [PATCH] Black format --- qlib/backtest/decision.py | 4 ++- .../order_execution/from_neutrader/config.py | 4 +-- .../order_execution/from_neutrader/feature.py | 32 +++++++++-------- qlib/rl/order_execution/simulator_qlib.py | 35 +++++++++---------- .../tests/test_simulator_qlib.py | 9 ++--- 5 files changed, 44 insertions(+), 40 deletions(-) diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 1772436dc..51c47005d 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -531,7 +531,9 @@ class TradeDecisionWO(BaseTradeDecision): Besides, the time_range is also included. """ - def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None): + def __init__( + self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None + ): super().__init__(strategy, trade_range=trade_range) self.order_list = order_list start, end = strategy.trade_calendar.get_step_time() diff --git a/qlib/rl/order_execution/from_neutrader/config.py b/qlib/rl/order_execution/from_neutrader/config.py index bb27f6f37..b2e556385 100644 --- a/qlib/rl/order_execution/from_neutrader/config.py +++ b/qlib/rl/order_execution/from_neutrader/config.py @@ -10,7 +10,7 @@ class ExchangeConfig: volume_threshold: dict open_cost: float = 0.0005 close_cost: float = 0.0015 - min_cost: float = 5. - trade_unit: Optional[float] = 100. + min_cost: float = 5.0 + trade_unit: Optional[float] = 100.0 cash_limit: Optional[Union[Path, float]] = None generate_report: bool = False diff --git a/qlib/rl/order_execution/from_neutrader/feature.py b/qlib/rl/order_execution/from_neutrader/feature.py index 81286b40c..40e62c8c0 100644 --- a/qlib/rl/order_execution/from_neutrader/feature.py +++ b/qlib/rl/order_execution/from_neutrader/feature.py @@ -34,10 +34,15 @@ class LRUCache: class DataWrapper: - - def __init__(self, feature_dataset: DatasetH, backtest_dataset: DatasetH, - columns_today: List[str], columns_yesterday: List[str], _internal: bool = False): - assert _internal, 'Init function of data wrapper is for internal use only.' + def __init__( + self, + feature_dataset: DatasetH, + backtest_dataset: DatasetH, + columns_today: List[str], + columns_yesterday: List[str], + _internal: bool = False, + ): + assert _internal, "Init function of data wrapper is for internal use only." self.feature_dataset = feature_dataset self.backtest_dataset = backtest_dataset @@ -76,8 +81,7 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None: qlib.init( region=REG_CN, auto_mount=False, - custom_ops=[DayLast, FFillNan, BFillNan, - Date, Select, IsNull, IsInf, Cut, DayCumsum], + custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum], expression_cache=None, calendar_provider={ "class": "LocalCalendarProvider", @@ -104,22 +108,22 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None: provider_uri=provider_uri_map, kernels=1, redis_port=-1, - clear_mem_cache=False # init_qlib will be called for multiple times. Keep the cache for improving performance + clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance ) # this won't work if it's put outside in case of multiprocessing if part is None: - feature_path = config.feature_root_dir / 'feature.pkl' - backtest_path = config.feature_root_dir / 'backtest.pkl' + feature_path = config.feature_root_dir / "feature.pkl" + backtest_path = config.feature_root_dir / "backtest.pkl" else: - feature_path = config.feature_root_dir / 'feature' / (part + '.pkl') - backtest_path = config.feature_root_dir / 'backtest' / (part + '.pkl') + feature_path = config.feature_root_dir / "feature" / (part + ".pkl") + backtest_path = config.feature_root_dir / "backtest" / (part + ".pkl") - with feature_path.open('rb') as f: + with feature_path.open("rb") as f: print(feature_path) feature_dataset = pickle.load(f) - with backtest_path.open('rb') as f: + with backtest_path.open("rb") as f: backtest_dataset = pickle.load(f) _dataset = DataWrapper( @@ -127,5 +131,5 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None: backtest_dataset, config.feature_columns_today, config.feature_columns_yesterday, - _internal=True + _internal=True, ) diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index c3a52631c..25999eb3b 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -31,17 +31,13 @@ def get_common_infra( ) -> CommonInfrastructure: # need to specify a range here for acceleration if cash_limit is None: - trade_account = Account( - init_cash=int(1e12), - benchmark_config={}, - pos_type='InfPosition' - ) + trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition") else: trade_account = Account( init_cash=cash_limit, benchmark_config={}, - pos_type='Position', - position_dict={code: {"amount": 1e12, "price": 1.} for code in codes} + pos_type="Position", + position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes}, ) exchange = get_exchange( @@ -55,7 +51,7 @@ def get_common_infra( start_time=trade_date, end_time=trade_date + pd.DateOffset(1), trade_unit=config.trade_unit, - volume_threshold=config.volume_threshold + volume_threshold=config.volume_threshold, ) return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange) @@ -145,17 +141,16 @@ class StateMaintainer: if len(execute_result) > 0: exchange = inner_executor.trade_exchange minutes = _get_minutes(execute_result[0][0].start_time, execute_result[-1][0].start_time) - market_price = np.array([ - exchange.get_deal_price(execute_order.stock_id, t, t, direction=execute_order.direction) - for t in minutes - ]) + market_price = np.array( + [ + exchange.get_deal_price(execute_order.stock_id, t, t, direction=execute_order.direction) + for t in minutes + ] + ) market_volume = np.array([exchange.get_volume(execute_order.stock_id, t, t) for t in minutes]) datetime_list = _get_ticks_slice( - self._tick_index, - execute_result[0][0].start_time, - execute_result[-1][0].start_time, - include_end=True + self._tick_index, execute_result[0][0].start_time, execute_result[-1][0].start_time, include_end=True ) else: market_price = np.array([]) @@ -188,9 +183,11 @@ class StateMaintainer: self.history_steps = _dataframe_append( self.history_steps, - [self._metrics_collect( - execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol - )], + [ + self._metrics_collect( + execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol + ) + ], ) def _metrics_collect( diff --git a/qlib/rl/order_execution/tests/test_simulator_qlib.py b/qlib/rl/order_execution/tests/test_simulator_qlib.py index 7d7d376d7..914ec477e 100644 --- a/qlib/rl/order_execution/tests/test_simulator_qlib.py +++ b/qlib/rl/order_execution/tests/test_simulator_qlib.py @@ -29,11 +29,12 @@ qlib_config = QlibConfig( # fmt: on exchange_config = ExchangeConfig( - limit_threshold=('$ask == 0', '$bid == 0'), - deal_price=('If($ask == 0, $bid, $ask)', 'If($bid == 0, $ask, $bid)'), + limit_threshold=("$ask == 0", "$bid == 0"), + deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"), volume_threshold={ - 'all': ('cum', "0.2 * DayCumsum($volume, '9:45', '14:44')"), - 'buy': ('current', '$askV1'), 'sell': ('current', '$bidV1') + "all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"), + "buy": ("current", "$askV1"), + "sell": ("current", "$bidV1"), }, open_cost=0.0005, close_cost=0.0015,