diff --git a/examples/nested_decision_execution/rl_dummy.py b/examples/nested_decision_execution/rl_dummy.py index cd0961f66..c42e28be4 100644 --- a/examples/nested_decision_execution/rl_dummy.py +++ b/examples/nested_decision_execution/rl_dummy.py @@ -319,6 +319,7 @@ class RLStrategy(BaseStrategy): self.policy = policy # TODO: how to get inner frequency and trade len + # This should be no longer required when PA is provided by qlib. self.inner_frequency = "day" self.inner_trade_len = 1 @@ -432,6 +433,12 @@ class RlWorkflow(NestedDecisonExecutionWorkflow): min_cost=5 ) common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange) + executor = NestedExecutor( + time_per_step="week", + inner_executor=SimulatorExecutor(time_per_step="day", verbose=True), + inner_strategy=RLStrategy(Observation("day"), Action(), DummyPolicy()), + common_infra=common_infra + ) strategy = init_instance_by_config({ "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.model_strategy", @@ -442,12 +449,6 @@ class RlWorkflow(NestedDecisonExecutionWorkflow): "n_drop": 5, }, }, common_infra=common_infra) - executor = NestedExecutor( - time_per_step="week", - inner_executor=SimulatorExecutor(time_per_step="day", verbose=True), - inner_strategy=RLStrategy(Observation("day"), Action(), DummyPolicy()), - common_infra=common_infra - ) report_dict = backtest_func(trade_start_time, trade_end_time, strategy, executor) print(report_dict) @@ -463,7 +464,7 @@ class QlibOrderDataset(Dataset): def __len__(self): return len(self.orders) - def __getitem__(self, index): + def __getitem__(self, index) -> Order: return self.orders[index] @@ -535,7 +536,7 @@ class Action: def validate(self, action: Any) -> bool: return self.action_space.contains(action) - def to_volume(self, action: Any, ep_state: EpisodicState): + def to_volume(self, action: Any, ep_state: EpisodicState) -> Any: exec_vol = ep_state.position / self.denominator * action if ep_state.cur_step + 1 >= ep_state.num_step: exec_vol = ep_state.position