mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
Add a few comments
This commit is contained in:
@@ -16,7 +16,7 @@ from qlib.tests.data import GetData
|
||||
from qlib.utils import init_instance_by_config, exists_qlib_data
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tianshou.data import Batch, Collector
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
@@ -51,7 +51,8 @@ def price_advantage(exec_price: float, baseline_price: float, direction: int) ->
|
||||
@dataclass
|
||||
class EpisodicState:
|
||||
"""
|
||||
A simplified data structure for RL-related components to process observations and rewards
|
||||
A simplified data structure as the input of RL-related components to calculate observations and rewards.
|
||||
Some of the metrics info are calculated on-the-fly in this class.
|
||||
"""
|
||||
# requirements
|
||||
stock_id: int
|
||||
@@ -181,6 +182,7 @@ class SingleOrderEnv(gym.Env):
|
||||
return self.observation.observation_space
|
||||
|
||||
def retrieve_backtest_data(self, field: str):
|
||||
# Retrieve backtest data for RL-specific use (including reward calculation)
|
||||
return D.features(
|
||||
[self.cur_order.stock_id],
|
||||
['$open', '$close', '$high', '$low', '$volume'],
|
||||
@@ -190,6 +192,7 @@ class SingleOrderEnv(gym.Env):
|
||||
)[field].to_numpy()
|
||||
|
||||
def initialize_state(self):
|
||||
# Synchronous state for executor to EpisodicState
|
||||
self.executor.reset(start_time=self.cur_order.start_time, end_time=self.cur_order.end_time)
|
||||
state = EpisodicState(
|
||||
stock_id=self.cur_order.stock_id,
|
||||
@@ -207,6 +210,7 @@ class SingleOrderEnv(gym.Env):
|
||||
return state
|
||||
|
||||
def update_state(self, exec_vol):
|
||||
# Synchronous exec_vol to executor and synchronous back to EpisodicState
|
||||
calendar = self.executor.trade_calendar
|
||||
state = self.ep_state
|
||||
|
||||
@@ -273,6 +277,7 @@ class SingleOrderEnv(gym.Env):
|
||||
'ins': self.ep_state.stock_id,
|
||||
'date': self.ep_state.start_time,
|
||||
}
|
||||
# TODO: collect logs
|
||||
pprint(info)
|
||||
|
||||
return self.observation(self.ep_state), reward, self.ep_state.done, info
|
||||
@@ -327,13 +332,18 @@ def _main():
|
||||
)
|
||||
return SingleOrderEnv(
|
||||
observation, action, reward_fn,
|
||||
iter(DataLoader(QlibOrderDataset('rl.pkl'), batch_size=None, shuffle=True)), executor)
|
||||
iter(DataLoader(QlibOrderDataset('rl_orders'), batch_size=None, shuffle=True)), executor)
|
||||
|
||||
policy = DummyPolicy()
|
||||
|
||||
# This can not be replaced with SubprocVectorEnv
|
||||
# File "/xxx/qlib/qlib/data/data.py", line 462, in dataset_processor
|
||||
# p = Pool(processes=workers)
|
||||
# AssertionError: daemonic processes are not allowed to have children
|
||||
envs = DummyVectorEnv([dummy_env for _ in range(4)])
|
||||
test_collector = Collector(policy, envs)
|
||||
policy.eval()
|
||||
# TODO: create a queue for all orders and make it auto-complete when all the orders are processed
|
||||
test_collector.collect(n_episode=10)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user