mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add A New Baseline: ADD (#704)
This commit is contained in:
@@ -298,6 +298,7 @@ Here is a list of models built on `Qlib`.
|
||||
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](examples/benchmarks/TRA/)
|
||||
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
|
||||
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
|
||||
3
examples/benchmarks/ADD/README.md
Normal file
3
examples/benchmarks/ADD/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# AdaRNN
|
||||
* Paper: [ADD: Augmented Disentanglement Distillation Framework for Improving Stock Trend Forecasting](https://arxiv.org/abs/2012.06289).
|
||||
|
||||
4
examples/benchmarks/ADD/requirements.txt
Normal file
4
examples/benchmarks/ADD/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
94
examples/benchmarks/ADD/workflow_config_add_Alpha360.yaml
Normal file
94
examples/benchmarks/ADD/workflow_config_add_Alpha360.yaml
Normal file
@@ -0,0 +1,94 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: ADD
|
||||
module_path: qlib.contrib.model.pytorch_add
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.1
|
||||
dec_dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 5000
|
||||
metric: ic
|
||||
base_model: GRU
|
||||
gamma: 0.1
|
||||
gamma_clip: 0.2
|
||||
optimizer: adam
|
||||
mu: 0.2
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -56,6 +56,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 |
|
||||
| LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 |
|
||||
| ADD | Alpha360 | 0.0430±0.00 | 0.3188±0.04 | 0.0559±0.00 | 0.4301±0.03 | 0.0667±0.02 | 0.8992±0.34 | -0.0855±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 |
|
||||
| AdaRNN(Yuntao Du, et al.) | Alpha360 | 0.0464±0.01 | 0.3619±0.08 | 0.0539±0.01 | 0.4287±0.06 | 0.0753±0.03 | 1.0200±0.40 | -0.0936±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
|
||||
|
||||
@@ -31,8 +31,9 @@ try:
|
||||
from .pytorch_tabnet import TabnetModel
|
||||
from .pytorch_sfm import SFM_Model
|
||||
from .pytorch_tcn import TCN
|
||||
from .pytorch_add import ADD
|
||||
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN)
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
|
||||
598
qlib/contrib/model/pytorch_add.py
Normal file
598
qlib/contrib/model/pytorch_add.py
Normal file
@@ -0,0 +1,598 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Text, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from qlib.contrib.model.pytorch_gru import GRUModel
|
||||
from qlib.contrib.model.pytorch_lstm import LSTMModel
|
||||
from qlib.contrib.model.pytorch_utils import count_parameters
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.processor import CSRankNorm
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import get_or_create_path
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class ADD(Model):
|
||||
"""ADD Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
dec_dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="mse",
|
||||
batch_size=5000,
|
||||
early_stop=20,
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
gamma=0.1,
|
||||
gamma_clip=0.4,
|
||||
mu=0.05,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADD")
|
||||
self.logger.info("ADD pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.dec_dropout = dec_dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.base_model = base_model
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.gamma = gamma
|
||||
self.gamma_clip = gamma_clip
|
||||
self.mu = mu
|
||||
|
||||
self.logger.info(
|
||||
"ADD parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\ndec_dropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\ngamma : {}"
|
||||
"\ngamma_clip : {}"
|
||||
"\nmu : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
dec_dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
base_model,
|
||||
model_path,
|
||||
gamma,
|
||||
gamma_clip,
|
||||
mu,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.ADD_model = ADDModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
dec_dropout=self.dec_dropout,
|
||||
base_model=self.base_model,
|
||||
gamma=self.gamma,
|
||||
gamma_clip=self.gamma_clip,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ADD_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ADD_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ADD_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.ADD_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.ADD_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def loss_pre_excess(self, pred_excess, label_excess, record=None):
|
||||
mask = ~torch.isnan(label_excess)
|
||||
pre_excess_loss = F.mse_loss(pred_excess[mask], label_excess[mask])
|
||||
if record is not None:
|
||||
record["pre_excess_loss"] = pre_excess_loss.item()
|
||||
return pre_excess_loss
|
||||
|
||||
def loss_pre_market(self, pred_market, label_market, record=None):
|
||||
pre_market_loss = F.cross_entropy(pred_market, label_market)
|
||||
if record is not None:
|
||||
record["pre_market_loss"] = pre_market_loss.item()
|
||||
return pre_market_loss
|
||||
|
||||
def loss_pre(self, pred_excess, label_excess, pred_market, label_market, record=None):
|
||||
pre_loss = self.loss_pre_excess(pred_excess, label_excess, record) + self.loss_pre_market(
|
||||
pred_market, label_market, record
|
||||
)
|
||||
if record is not None:
|
||||
record["pre_loss"] = pre_loss.item()
|
||||
return pre_loss
|
||||
|
||||
def loss_adv_excess(self, adv_excess, label_excess, record=None):
|
||||
mask = ~torch.isnan(label_excess)
|
||||
adv_excess_loss = F.mse_loss(adv_excess.squeeze()[mask], label_excess[mask])
|
||||
if record is not None:
|
||||
record["adv_excess_loss"] = adv_excess_loss.item()
|
||||
return adv_excess_loss
|
||||
|
||||
def loss_adv_market(self, adv_market, label_market, record=None):
|
||||
adv_market_loss = F.cross_entropy(adv_market, label_market)
|
||||
if record is not None:
|
||||
record["adv_market_loss"] = adv_market_loss.item()
|
||||
return adv_market_loss
|
||||
|
||||
def loss_adv(self, adv_excess, label_excess, adv_market, label_market, record=None):
|
||||
adv_loss = self.loss_adv_excess(adv_excess, label_excess, record) + self.loss_adv_market(
|
||||
adv_market, label_market, record
|
||||
)
|
||||
if record is not None:
|
||||
record["adv_loss"] = adv_loss.item()
|
||||
return adv_loss
|
||||
|
||||
def loss_fn(self, x, preds, label_excess, label_market, record=None):
|
||||
loss = (
|
||||
self.loss_pre(preds["excess"], label_excess, preds["market"], label_market, record)
|
||||
+ self.loss_adv(preds["adv_excess"], label_excess, preds["adv_market"], label_market, record)
|
||||
+ self.mu * self.loss_rec(x, preds["reconstructed_feature"], record)
|
||||
)
|
||||
if record is not None:
|
||||
record["loss"] = loss.item()
|
||||
return loss
|
||||
|
||||
def loss_rec(self, x, rec_x, record=None):
|
||||
x = x.reshape(len(x), self.d_feat, -1)
|
||||
x = x.permute(0, 2, 1)
|
||||
rec_loss = F.mse_loss(x, rec_x)
|
||||
if record is not None:
|
||||
record["rec_loss"] = rec_loss.item()
|
||||
return rec_loss
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def cal_ic_metrics(self, pred, label):
|
||||
metrics = {}
|
||||
metrics["mse"] = -F.mse_loss(pred, label).item()
|
||||
metrics["loss"] = metrics["mse"]
|
||||
pred = pd.Series(pred.cpu().detach().numpy())
|
||||
label = pd.Series(label.cpu().detach().numpy())
|
||||
metrics["ic"] = pred.corr(label)
|
||||
metrics["ric"] = pred.corr(label, method="spearman")
|
||||
return metrics
|
||||
|
||||
def test_epoch(self, data_x, data_y, data_m):
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
m_values = np.squeeze(data_m.values.astype(int))
|
||||
self.ADD_model.eval()
|
||||
|
||||
metrics_list = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
label_excess = torch.from_numpy(y_values[batch]).float().to(self.device)
|
||||
label_market = torch.from_numpy(m_values[batch]).long().to(self.device)
|
||||
|
||||
metrics = {}
|
||||
preds = self.ADD_model(feature)
|
||||
self.loss_fn(feature, preds, label_excess, label_market, metrics)
|
||||
metrics.update(self.cal_ic_metrics(preds["excess"], label_excess))
|
||||
metrics_list.append(metrics)
|
||||
metrics = {}
|
||||
keys = metrics_list[0].keys()
|
||||
for k in keys:
|
||||
vs = [m[k] for m in metrics_list]
|
||||
metrics[k] = sum(vs) / len(vs)
|
||||
|
||||
return metrics
|
||||
|
||||
def train_epoch(self, x_train_values, y_train_values, m_train_values):
|
||||
self.ADD_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
cur_step = 1
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
batch = indices[i : i + self.batch_size]
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
label_excess = torch.from_numpy(y_train_values[batch]).float().to(self.device)
|
||||
label_market = torch.from_numpy(m_train_values[batch]).long().to(self.device)
|
||||
|
||||
preds = self.ADD_model(feature)
|
||||
|
||||
loss = self.loss_fn(feature, preds, label_excess, label_market)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.ADD_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
cur_step += 1
|
||||
|
||||
def log_metrics(self, mode, metrics):
|
||||
metrics = ["{}/{}: {:.6f}".format(k, mode, v) for k, v in metrics.items()]
|
||||
metrics = ", ".join(metrics)
|
||||
self.logger.info(metrics)
|
||||
|
||||
def bootstrap_fit(self, x_train, y_train, m_train, x_valid, y_valid, m_valid):
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
m_train_values = np.squeeze(m_train.values.astype(int))
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train_values, y_train_values, m_train_values)
|
||||
self.logger.info("evaluating...")
|
||||
train_metrics = self.test_epoch(x_train, y_train, m_train)
|
||||
valid_metrics = self.test_epoch(x_valid, y_valid, m_valid)
|
||||
self.log_metrics("train", train_metrics)
|
||||
self.log_metrics("valid", valid_metrics)
|
||||
|
||||
if self.metric in valid_metrics:
|
||||
val_score = valid_metrics[self.metric]
|
||||
else:
|
||||
raise ValueError("unknown metric name `%s`" % self.metric)
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.ADD_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
self.ADD_model.before_adv_excess.step_alpha()
|
||||
self.ADD_model.before_adv_market.step_alpha()
|
||||
self.logger.info("bootstrap_fit best score: {:.6f} @ {}".format(best_score, best_epoch))
|
||||
self.ADD_model.load_state_dict(best_param)
|
||||
return best_score
|
||||
|
||||
def gen_market_label(self, df, raw_label):
|
||||
market_label = raw_label.groupby("datetime").mean().squeeze()
|
||||
bins = [-np.inf, self.lo, self.hi, np.inf]
|
||||
market_label = pd.cut(market_label, bins, labels=False)
|
||||
market_label.name = ("market_return", "market_return")
|
||||
df = df.join(market_label)
|
||||
return df
|
||||
|
||||
def fit_thresh(self, train_label):
|
||||
market_label = train_label.groupby("datetime").mean().squeeze()
|
||||
self.lo, self.hi = market_label.quantile([1 / 3, 2 / 3])
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
label_train, label_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["label"],
|
||||
data_key=DataHandlerLP.DK_R,
|
||||
)
|
||||
self.fit_thresh(label_train)
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
df_train = self.gen_market_label(df_train, label_train)
|
||||
df_valid = self.gen_market_label(df_valid, label_valid)
|
||||
|
||||
x_train, y_train, m_train = df_train["feature"], df_train["label"], df_train["market_return"]
|
||||
x_valid, y_valid, m_valid = df_valid["feature"], df_valid["label"], df_valid["market_return"]
|
||||
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
# load pretrained base_model
|
||||
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.ADD_model.enc_excess.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.ADD_model.enc_excess.load_state_dict(model_dict)
|
||||
model_dict = self.ADD_model.enc_market.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.ADD_model.enc_market.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
self.bootstrap_fit(x_train, y_train, m_train, x_valid, y_valid, m_valid)
|
||||
|
||||
best_param = copy.deepcopy(self.ADD_model.state_dict())
|
||||
save_path = get_or_create_path(save_path)
|
||||
torch.save(best_param, save_path)
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.ADD_model.eval()
|
||||
x_values = x_test.values
|
||||
preds = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.ADD_model(x_batch)
|
||||
pred = pred["excess"].detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
r = pd.Series(np.concatenate(preds), index=index)
|
||||
return r
|
||||
|
||||
|
||||
class ADDModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
dec_dropout=0.5,
|
||||
base_model="GRU",
|
||||
gamma=0.1,
|
||||
gamma_clip=0.4,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_feat = d_feat
|
||||
self.base_model = base_model
|
||||
if base_model == "GRU":
|
||||
self.enc_excess, self.enc_market = [
|
||||
nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
elif base_model == "LSTM":
|
||||
self.enc_excess, self.enc_market = [
|
||||
nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
self.dec = Decoder(d_feat, 2 * hidden_size, num_layers, dec_dropout, base_model)
|
||||
|
||||
ctx_size = hidden_size * num_layers
|
||||
self.pred_excess, self.adv_excess = [
|
||||
nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 1))
|
||||
for _ in range(2)
|
||||
]
|
||||
self.adv_market, self.pred_market = [
|
||||
nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 3))
|
||||
for _ in range(2)
|
||||
]
|
||||
self.before_adv_market, self.before_adv_excess = [RevGrad(gamma, gamma_clip) for _ in range(2)]
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(len(x), self.d_feat, -1)
|
||||
N = x.shape[0]
|
||||
T = x.shape[-1]
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
out, hidden_excess = self.enc_excess(x)
|
||||
out, hidden_market = self.enc_market(x)
|
||||
if self.base_model == "LSTM":
|
||||
feature_excess = hidden_excess[0].permute(1, 0, 2).reshape(N, -1)
|
||||
feature_market = hidden_market[0].permute(1, 0, 2).reshape(N, -1)
|
||||
else:
|
||||
feature_excess = hidden_excess.permute(1, 0, 2).reshape(N, -1)
|
||||
feature_market = hidden_market.permute(1, 0, 2).reshape(N, -1)
|
||||
predicts = {}
|
||||
predicts["excess"] = self.pred_excess(feature_excess).squeeze(1)
|
||||
predicts["market"] = self.pred_market(feature_market)
|
||||
predicts["adv_market"] = self.adv_market(self.before_adv_market(feature_excess))
|
||||
predicts["adv_excess"] = self.adv_excess(self.before_adv_excess(feature_market).squeeze(1))
|
||||
if self.base_model == "LSTM":
|
||||
hidden = [torch.cat([hidden_excess[i], hidden_market[i]], -1) for i in range(2)]
|
||||
else:
|
||||
hidden = torch.cat([hidden_excess, hidden_market], -1)
|
||||
x = torch.zeros_like(x[:, 1, :])
|
||||
reconstructed_feature = []
|
||||
for i in range(T):
|
||||
x, hidden = self.dec(x, hidden)
|
||||
reconstructed_feature.append(x)
|
||||
reconstructed_feature = torch.stack(reconstructed_feature, 1)
|
||||
predicts["reconstructed_feature"] = reconstructed_feature
|
||||
return predicts
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=128, num_layers=1, dropout=0.5, base_model="GRU"):
|
||||
super().__init__()
|
||||
self.base_model = base_model
|
||||
if base_model == "GRU":
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
|
||||
self.fc = nn.Linear(hidden_size, d_feat)
|
||||
|
||||
def forward(self, x, hidden):
|
||||
x = x.unsqueeze(1)
|
||||
output, hidden = self.rnn(x, hidden)
|
||||
output = output.squeeze(1)
|
||||
pred = self.fc(output)
|
||||
return pred, hidden
|
||||
|
||||
|
||||
class RevGradFunc(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, alpha_):
|
||||
ctx.save_for_backward(input_, alpha_)
|
||||
output = input_
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pragma: no cover
|
||||
grad_input = None
|
||||
_, alpha_ = ctx.saved_tensors
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = -grad_output * alpha_
|
||||
return grad_input, None
|
||||
|
||||
|
||||
class RevGrad(nn.Module):
|
||||
def __init__(self, gamma=0.1, gamma_clip=0.4, *args, **kwargs):
|
||||
"""
|
||||
A gradient reversal layer.
|
||||
This layer has no parameters, and simply reverses the gradient
|
||||
in the backward pass.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.gamma = gamma
|
||||
self.gamma_clip = torch.tensor(float(gamma_clip), requires_grad=False)
|
||||
self._alpha = torch.tensor(0, requires_grad=False)
|
||||
self._p = 0
|
||||
|
||||
def step_alpha(self):
|
||||
self._p += 1
|
||||
self._alpha = min(
|
||||
self.gamma_clip, torch.tensor(2 / (1 + math.exp(-self.gamma * self._p)) - 1, requires_grad=False)
|
||||
)
|
||||
|
||||
def forward(self, input_):
|
||||
return RevGradFunc.apply(input_, self._alpha)
|
||||
Reference in New Issue
Block a user