1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

refactor TRA

This commit is contained in:
Dong Zhou
2021-07-21 13:19:07 +08:00
committed by you-n-g
parent 9303415666
commit 07655f2d5b
6 changed files with 1535 additions and 0 deletions

View File

@@ -0,0 +1,125 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
model_config: &model_config
input_size: 20
hidden_size: 64
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.0
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
lr: 1e-3
n_epochs: 100
max_steps_per_epoch: 100
early_stop: 10
smooth_steps: 5
seed: 0
logdir: output/Alpha158/router
lamb: 1.0
rho: 1.0
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: False
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size:
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,118 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: TPE
model_config: &model_config
input_size: 158
hidden_size: 256
num_layers: 2
use_attn: True
dropout: 0.2
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
lr: 1e-3
n_epochs: 100
max_steps_per_epoch: 100
early_stop: 10
smooth_steps: 5
seed: 0
logdir: output/Alpha158_full/router
lamb: 1.0
rho: 1.0
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: False
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size:
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,119 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
num_states: &num_states 3
memory_mode: &memory_mode sample
tra_config: &tra_config
num_states: *num_states
hidden_size: 16
tau: 1.0
src_info: LR_TPE
model_config: &model_config
input_size: 6
hidden_size: 64
num_layers: 2
rnn_arch: LSTM
use_attn: True
dropout: 0.0
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TRAModel
module_path: qlib.contrib.model.pytorch_tra
kwargs:
tra_config: *tra_config
model_config: *model_config
lr: 1e-3
n_epochs: 100
max_steps_per_epoch: 100
early_stop: 10
smooth_steps: 5
logdir: output/Alpha360/router
seed: 0
lamb: 1.0
rho: 1.0
transport_method: router
memory_mode: *memory_mode
eval_train: False
eval_test: True
pretrain: False
init_state:
freeze_model: False
freeze_predictors: False
dataset:
class: MTSDatasetH
module_path: qlib.contrib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
seq_len: 60
horizon: 2
input_size: 6
num_states: *num_states
batch_size: 1024
n_samples:
memory_mode: *memory_mode
drop_last: True
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,349 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import torch
import warnings
import numpy as np
import pandas as pd
from qlib.utils import init_instance_by_config
from qlib.data.dataset import DatasetH, DataHandler
device = "cuda" if torch.cuda.is_available() else "cpu"
def _to_tensor(x):
if not isinstance(x, torch.Tensor):
return torch.tensor(x, dtype=torch.float, device=device)
return x
def _create_ts_slices(index, seq_len):
"""
create time series slices from pandas index
Args:
index (pd.MultiIndex): pandas multiindex with <instrument, datetime> order
seq_len (int): sequence length
"""
assert isinstance(index, pd.MultiIndex), "unsupported index type"
assert seq_len > 0, "sequence length should be larger than 0"
assert index.is_monotonic_increasing, "index should be sorted"
# number of dates for each instrument
sample_count_by_insts = index.to_series().groupby(level=0).size().values
# start index for each instrument
start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1)
start_index_of_insts[0] = 0
# all the [start, stop) indices of features
# features between [start, stop) will be used to predict label at `stop - 1`
slices = []
for cur_loc, cur_cnt in zip(start_index_of_insts, sample_count_by_insts):
for stop in range(1, cur_cnt + 1):
end = cur_loc + stop
start = max(end - seq_len, 0)
slices.append(slice(start, end))
slices = np.array(slices, dtype="object")
assert len(slices) == len(index) # the i-th slice = index[i]
return slices
def _get_date_parse_fn(target):
"""get date parse function
This method is used to parse date arguments as target type.
Example:
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
get_date_parse_fn(20120101)('2017-01-01') => 20170101
"""
if isinstance(target, pd.Timestamp):
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
elif isinstance(target, int):
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
elif isinstance(target, str) and len(target) == 8:
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
else:
_fn = lambda x: x # '2021-01-01'
return _fn
def _maybe_padding(x, seq_len, zeros=None):
"""padding 2d <time * feature> data with zeros
Args:
x (np.ndarray): 2d data with shape <time * feature>
seq_len (int): target sequence length
zeros (np.ndarray): zeros with shape <seq_len * feature>
"""
assert seq_len > 0, "sequence length should be larger than 0"
if zeros is None:
zeros = np.zeros((seq_len, x.shape[1]), dtype=np.float32)
else:
assert len(zeros) >= seq_len, "zeros matrix is not large enough for padding"
if len(x) != seq_len: # padding zeros
x = np.concatenate([zeros[: seq_len - len(x), : x.shape[1]], x], axis=0)
return x
class MTSDatasetH(DatasetH):
"""Memory Augmented Time Series Dataset
Args:
handler (DataHandler): data handler
segments (dict): data split segments
seq_len (int): time series sequence length
horizon (int): label horizon
num_states (int): how many memory states to be added
memory_mode (str): memory mode (daily or sample)
batch_size (int): batch size (<0 will use daily sampling)
n_samples (int): number of samples in the same day
shuffle (bool): whether shuffle data
drop_last (bool): whether drop last batch < batch_size
input_size (int): reshape flatten rows as this input_size (backward compatibility)
"""
def __init__(
self,
handler,
segments,
seq_len=60,
horizon=0,
num_states=0,
memory_mode="sample",
batch_size=-1,
n_samples=None,
shuffle=True,
drop_last=False,
input_size=None,
**kwargs
):
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
assert memory_mode in ["sample", "daily"], "unsupported memory mode"
assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)"
assert batch_size != 0, "invalid batch size"
if batch_size > 0 and n_samples is not None:
warnings.warn("`n_samples` can only be used for daily sampling (`batch_size < 0`)")
self.seq_len = seq_len
self.horizon = horizon
self.num_states = num_states
self.memory_mode = memory_mode
self.batch_size = batch_size
self.n_samples = n_samples
self.shuffle = shuffle
self.drop_last = drop_last
self.input_size = input_size
self.params = (batch_size, n_samples, drop_last, shuffle) # for train/eval switch
super().__init__(handler, segments, **kwargs)
def setup_data(self, handler_kwargs: dict = None, **kwargs):
super().setup_data(**kwargs)
if handler_kwargs is not None:
self.handler.setup_data(**handler_kwargs)
# pre-fetch data and change index to <code, date>
# NOTE: we will use inplace sort to reduce memory use
try:
df = self.handler._learn.copy() # use copy otherwise recorder will fail
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
except:
warnings.warn("cannot access `_learn`, will load raw data")
df = self.handler._data.copy()
df.index = df.index.swaplevel()
df.sort_index(inplace=True)
# convert to numpy
self._data = df["feature"].values.astype("float32")
np.nan_to_num(self._data, copy=False) # NOTE: fillna in case users forget using the fillna processor
self._label = df["label"].squeeze().values.astype("float32")
self._index = df.index
if self.input_size is not None and self.input_size != self._data.shape[1]:
warnings.warn("the data has different shape from input_size and the data will be reshaped")
assert self._data.shape[1] % self.input_size == 0, "data mismatch, please check `input_size`"
# create batch slices
self._batch_slices = _create_ts_slices(self._index, self.seq_len)
# create daily slices
daily_slices = {date: [] for date in sorted(self._index.unique(level=1))} # sorted by date
for i, (code, date) in enumerate(self._index):
daily_slices[date].append(self._batch_slices[i])
self._daily_slices = np.array(list(daily_slices.values()), dtype="object")
self._daily_index = pd.Series(list(daily_slices.keys())) # index is the original date index
# add memory (sample wise and daily)
if self.memory_mode == "sample":
self._memory = np.zeros((len(self._data), self.num_states), dtype=np.float32)
elif self.memory_mode == "daily":
self._memory = np.zeros((len(self._daily_index), self.num_states), dtype=np.float32)
else:
raise ValueError(f"invalid memory_mode `{self.memory_mode}`")
# padding tensor
self._zeros = np.zeros((self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32)
def _prepare_seg(self, slc, **kwargs):
fn = _get_date_parse_fn(self._index[0][1])
start_date = fn(slc.start)
end_date = fn(slc.stop)
obj = copy.copy(self) # shallow copy
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
obj._data = self._data # reference (no copy)
obj._label = self._label
obj._index = self._index
obj._memory = self._memory
obj._zeros = self._zeros
# update index for this batch
date_index = self._index.get_level_values(1)
obj._batch_slices = self._batch_slices[(date_index >= start_date) & (date_index <= end_date)]
mask = (self._daily_index.values >= start_date) & (self._daily_index.values <= end_date)
obj._daily_slices = self._daily_slices[mask]
obj._daily_index = self._daily_index[mask]
return obj
def restore_index(self, index):
return self._index[index]
def restore_daily_index(self, daily_index):
return pd.Index(self._daily_index.loc[daily_index])
def assign_data(self, index, vals):
if self.num_states == 0:
raise ValueError("cannot assign data as `num_states==0`")
if isinstance(vals, torch.Tensor):
vals = vals.detach().cpu().numpy()
# if isinstance(index, pd.Series):
# index = index.index # daily batch use Series to store index
self._memory[index] = vals
def clear_memory(self):
if self.num_states == 0:
raise ValueError("cannot clear memory as `num_states==0`")
self._memory[:] = 0
# TODO: better train/eval mode design
def train(self):
"""enable traning mode"""
self.batch_size, self.n_samples, self.drop_last, self.shuffle = self.params
def eval(self):
"""enable evaluation mode"""
self.batch_size = -1
self.n_samples = None
self.drop_last = False
self.shuffle = False
def _get_slices(self):
if self.batch_size < 0: # daily sampling
slices = self._daily_slices.copy()
batch_size = -1 * self.batch_size
else: # normal sampling
slices = self._batch_slices.copy()
batch_size = self.batch_size
return slices, batch_size
def __len__(self):
slices, batch_size = self._get_slices()
if self.drop_last:
return len(slices) // batch_size
return (len(slices) + batch_size - 1) // batch_size
def __iter__(self):
slices, batch_size = self._get_slices()
indices = np.arange(len(slices))
if self.shuffle:
np.random.shuffle(indices)
for i in range(len(indices))[::batch_size]:
if self.drop_last and i + batch_size > len(indices):
break
data = [] # store features
label = [] # store labels
index = [] # store index
state = [] # store memory states
daily_index = [] # store daily index
daily_count = [] # store number of samples for each day
for j in indices[i : i + batch_size]:
# normal sampling: self.batch_size > 0 => slices is a list => slices_subset is a slice
# daily sampling: self.batch_size < 0 => slices is a nested list => slices_subset is a list
slices_subset = slices[j]
# daily sampling
# each slices_subset contains a list of slices for multiple stocks
# NOTE: daily sampling is used in 1) eval mode, 2) train mode with self.batch_size < 0
if self.batch_size < 0:
# store daily index
idx = self._daily_index.index[j] # daily_index.index is the index of the original data
daily_index.append(idx)
# store daily memory if specified
# NOTE: daily memory always requires daily sampling (self.batch_size < 0)
if self.memory_mode == "daily":
slc = slice(max(idx - self.seq_len - self.horizon, 0), max(idx - self.horizon, 0))
state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros))
# down-sample stocks and store count
if self.n_samples and 0 < self.n_samples < len(slices_subset): # intraday subsample
slices_subset = np.random.choice(slices_subset, self.n_samples, replace=False)
daily_count.append(len(slices_subset))
# normal sampling
# each slices_subset is a single slice
# NOTE: normal sampling is used in train mode with self.batch_size > 0
else:
slices_subset = [slices_subset]
for slc in slices_subset:
# legacy support for Alpha360 data by `input_size`
if self.input_size:
data.append(self._data[slc.stop - 1].reshape(self.input_size, -1).T)
else:
data.append(_maybe_padding(self._data[slc], self.seq_len, self._zeros))
if self.memory_mode == "sample":
state.append(_maybe_padding(self._memory[slc][: -self.horizon], self.seq_len, self._zeros))
label.append(self._label[slc.stop - 1])
index.append(slc.stop - 1)
# end slices loop
# end indices batch loop
# concate
data = _to_tensor(np.stack(data))
state = _to_tensor(np.stack(state))
label = _to_tensor(np.stack(label))
index = np.array(index)
daily_index = np.array(daily_index)
daily_count = np.array(daily_count)
# yield -> generator
yield {
"data": data,
"label": label,
"state": state,
"index": index,
"daily_index": daily_index,
"daily_count": daily_count,
}
# end indice loop

View File

@@ -0,0 +1,820 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import copy
import math
import json
import collections
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
try:
from torch.utils.tensorboard import SummaryWriter
except:
SummaryWriter = None
from tqdm import tqdm
from qlib.utils import get_or_create_path
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.contrib.data.dataset import MTSDatasetH
device = "cuda" if torch.cuda.is_available() else "cpu"
class TRAModel(Model):
"""
TRA Model
Args:
model_config (dict): model config (will be used by RNN or Transformer)
tra_config (dict): TRA config (will be used by TRA)
model_type (str): which backbone model to use (RNN/Transformer)
lr (float): learning rate
n_epochs (int): number of total epochs
early_stop (int): early stop when performance not improved at this step
smooth_steps (int): number of steps for parameter smoothing
max_steps_per_epoch (int): maximum number of steps in one epoch
lamb (float): regularization parameter
rho (float): exponential decay rate for `lamb`
seed (int): random seed
logdir (str): local log directory
eval_train (bool): whether evaluate train set between epochs
eval_test (bool): whether evaluate test set between epochs
pretrain (bool): whether pretrain the backbone model before training TRA.
Note that only TRA will be optimized after pretraining
init_state (str): model init state path
freeze_model (bool): whether freeze backbone model parameters
freeze_predictors (bool): whether freeze predictors parameters
transport_method (str): transport method, can be none/router/oracle
memory_mode (str): memory mode, the same argument for MTSDatasetH
"""
def __init__(
self,
model_config,
tra_config,
model_type="RNN",
lr=1e-3,
n_epochs=500,
early_stop=50,
smooth_steps=5,
max_steps_per_epoch=None,
lamb=0.0,
rho=0.99,
seed=0,
logdir=None,
eval_train=False,
eval_test=False,
pretrain=False,
init_state=None,
freeze_model=False,
freeze_predictors=False,
transport_method="none",
memory_mode="sample",
):
self.logger = get_module_logger("TRA")
assert memory_mode in ["sample", "daily"], "invalid memory mode"
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
assert transport_method == "none" or tra_config['num_states'] > 1, "optimal transport requires `num_states` > 1"
assert memory_mode != "daily" or tra_config['src_info'] == 'TPE', "daily transport can only support TPE as `src_info`"
if transport_method == "router" and not eval_train:
self.logger.warning("`eval_train` will be ignored when using TRA.router")
np.random.seed(seed)
torch.manual_seed(seed)
self.model_config = model_config
self.tra_config = tra_config
self.model_type = model_type
self.lr = lr
self.n_epochs = n_epochs
self.early_stop = early_stop
self.smooth_steps = smooth_steps
self.max_steps_per_epoch = max_steps_per_epoch
self.lamb = lamb
self.rho = rho
self.seed = seed
self.logdir = logdir
self.eval_train = eval_train
self.eval_test = eval_test
self.pretrain = pretrain
self.init_state = init_state
self.freeze_model = freeze_model
self.freeze_predictors = freeze_predictors
self.transport_method = transport_method
self.use_daily_transport = memory_mode == "daily"
self.transport_fn = transport_daily if self.use_daily_transport else transport_sample
self._writer = None
if self.logdir is not None:
if os.path.exists(self.logdir):
self.logger.warning(f"logdir {self.logdir} is not empty")
os.makedirs(self.logdir, exist_ok=True)
if SummaryWriter is not None:
self._writer = SummaryWriter(log_dir=self.logdir)
self._init_model()
def _init_model(self):
self.logger.info("init TRAModel...")
self.model = eval(self.model_type)(**self.model_config).to(device)
print(self.model)
self.tra = TRA(self.model.output_size, **self.tra_config).to(device)
print(self.tra)
if self.init_state:
self.logger.warninging(f"load state dict from `init_state`")
state_dict = torch.load(self.init_state, map_location="cpu")
self.model.load_state_dict(state_dict["model"])
try:
self.tra.load_state_dict(state_dict["tra"])
except:
self.logger.warninging("cannot load tra model, will skip")
if self.freeze_model:
self.logger.warninging(f"freeze model parameters")
for param in self.model.parameters():
param.requires_grad_(False)
if self.freeze_predictors:
self.logger.warninging(f"freeze TRA.predictors parameters")
for param in self.tra.predictors.parameters():
param.requires_grad_(False)
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters() if p.requires_grad]))
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters() if p.requires_grad]))
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
self.fitted = False
self.global_step = -1
def train_epoch(self, epoch, data_set, is_pretrain=False):
self.model.train()
self.tra.train()
data_set.train()
max_steps = len(data_set)
if self.max_steps_per_epoch is not None:
if epoch == 0 and self.max_steps_per_epoch < max_steps:
self.logger.info(f"max steps updated from {max_steps} to {self.max_steps_per_epoch}")
max_steps = min(self.max_steps_per_epoch, max_steps)
cur_step = 0
total_loss = 0
total_count = 0
for batch in tqdm(data_set, total=max_steps):
cur_step += 1
if cur_step > max_steps:
break
self.global_step += 1
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
hidden = self.model(data)
all_preds, choice, prob = self.tra(hidden, state)
if not is_pretrain and self.transport_method != "none":
loss, pred, L, P = self.transport_fn(
all_preds, label, choice, prob, count, self.transport_method, training=True
)
data_set.assign_data(index, L) # save loss to memory
lamb = self.lamb * (self.rho ** self.global_step) # regularization decay
reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict OT assignment
if self._writer is not None:
self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step)
self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step)
self._writer.add_scalar("training/lamb", lamb, self.global_step)
prob_mean = prob.mean(axis=0).detach()
self._writer.add_scalar("training/prob_max", prob_mean.max(), self.global_step)
self._writer.add_scalar("training/prob_min", prob_mean.min(), self.global_step)
P_mean = P.mean(axis=0).detach()
self._writer.add_scalar("training/P_max", P_mean.max(), self.global_step)
self._writer.add_scalar("training/P_min", P_mean.min(), self.global_step)
loss = loss - lamb * reg
else:
pred = all_preds.mean(dim=1)
loss = loss_fn(pred, label)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
if self._writer is not None:
self._writer.add_scalar("training/total_loss", loss.item(), self.global_step)
total_loss += loss.item()
total_count += 1
total_loss /= total_count
if self._writer is not None:
self._writer.add_scalar("training/loss", total_loss, epoch)
return total_loss
def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False):
self.model.eval()
self.tra.eval()
data_set.eval()
preds = []
probs = []
metrics = []
for batch in tqdm(data_set):
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
with torch.no_grad():
hidden = self.model(data)
all_preds, choice, prob = self.tra(hidden, state)
if not is_pretrain and self.transport_method != "none":
loss, pred, L, P = self.transport_fn(all_preds, label, choice, prob, count, self.transport_method, training=False)
data_set.assign_data(index, L) # save loss to memory
else:
pred = all_preds.mean(dim=1)
X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()]
columns = ["score", "label"] + ["score_%d" % d for d in range(all_preds.shape[1])]
pred = pd.DataFrame(X, index=batch["index"], columns=columns)
metrics.append(evaluate(pred))
if return_pred:
preds.append(pred)
if prob is not None:
columns = ["prob_%d" % d for d in range(all_preds.shape[1])]
probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns))
metrics = pd.DataFrame(metrics)
metrics = {
"MSE": metrics.MSE.mean(),
"MAE": metrics.MAE.mean(),
"IC": metrics.IC.mean(),
"ICIR": metrics.IC.mean() / metrics.IC.std(),
}
if self._writer is not None and epoch >= 0:
for key, value in metrics.items():
self._writer.add_scalar(prefix + "/" + key, value, epoch)
if return_pred:
preds = pd.concat(preds, axis=0)
preds.index = data_set.restore_index(preds.index)
preds.index = preds.index.swaplevel()
preds.sort_index(inplace=True)
if probs:
probs = pd.concat(probs, axis=0)
if self.use_daily_transport:
probs.index = data_set.restore_daily_index(probs.index)
else:
probs.index = data_set.restore_index(probs.index)
probs.index = probs.index.swaplevel()
probs.sort_index(inplace=True)
return metrics, preds, probs
def _fit(self, train_set, valid_set, test_set, evals_result, start_epoch=0, is_pretrain=True):
best_score = -1
best_epoch = 0
stop_rounds = 0
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
params_list = {
"model": collections.deque(maxlen=self.smooth_steps),
"tra": collections.deque(maxlen=self.smooth_steps),
}
# train
if not is_pretrain and self.transport_method == "router":
self.logger.info("init memory...")
self.test_epoch(-1, train_set)
for epoch in range(start_epoch, start_epoch + self.n_epochs):
self.logger.info("Epoch %d:", epoch)
self.logger.info("training...")
self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)
self.logger.info("evaluating...")
# average params for inference
params_list["model"].append(copy.deepcopy(self.model.state_dict()))
params_list["tra"].append(copy.deepcopy(self.tra.state_dict()))
self.model.load_state_dict(average_params(params_list["model"]))
self.tra.load_state_dict(average_params(params_list["tra"]))
# NOTE: during evaluating, the whole memory will be refreshed
if not is_pretrain and (self.transport_method == "router" or self.eval_train):
train_set.clear_memory() # NOTE: clear the shared memory
train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix="train")[0]
evals_result["train"].append(train_metrics)
self.logger.info("train metrics: %s" % train_metrics)
valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix="valid")[0]
evals_result["valid"].append(valid_metrics)
self.logger.info("valid metrics: %s" % valid_metrics)
if self.eval_test:
test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix="test")[0]
evals_result["test"].append(test_metrics)
self.logger.info("test metrics: %s" % test_metrics)
if valid_metrics["IC"] > best_score:
best_score = valid_metrics["IC"]
stop_rounds = 0
best_epoch = epoch
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
torch.save(best_params, self.logdir + "/model.bin")
else:
stop_rounds += 1
if stop_rounds >= self.early_stop:
self.logger.info("early stop @ %s" % epoch)
break
# restore parameters
self.model.load_state_dict(params_list["model"][-1])
self.tra.load_state_dict(params_list["tra"][-1])
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_params["model"])
self.tra.load_state_dict(best_params["tra"])
return best_score, epoch
def fit(self, dataset, evals_result=dict()):
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
self.fitted = True
self.global_step = -1
evals_result["train"] = []
evals_result["valid"] = []
evals_result["test"] = []
epoch = 0
if self.pretrain:
self.logger.info("pretraining...")
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
_, epoch = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
self.logger.info("reset TRA")
self.tra.reset_parameters() # reset both router and predictors
self.optimizer = optim.Adam(self.tra.parameters(), lr=self.lr) # optimize TRA only
self.logger.info("training...")
best_score, _ = self._fit(train_set, valid_set, test_set, evals_result, start_epoch=epoch, is_pretrain=False)
self.logger.info("inference")
train_metrics, train_preds, train_probs = self.test_epoch(-1, train_set, return_pred=True)
self.logger.info("train metrics: %s" % train_metrics)
valid_metrics, valid_preds, valid_probs = self.test_epoch(-1, valid_set, return_pred=True)
self.logger.info("valid metrics: %s" % valid_metrics)
test_metrics, test_preds, test_probs = self.test_epoch(-1, test_set, return_pred=True)
self.logger.info("test metrics: %s" % test_metrics)
if self.logdir:
self.logger.info("save model & pred to local directory")
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
self.logdir + "/logs.csv", index=False
)
torch.save({"model": self.model.state_dict(), "tra": self.tra.state_dict()}, self.logdir + "/model.bin")
train_preds.to_pickle(self.logdir + "/train_pred.pkl")
valid_preds.to_pickle(self.logdir + "/valid_pred.pkl")
test_preds.to_pickle(self.logdir + "/test_pred.pkl")
if len(train_probs):
train_probs.to_pickle(self.logdir + "/train_prob.pkl")
valid_probs.to_pickle(self.logdir + "/valid_prob.pkl")
test_probs.to_pickle(self.logdir + "/test_prob.pkl")
info = {
"config": {
"model_config": self.model_config,
"tra_config": self.tra_config,
"model_type": self.model_type,
"lr": self.lr,
"n_epochs": self.n_epochs,
"early_stop": self.early_stop,
"smooth_steps": self.smooth_steps,
"max_steps_per_epoch": self.max_steps_per_epoch,
"lamb": self.lamb,
"rho": self.rho,
"seed": self.seed,
"logdir": self.logdir,
"pretrain": self.pretrain,
"init_state": self.init_state,
"transport_method": self.transport_method,
"use_daily_transport": self.use_daily_transport,
},
"best_eval_metric": -best_score, # NOTE: -1 for minimize
"metrics": {"train": train_metrics, "valid": valid_metrics, "test": test_metrics},
}
with open(self.logdir + "/info.json", "w") as f:
json.dump(info, f)
def predict(self, dataset, segment="test"):
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
if not self.fitted:
raise ValueError("model is not fitted yet!")
test_set = dataset.prepare(segment)
metrics, preds, probs = self.test_epoch(-1, test_set, return_pred=True)
self.logger.info("test metrics: %s" % metrics)
return preds
class RNN(nn.Module):
"""RNN Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of hidden layers
rnn_arch (str): rnn architecture
use_attn (bool): whether use attention layer.
we use concat attention as https://github.com/fulifeng/Adv-AGRU/
dropout (float): dropout rate
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
rnn_arch="GRU",
use_attn=True,
dropout=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn_arch = rnn_arch
self.use_attn = use_attn
self.input_proj = nn.Linear(input_size, hidden_size)
self.rnn = getattr(nn, rnn_arch)(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
if self.use_attn:
self.W = nn.Linear(hidden_size, hidden_size)
self.u = nn.Linear(hidden_size, 1, bias=False)
self.output_size = hidden_size * 2
else:
self.output_size = hidden_size
def forward(self, x):
x = self.input_proj(x)
rnn_out, last_out = self.rnn(x)
if self.rnn_arch == "LSTM":
last_out = last_out[0]
last_out = last_out.mean(dim=0)
if self.use_attn:
laten = self.W(rnn_out).tanh()
scores = self.u(laten).softmax(dim=1)
att_out = (rnn_out * scores).sum(dim=1).squeeze()
last_out = torch.cat([last_out, att_out], dim=1)
return last_out
class PositionalEncoding(nn.Module):
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class Transformer(nn.Module):
"""Transformer Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of transformer layers
num_heads (int): number of heads in transformer
dropout (float): dropout rate
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
num_heads=2,
dropout=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.input_proj = nn.Linear(input_size, hidden_size)
self.pe = PositionalEncoding(input_size, dropout)
layer = nn.TransformerEncoderLayer(
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.output_size = hidden_size
def forward(self, x):
x = x.permute(1, 0, 2).contiguous() # the first dim need to be time
x = self.pe(x)
x = self.input_proj(x)
out = self.encoder(x)
return out[-1]
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction erros & latent representation as inputs,
then routes the input sample to a specific predictor for training & inference.
Args:
input_size (int): input size (RNN/Transformer's hidden size)
num_states (int): number of latent states (i.e., trading patterns)
If `num_states=1`, then TRA falls back to traditional methods
hidden_size (int): hidden size of the router
tau (float): gumbel softmax temperature
src_info (str): information for the router
"""
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
super().__init__()
assert src_info in ['LR', 'TPE', 'LR_TPE'], 'invalid `src_info`'
self.num_states = num_states
self.tau = tau
self.src_info = src_info
self.predictors = nn.Linear(input_size, num_states)
if self.num_states > 1:
if "TPE" in src_info:
self.router = nn.GRU(
input_size=num_states,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
)
self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states)
else:
self.fc = nn.Linear(input_size, num_states)
def reset_parameters(self):
for child in self.children():
child.reset_parameters()
def forward(self, hidden, hist_loss):
preds = self.predictors(hidden)
if self.num_states == 1: # no need for router when having only one prediction
return preds.squeeze(-1), preds, None
if "TPE" in self.src_info:
out = self.router(hist_loss)[0][:, -1] # TPE
if "LR" in self.src_info:
out = torch.cat([hidden, out], dim=-1) # LR_TPE
else:
out = hidden # LR
out = self.fc(out)
choice = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=True)
prob = torch.softmax(out / self.tau, dim=-1)
return preds, choice, prob
def evaluate(pred):
pred = pred.rank(pct=True) # transform into percentiles
score = pred.score
label = pred.label
diff = score - label
MSE = (diff ** 2).mean()
MAE = (diff.abs()).mean()
IC = score.corr(label, method="spearman")
return {"MSE": MSE, "MAE": MAE, "IC": IC}
def average_params(params_list):
assert isinstance(params_list, (tuple, list, collections.deque))
n = len(params_list)
if n == 1:
return params_list[0]
new_params = collections.OrderedDict()
keys = None
for i, params in enumerate(params_list):
if keys is None:
keys = params.keys()
for k, v in params.items():
if k not in keys:
raise ValueError("the %d-th model has different params" % i)
if k not in new_params:
new_params[k] = v / n
else:
new_params[k] += v / n
return new_params
def shoot_infs(inp_tensor):
"""Replaces inf by maximum of tensor"""
mask_inf = torch.isinf(inp_tensor)
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
if len(ind_inf) > 0:
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = 0
elif len(ind) == 1:
inp_tensor[ind[0]] = 0
m = torch.max(inp_tensor)
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = m
elif len(ind) == 1:
inp_tensor[ind[0]] = m
return inp_tensor
def sinkhorn(Q, n_iters=3, epsilon=0.01):
# epsilon should be adjusted according to logits value's scale
with torch.no_grad():
Q = torch.exp(Q / epsilon)
Q = shoot_infs(Q)
for i in range(n_iters):
Q /= Q.sum(dim=0, keepdim=True)
Q /= Q.sum(dim=1, keepdim=True)
return Q
def loss_fn(pred, label):
mask = ~torch.isnan(label)
if len(pred.shape) == 2:
label = label[:, None]
return (pred[mask] - label[mask]).pow(2).mean(dim=0)
def transport_sample(all_preds, label, choice, prob, count, transport_method, training=False):
"""
sample-wise transport
Args:
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
label (torch.Tensor): label, [sample]
choice (torch.Tensor): gumbel softmax choice, [sample x states]
prob (torch.Tensor): router predicted probility, [sample x states]
count (list): sample counts for each day, empty list for sample-wise transport
transport_method (str): transportation method
training (bool): indicate training or inference
"""
assert all_preds.shape == choice.shape
assert len(all_preds) == len(label)
assert transport_method in ["oracle", "router"]
all_loss = (all_preds - label[:, None]).pow(2) # [sample x states]
all_loss[torch.isnan(label)] = 0.0
if transport_method == "router":
if training: # router training
pred = (all_preds * choice).sum(dim=1) # gumbel softmax
else: # router inference
pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax
elif not training: # oracle inference: always choose the model with the smallest loss
pred = all_preds[range(len(all_preds)), all_loss.argmin(dim=-1)]
else: # oracle training: pred is not needed
pred = None
L = (all_loss - all_loss.min(dim=1, keepdim=True).values).detach() # normalize
P = sinkhorn(-L) if training else None # use sinkhorn to get sample assignment during training
if pred is not None: # router training/inference & oracle inference loss
loss = loss_fn(pred, label)
else: # oracle training loss
loss = (all_loss * P).sum(dim=1).mean()
return loss, pred, L, P
def transport_daily(all_preds, label, choice, prob, count, transport_method, training=False):
"""
daily transport
Args:
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
label (torch.Tensor): label, [sample]
choice (torch.Tensor): gumbel softmax choice, [days x states]
prob (torch.Tensor): router predicted probility, [days x states]
count (list): sample counts for each day, [days]
transport_method (str): transportation method
training (bool): indicate training or inference
"""
assert len(prob) == len(count)
assert len(all_preds) == sum(count)
assert transport_method in ["oracle", "router"]
all_loss = [] # loss of all predictions
pred = [] # final predictions
start = 0
for i, cnt in enumerate(count):
slc = slice(start, start + cnt) # samples from the i-th day
start += cnt
tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day
all_loss.append(tloss)
if transport_method == "router":
if training: # router training
tpred = all_preds[slc] @ choice[i] # gumbel softmax
else: # router inference
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
elif not training: # oracle inference: always choose the model with the smallest loss
tpred = all_preds[slc][:, tloss.argmin(dim=-1)]
else: # oracle training: pred is not needed
tpred = None
if tpred is not None:
pred.append(tpred)
all_loss = torch.stack(all_loss, dim=0) # [days x states]
if pred:
pred = torch.cat(pred, dim=0) # [samples]
L = (all_loss - all_loss.min(dim=1, keepdim=True).values).detach() # normalize
P = sinkhorn(-L) if training else None # use sinkhorn to get sample assignment during training
if len(pred): # router training/inference & oracle inference loss
loss = loss_fn(pred, label)
else: # oracle training loss
loss = (all_loss * P).sum(dim=1).mean()
return loss, pred, L, P

View File

@@ -199,6 +199,10 @@ class StaticDataLoader(DataLoader):
self.join = join
self._data = None
def __getstate__(self) -> dict:
# avoid pickling `self._data`
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
self._maybe_load_raw_data()
if instruments is None: