mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
* rl init * aux info * Reward config * update * simple * update saoe init * update simulator and seed * minor * minor * update sim * checkpoint * obs * Update interpreter * init qlib simulator * checkpoint * Refine codebase * checkpoint * checkpoint * Add one test * More tests * Simulator checkpoint * checkpoint * First-step tested * Checkpoint * Update data_queue API * Checkpoint * Update test * Move files * Checkpoint * Single-quote -> double-quote * Fix finite env tests * Tested with mypy * pep-574 * No call for env done * Update finite env docs * Fix csv writer * Refine tester * Update logger * Add another logger test * Checkpoint * Add network sanity test * steps per episode is not correct * Cleanup code, ready for PR * Reformat with black * Fix pylint for py37 * Fix lint * Fix lint * Fix flake * update mypy command * mypy * Update exclude pattern * Use pyproject.toml * test * . * . * Refactor pipeline * . * defaults run bash * . * Revert and skip follow_imports * Fix toml issue * fix mypy * . * . * . * Fix install * Minor fix * Fix test * Fix test * Remove requirements * Revert * fix tests * Fix lint * . * . * . * . * . * update install from source command * . * Fix data download * . * . * . * . * . * . * Fix py37 * Ignore tests on non-linux * resolve comments * fix tests * resolve comments * some typo * style updates * More comments * fix dummy * add warning * Align precision in some system * Added some impl notes Co-authored-by: Young <afe.young@gmail.com>
119 lines
4.1 KiB
Python
119 lines
4.1 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import cast
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from tianshou.data import Batch
|
|
|
|
from qlib.typehint import Literal
|
|
from .interpreter import FullHistoryObs
|
|
|
|
__all__ = ["Recurrent"]
|
|
|
|
|
|
class Recurrent(nn.Module):
|
|
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
|
|
|
|
At every timestep the input of policy network is divided into two parts,
|
|
the public variables and the private variables. which are handled by ``raw_rnn``
|
|
and ``pri_rnn`` in this network, respectively.
|
|
|
|
One minor difference is that, in this implementation, we don't assume the direction to be fixed.
|
|
Thus, another ``dire_fc`` is added to produce an extra direction-related feature.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
obs_space: FullHistoryObs,
|
|
hidden_dim: int = 64,
|
|
output_dim: int = 32,
|
|
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
|
|
rnn_num_layers: int = 1,
|
|
):
|
|
super().__init__()
|
|
|
|
self.hidden_dim = hidden_dim
|
|
self.output_dim = output_dim
|
|
self.num_sources = 3
|
|
|
|
rnn_classes = {"rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU}
|
|
|
|
self.rnn_class = rnn_classes[rnn_type]
|
|
self.rnn_layers = rnn_num_layers
|
|
|
|
self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
|
|
self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
|
|
self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers)
|
|
|
|
self.raw_fc = nn.Sequential(nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU())
|
|
self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU())
|
|
self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
|
|
|
|
self._init_extra_branches()
|
|
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(hidden_dim * self.num_sources, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, output_dim),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def _init_extra_branches(self):
|
|
pass
|
|
|
|
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
bs, _, data_dim = obs["data_processed"].size()
|
|
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
|
|
cur_step = obs["cur_step"].long()
|
|
cur_tick = obs["cur_tick"].long()
|
|
bs_indices = torch.arange(bs, device=device)
|
|
|
|
position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step]
|
|
steps = (
|
|
torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float()
|
|
/ obs["num_step"].unsqueeze(-1).float()
|
|
) # [bs, num_step]
|
|
priv = torch.stack((position.float(), steps), -1)
|
|
|
|
data_in = self.raw_fc(data)
|
|
data_out, _ = self.raw_rnn(data_in)
|
|
# as it is padded with zero in front, this should be last minute
|
|
data_out_slice = data_out[bs_indices, cur_tick]
|
|
|
|
priv_in = self.pri_fc(priv)
|
|
priv_out = self.pri_rnn(priv_in)[0]
|
|
priv_out = priv_out[bs_indices, cur_step]
|
|
|
|
sources = [data_out_slice, priv_out]
|
|
|
|
dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float())
|
|
sources.append(dir_out)
|
|
|
|
return sources, data_out
|
|
|
|
def forward(self, batch: Batch) -> torch.Tensor:
|
|
"""
|
|
Input should be a dict (at least) containing:
|
|
|
|
- data_processed: [N, T, C]
|
|
- cur_step: [N] (int)
|
|
- cur_time: [N] (int)
|
|
- position_history: [N, S] (S is number of steps)
|
|
- target: [N]
|
|
- num_step: [N] (int)
|
|
- acquiring: [N] (0 or 1)
|
|
"""
|
|
|
|
inp = cast(FullHistoryObs, batch)
|
|
device = inp["data_processed"].device
|
|
|
|
sources, _ = self._source_features(inp, device)
|
|
assert len(sources) == self.num_sources
|
|
|
|
out = torch.cat(sources, -1)
|
|
return self.fc(out)
|