mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 20:11:08 +08:00
Black format
This commit is contained in:
@@ -531,7 +531,9 @@ class TradeDecisionWO(BaseTradeDecision):
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
|
||||
def __init__(
|
||||
self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None
|
||||
):
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = order_list
|
||||
start, end = strategy.trade_calendar.get_step_time()
|
||||
|
||||
@@ -10,7 +10,7 @@ class ExchangeConfig:
|
||||
volume_threshold: dict
|
||||
open_cost: float = 0.0005
|
||||
close_cost: float = 0.0015
|
||||
min_cost: float = 5.
|
||||
trade_unit: Optional[float] = 100.
|
||||
min_cost: float = 5.0
|
||||
trade_unit: Optional[float] = 100.0
|
||||
cash_limit: Optional[Union[Path, float]] = None
|
||||
generate_report: bool = False
|
||||
|
||||
@@ -34,10 +34,15 @@ class LRUCache:
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
|
||||
def __init__(self, feature_dataset: DatasetH, backtest_dataset: DatasetH,
|
||||
columns_today: List[str], columns_yesterday: List[str], _internal: bool = False):
|
||||
assert _internal, 'Init function of data wrapper is for internal use only.'
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
@@ -76,8 +81,7 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan,
|
||||
Date, Select, IsNull, IsInf, Cut, DayCumsum],
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],
|
||||
expression_cache=None,
|
||||
calendar_provider={
|
||||
"class": "LocalCalendarProvider",
|
||||
@@ -104,22 +108,22 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
|
||||
provider_uri=provider_uri_map,
|
||||
kernels=1,
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
||||
|
||||
# this won't work if it's put outside in case of multiprocessing
|
||||
|
||||
if part is None:
|
||||
feature_path = config.feature_root_dir / 'feature.pkl'
|
||||
backtest_path = config.feature_root_dir / 'backtest.pkl'
|
||||
feature_path = config.feature_root_dir / "feature.pkl"
|
||||
backtest_path = config.feature_root_dir / "backtest.pkl"
|
||||
else:
|
||||
feature_path = config.feature_root_dir / 'feature' / (part + '.pkl')
|
||||
backtest_path = config.feature_root_dir / 'backtest' / (part + '.pkl')
|
||||
feature_path = config.feature_root_dir / "feature" / (part + ".pkl")
|
||||
backtest_path = config.feature_root_dir / "backtest" / (part + ".pkl")
|
||||
|
||||
with feature_path.open('rb') as f:
|
||||
with feature_path.open("rb") as f:
|
||||
print(feature_path)
|
||||
feature_dataset = pickle.load(f)
|
||||
with backtest_path.open('rb') as f:
|
||||
with backtest_path.open("rb") as f:
|
||||
backtest_dataset = pickle.load(f)
|
||||
|
||||
_dataset = DataWrapper(
|
||||
@@ -127,5 +131,5 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
|
||||
backtest_dataset,
|
||||
config.feature_columns_today,
|
||||
config.feature_columns_yesterday,
|
||||
_internal=True
|
||||
_internal=True,
|
||||
)
|
||||
|
||||
@@ -31,17 +31,13 @@ def get_common_infra(
|
||||
) -> CommonInfrastructure:
|
||||
# need to specify a range here for acceleration
|
||||
if cash_limit is None:
|
||||
trade_account = Account(
|
||||
init_cash=int(1e12),
|
||||
benchmark_config={},
|
||||
pos_type='InfPosition'
|
||||
)
|
||||
trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition")
|
||||
else:
|
||||
trade_account = Account(
|
||||
init_cash=cash_limit,
|
||||
benchmark_config={},
|
||||
pos_type='Position',
|
||||
position_dict={code: {"amount": 1e12, "price": 1.} for code in codes}
|
||||
pos_type="Position",
|
||||
position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes},
|
||||
)
|
||||
|
||||
exchange = get_exchange(
|
||||
@@ -55,7 +51,7 @@ def get_common_infra(
|
||||
start_time=trade_date,
|
||||
end_time=trade_date + pd.DateOffset(1),
|
||||
trade_unit=config.trade_unit,
|
||||
volume_threshold=config.volume_threshold
|
||||
volume_threshold=config.volume_threshold,
|
||||
)
|
||||
|
||||
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
|
||||
@@ -145,17 +141,16 @@ class StateMaintainer:
|
||||
if len(execute_result) > 0:
|
||||
exchange = inner_executor.trade_exchange
|
||||
minutes = _get_minutes(execute_result[0][0].start_time, execute_result[-1][0].start_time)
|
||||
market_price = np.array([
|
||||
exchange.get_deal_price(execute_order.stock_id, t, t, direction=execute_order.direction)
|
||||
for t in minutes
|
||||
])
|
||||
market_price = np.array(
|
||||
[
|
||||
exchange.get_deal_price(execute_order.stock_id, t, t, direction=execute_order.direction)
|
||||
for t in minutes
|
||||
]
|
||||
)
|
||||
market_volume = np.array([exchange.get_volume(execute_order.stock_id, t, t) for t in minutes])
|
||||
|
||||
datetime_list = _get_ticks_slice(
|
||||
self._tick_index,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
include_end=True
|
||||
self._tick_index, execute_result[0][0].start_time, execute_result[-1][0].start_time, include_end=True
|
||||
)
|
||||
else:
|
||||
market_price = np.array([])
|
||||
@@ -188,9 +183,11 @@ class StateMaintainer:
|
||||
|
||||
self.history_steps = _dataframe_append(
|
||||
self.history_steps,
|
||||
[self._metrics_collect(
|
||||
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol
|
||||
)],
|
||||
[
|
||||
self._metrics_collect(
|
||||
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _metrics_collect(
|
||||
|
||||
@@ -29,11 +29,12 @@ qlib_config = QlibConfig(
|
||||
# fmt: on
|
||||
|
||||
exchange_config = ExchangeConfig(
|
||||
limit_threshold=('$ask == 0', '$bid == 0'),
|
||||
deal_price=('If($ask == 0, $bid, $ask)', 'If($bid == 0, $ask, $bid)'),
|
||||
limit_threshold=("$ask == 0", "$bid == 0"),
|
||||
deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
|
||||
volume_threshold={
|
||||
'all': ('cum', "0.2 * DayCumsum($volume, '9:45', '14:44')"),
|
||||
'buy': ('current', '$askV1'), 'sell': ('current', '$bidV1')
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
"sell": ("current", "$bidV1"),
|
||||
},
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
|
||||
Reference in New Issue
Block a user