1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00

Black format

This commit is contained in:
Huoran Li
2022-07-08 14:52:32 +08:00
parent d8858ba445
commit 09f51061e1
5 changed files with 44 additions and 40 deletions

View File

@@ -531,7 +531,9 @@ class TradeDecisionWO(BaseTradeDecision):
Besides, the time_range is also included.
"""
def __init__(self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
def __init__(
self, order_list: List[Order], strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None
):
super().__init__(strategy, trade_range=trade_range)
self.order_list = order_list
start, end = strategy.trade_calendar.get_step_time()

View File

@@ -10,7 +10,7 @@ class ExchangeConfig:
volume_threshold: dict
open_cost: float = 0.0005
close_cost: float = 0.0015
min_cost: float = 5.
trade_unit: Optional[float] = 100.
min_cost: float = 5.0
trade_unit: Optional[float] = 100.0
cash_limit: Optional[Union[Path, float]] = None
generate_report: bool = False

View File

@@ -34,10 +34,15 @@ class LRUCache:
class DataWrapper:
def __init__(self, feature_dataset: DatasetH, backtest_dataset: DatasetH,
columns_today: List[str], columns_yesterday: List[str], _internal: bool = False):
assert _internal, 'Init function of data wrapper is for internal use only.'
def __init__(
self,
feature_dataset: DatasetH,
backtest_dataset: DatasetH,
columns_today: List[str],
columns_yesterday: List[str],
_internal: bool = False,
):
assert _internal, "Init function of data wrapper is for internal use only."
self.feature_dataset = feature_dataset
self.backtest_dataset = backtest_dataset
@@ -76,8 +81,7 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
qlib.init(
region=REG_CN,
auto_mount=False,
custom_ops=[DayLast, FFillNan, BFillNan,
Date, Select, IsNull, IsInf, Cut, DayCumsum],
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],
expression_cache=None,
calendar_provider={
"class": "LocalCalendarProvider",
@@ -104,22 +108,22 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
provider_uri=provider_uri_map,
kernels=1,
redis_port=-1,
clear_mem_cache=False # init_qlib will be called for multiple times. Keep the cache for improving performance
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
)
# this won't work if it's put outside in case of multiprocessing
if part is None:
feature_path = config.feature_root_dir / 'feature.pkl'
backtest_path = config.feature_root_dir / 'backtest.pkl'
feature_path = config.feature_root_dir / "feature.pkl"
backtest_path = config.feature_root_dir / "backtest.pkl"
else:
feature_path = config.feature_root_dir / 'feature' / (part + '.pkl')
backtest_path = config.feature_root_dir / 'backtest' / (part + '.pkl')
feature_path = config.feature_root_dir / "feature" / (part + ".pkl")
backtest_path = config.feature_root_dir / "backtest" / (part + ".pkl")
with feature_path.open('rb') as f:
with feature_path.open("rb") as f:
print(feature_path)
feature_dataset = pickle.load(f)
with backtest_path.open('rb') as f:
with backtest_path.open("rb") as f:
backtest_dataset = pickle.load(f)
_dataset = DataWrapper(
@@ -127,5 +131,5 @@ def init_qlib(config: QlibConfig, part: Optional[str] = None) -> None:
backtest_dataset,
config.feature_columns_today,
config.feature_columns_yesterday,
_internal=True
_internal=True,
)

View File

@@ -31,17 +31,13 @@ def get_common_infra(
) -> CommonInfrastructure:
# need to specify a range here for acceleration
if cash_limit is None:
trade_account = Account(
init_cash=int(1e12),
benchmark_config={},
pos_type='InfPosition'
)
trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition")
else:
trade_account = Account(
init_cash=cash_limit,
benchmark_config={},
pos_type='Position',
position_dict={code: {"amount": 1e12, "price": 1.} for code in codes}
pos_type="Position",
position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes},
)
exchange = get_exchange(
@@ -55,7 +51,7 @@ def get_common_infra(
start_time=trade_date,
end_time=trade_date + pd.DateOffset(1),
trade_unit=config.trade_unit,
volume_threshold=config.volume_threshold
volume_threshold=config.volume_threshold,
)
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
@@ -145,17 +141,16 @@ class StateMaintainer:
if len(execute_result) > 0:
exchange = inner_executor.trade_exchange
minutes = _get_minutes(execute_result[0][0].start_time, execute_result[-1][0].start_time)
market_price = np.array([
exchange.get_deal_price(execute_order.stock_id, t, t, direction=execute_order.direction)
for t in minutes
])
market_price = np.array(
[
exchange.get_deal_price(execute_order.stock_id, t, t, direction=execute_order.direction)
for t in minutes
]
)
market_volume = np.array([exchange.get_volume(execute_order.stock_id, t, t) for t in minutes])
datetime_list = _get_ticks_slice(
self._tick_index,
execute_result[0][0].start_time,
execute_result[-1][0].start_time,
include_end=True
self._tick_index, execute_result[0][0].start_time, execute_result[-1][0].start_time, include_end=True
)
else:
market_price = np.array([])
@@ -188,9 +183,11 @@ class StateMaintainer:
self.history_steps = _dataframe_append(
self.history_steps,
[self._metrics_collect(
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol
)],
[
self._metrics_collect(
execute_order, execute_order.start_time, market_volume, market_price, exec_vol.sum(), exec_vol
)
],
)
def _metrics_collect(

View File

@@ -29,11 +29,12 @@ qlib_config = QlibConfig(
# fmt: on
exchange_config = ExchangeConfig(
limit_threshold=('$ask == 0', '$bid == 0'),
deal_price=('If($ask == 0, $bid, $ask)', 'If($bid == 0, $ask, $bid)'),
limit_threshold=("$ask == 0", "$bid == 0"),
deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
volume_threshold={
'all': ('cum', "0.2 * DayCumsum($volume, '9:45', '14:44')"),
'buy': ('current', '$askV1'), 'sell': ('current', '$bidV1')
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"),
"buy": ("current", "$askV1"),
"sell": ("current", "$bidV1"),
},
open_cost=0.0005,
close_cost=0.0015,