1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00

Refine example

This commit is contained in:
Yuge Zhang
2021-06-07 10:56:12 +08:00
parent a06fa2bc44
commit 76be5d50e5

View File

@@ -319,6 +319,7 @@ class RLStrategy(BaseStrategy):
self.policy = policy
# TODO: how to get inner frequency and trade len
# This should be no longer required when PA is provided by qlib.
self.inner_frequency = "day"
self.inner_trade_len = 1
@@ -432,6 +433,12 @@ class RlWorkflow(NestedDecisonExecutionWorkflow):
min_cost=5
)
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
executor = NestedExecutor(
time_per_step="week",
inner_executor=SimulatorExecutor(time_per_step="day", verbose=True),
inner_strategy=RLStrategy(Observation("day"), Action(), DummyPolicy()),
common_infra=common_infra
)
strategy = init_instance_by_config({
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
@@ -442,12 +449,6 @@ class RlWorkflow(NestedDecisonExecutionWorkflow):
"n_drop": 5,
},
}, common_infra=common_infra)
executor = NestedExecutor(
time_per_step="week",
inner_executor=SimulatorExecutor(time_per_step="day", verbose=True),
inner_strategy=RLStrategy(Observation("day"), Action(), DummyPolicy()),
common_infra=common_infra
)
report_dict = backtest_func(trade_start_time, trade_end_time, strategy, executor)
print(report_dict)
@@ -463,7 +464,7 @@ class QlibOrderDataset(Dataset):
def __len__(self):
return len(self.orders)
def __getitem__(self, index):
def __getitem__(self, index) -> Order:
return self.orders[index]
@@ -535,7 +536,7 @@ class Action:
def validate(self, action: Any) -> bool:
return self.action_space.contains(action)
def to_volume(self, action: Any, ep_state: EpisodicState):
def to_volume(self, action: Any, ep_state: EpisodicState) -> Any:
exec_vol = ep_state.position / self.denominator * action
if ep_state.cur_step + 1 >= ep_state.num_step:
exec_vol = ep_state.position