mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
* update python version * fix: Correct selector handling and add time filtering in storage.py * fix: convert index and columns to list in repr methods * feat: Add Makefile for managing project prerequisites * feat: Add Cython extensions for rolling and expanding operations * resolve install error * fix lint error * fix lint error * fix lint error * fix lint error * fix lint error * update build package * update makefile * update ci yaml * fix docs build error * fix ubuntu install error * fix docs build error * fix install error * fix install error * fix install error * fix install error * fix pylint error * fix pylint error * fix pylint error * fix pylint error * fix pylint error E1123 * fix pylint error R0917 * fix pytest error * fix pytest error * fix pytest error * update code * update code * fix ci error * fix pylint error * fix black error * fix pytest error * fix CI error * fix CI error * add python version to CI * add python version to CI * add python version to CI * fix pylint error * fix pytest general nn error * fix CI error * optimize code * add coments * Extended macos version * remove build package --------- Co-authored-by: Young <afe.young@gmail.com>
789 lines
27 KiB
Python
789 lines
27 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
import os
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
import copy
|
|
from typing import Text, Union
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.autograd import Function
|
|
from qlib.contrib.model.pytorch_utils import count_parameters
|
|
from qlib.data.dataset import DatasetH
|
|
from qlib.data.dataset.handler import DataHandlerLP
|
|
from qlib.log import get_module_logger
|
|
from qlib.model.base import Model
|
|
from qlib.utils import get_or_create_path
|
|
|
|
|
|
class ADARNN(Model):
|
|
"""ADARNN Model
|
|
|
|
Parameters
|
|
----------
|
|
d_feat : int
|
|
input dimension for each time step
|
|
metric: str
|
|
the evaluation metric used in early stop
|
|
optimizer : str
|
|
optimizer name
|
|
GPU : str
|
|
the GPU ID(s) used for training
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_feat=6,
|
|
hidden_size=64,
|
|
num_layers=2,
|
|
dropout=0.0,
|
|
n_epochs=200,
|
|
pre_epoch=40,
|
|
dw=0.5,
|
|
loss_type="cosine",
|
|
len_seq=60,
|
|
len_win=0,
|
|
lr=0.001,
|
|
metric="mse",
|
|
batch_size=2000,
|
|
early_stop=20,
|
|
loss="mse",
|
|
optimizer="adam",
|
|
n_splits=2,
|
|
GPU=0,
|
|
seed=None,
|
|
**_,
|
|
):
|
|
# Set logger.
|
|
self.logger = get_module_logger("ADARNN")
|
|
self.logger.info("ADARNN pytorch version...")
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)
|
|
|
|
# set hyper-parameters.
|
|
self.d_feat = d_feat
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.dropout = dropout
|
|
self.n_epochs = n_epochs
|
|
self.pre_epoch = pre_epoch
|
|
self.dw = dw
|
|
self.loss_type = loss_type
|
|
self.len_seq = len_seq
|
|
self.len_win = len_win
|
|
self.lr = lr
|
|
self.metric = metric
|
|
self.batch_size = batch_size
|
|
self.early_stop = early_stop
|
|
self.optimizer = optimizer.lower()
|
|
self.loss = loss
|
|
self.n_splits = n_splits
|
|
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
|
self.seed = seed
|
|
|
|
self.logger.info(
|
|
"ADARNN parameters setting:"
|
|
"\nd_feat : {}"
|
|
"\nhidden_size : {}"
|
|
"\nnum_layers : {}"
|
|
"\ndropout : {}"
|
|
"\nn_epochs : {}"
|
|
"\nlr : {}"
|
|
"\nmetric : {}"
|
|
"\nbatch_size : {}"
|
|
"\nearly_stop : {}"
|
|
"\noptimizer : {}"
|
|
"\nloss_type : {}"
|
|
"\nvisible_GPU : {}"
|
|
"\nuse_GPU : {}"
|
|
"\nseed : {}".format(
|
|
d_feat,
|
|
hidden_size,
|
|
num_layers,
|
|
dropout,
|
|
n_epochs,
|
|
lr,
|
|
metric,
|
|
batch_size,
|
|
early_stop,
|
|
optimizer.lower(),
|
|
loss,
|
|
GPU,
|
|
self.use_gpu,
|
|
seed,
|
|
)
|
|
)
|
|
|
|
if self.seed is not None:
|
|
np.random.seed(self.seed)
|
|
torch.manual_seed(self.seed)
|
|
|
|
n_hiddens = [hidden_size for _ in range(num_layers)]
|
|
self.model = AdaRNN(
|
|
use_bottleneck=False,
|
|
bottleneck_width=64,
|
|
n_input=d_feat,
|
|
n_hiddens=n_hiddens,
|
|
n_output=1,
|
|
dropout=dropout,
|
|
model_type="AdaRNN",
|
|
len_seq=len_seq,
|
|
trans_loss=loss_type,
|
|
)
|
|
self.logger.info("model:\n{:}".format(self.model))
|
|
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.model)))
|
|
|
|
if optimizer.lower() == "adam":
|
|
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
|
|
elif optimizer.lower() == "gd":
|
|
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
|
|
else:
|
|
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
|
|
|
self.fitted = False
|
|
self.model.to(self.device)
|
|
|
|
@property
|
|
def use_gpu(self):
|
|
return self.device != torch.device("cpu")
|
|
|
|
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
|
|
self.model.train()
|
|
criterion = nn.MSELoss()
|
|
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
|
out_weight_list = None
|
|
for data_all in zip(*train_loader_list):
|
|
# for data_all in zip(*train_loader_list):
|
|
self.train_optimizer.zero_grad()
|
|
list_feat = []
|
|
list_label = []
|
|
for data in data_all:
|
|
# feature :[36, 24, 6]
|
|
feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float()
|
|
list_feat.append(feature)
|
|
list_label.append(label_reg)
|
|
flag = False
|
|
index = get_index(len(data_all) - 1)
|
|
for temp_index in index:
|
|
s1 = temp_index[0]
|
|
s2 = temp_index[1]
|
|
if list_feat[s1].shape[0] != list_feat[s2].shape[0]:
|
|
flag = True
|
|
break
|
|
if flag:
|
|
continue
|
|
|
|
total_loss = torch.zeros(1).to(self.device)
|
|
for i, n in enumerate(index):
|
|
feature_s = list_feat[n[0]]
|
|
feature_t = list_feat[n[1]]
|
|
label_reg_s = list_label[n[0]]
|
|
label_reg_t = list_label[n[1]]
|
|
feature_all = torch.cat((feature_s, feature_t), 0)
|
|
|
|
if epoch < self.pre_epoch:
|
|
pred_all, loss_transfer, out_weight_list = self.model.forward_pre_train(
|
|
feature_all, len_win=self.len_win
|
|
)
|
|
else:
|
|
pred_all, loss_transfer, dist, weight_mat = self.model.forward_Boosting(feature_all, weight_mat)
|
|
dist_mat = dist_mat + dist
|
|
pred_s = pred_all[0 : feature_s.size(0)]
|
|
pred_t = pred_all[feature_s.size(0) :]
|
|
|
|
loss_s = criterion(pred_s, label_reg_s)
|
|
loss_t = criterion(pred_t, label_reg_t)
|
|
|
|
total_loss = total_loss + loss_s + loss_t + self.dw * loss_transfer
|
|
self.train_optimizer.zero_grad()
|
|
total_loss.backward()
|
|
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
|
self.train_optimizer.step()
|
|
if epoch >= self.pre_epoch:
|
|
if epoch > self.pre_epoch:
|
|
weight_mat = self.model.update_weight_Boosting(weight_mat, dist_old, dist_mat)
|
|
return weight_mat, dist_mat
|
|
else:
|
|
weight_mat = self.transform_type(out_weight_list)
|
|
return weight_mat, None
|
|
|
|
@staticmethod
|
|
def calc_all_metrics(pred):
|
|
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
|
res = {}
|
|
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
|
rank_ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score, method="spearman"))
|
|
res["ic"] = ic.mean()
|
|
res["icir"] = ic.mean() / ic.std()
|
|
res["ric"] = rank_ic.mean()
|
|
res["ricir"] = rank_ic.mean() / rank_ic.std()
|
|
res["mse"] = -(pred["label"] - pred["score"]).mean()
|
|
res["loss"] = res["mse"]
|
|
return res
|
|
|
|
def test_epoch(self, df):
|
|
self.model.eval()
|
|
preds = self.infer(df["feature"])
|
|
label = df["label"].squeeze()
|
|
preds = pd.DataFrame({"label": label, "score": preds}, index=df.index)
|
|
metrics = self.calc_all_metrics(preds)
|
|
return metrics
|
|
|
|
def log_metrics(self, mode, metrics):
|
|
metrics = ["{}/{}: {:.6f}".format(k, mode, v) for k, v in metrics.items()]
|
|
metrics = ", ".join(metrics)
|
|
self.logger.info(metrics)
|
|
|
|
def fit(
|
|
self,
|
|
dataset: DatasetH,
|
|
evals_result=dict(),
|
|
save_path=None,
|
|
):
|
|
df_train, df_valid = dataset.prepare(
|
|
["train", "valid"],
|
|
col_set=["feature", "label"],
|
|
data_key=DataHandlerLP.DK_L,
|
|
)
|
|
# splits = ['2011-06-30']
|
|
days = df_train.index.get_level_values(level=0).unique()
|
|
train_splits = np.array_split(days, self.n_splits)
|
|
train_splits = [df_train[s[0] : s[-1]] for s in train_splits]
|
|
train_loader_list = [get_stock_loader(df, self.batch_size) for df in train_splits]
|
|
|
|
save_path = get_or_create_path(save_path)
|
|
stop_steps = 0
|
|
evals_result["train"] = []
|
|
evals_result["valid"] = []
|
|
|
|
# train
|
|
self.logger.info("training...")
|
|
self.fitted = True
|
|
best_score = -np.inf
|
|
best_epoch = 0
|
|
weight_mat, dist_mat = None, None
|
|
|
|
for step in range(self.n_epochs):
|
|
self.logger.info("Epoch%d:", step)
|
|
self.logger.info("training...")
|
|
weight_mat, dist_mat = self.train_AdaRNN(train_loader_list, step, dist_mat, weight_mat)
|
|
self.logger.info("evaluating...")
|
|
train_metrics = self.test_epoch(df_train)
|
|
valid_metrics = self.test_epoch(df_valid)
|
|
self.log_metrics("train: ", train_metrics)
|
|
self.log_metrics("valid: ", valid_metrics)
|
|
|
|
valid_score = valid_metrics[self.metric]
|
|
train_score = train_metrics[self.metric]
|
|
evals_result["train"].append(train_score)
|
|
evals_result["valid"].append(valid_score)
|
|
if valid_score > best_score:
|
|
best_score = valid_score
|
|
stop_steps = 0
|
|
best_epoch = step
|
|
best_param = copy.deepcopy(self.model.state_dict())
|
|
else:
|
|
stop_steps += 1
|
|
if stop_steps >= self.early_stop:
|
|
self.logger.info("early stop")
|
|
break
|
|
|
|
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
|
self.model.load_state_dict(best_param)
|
|
torch.save(best_param, save_path)
|
|
|
|
if self.use_gpu:
|
|
torch.cuda.empty_cache()
|
|
return best_score
|
|
|
|
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
|
if not self.fitted:
|
|
raise ValueError("model is not fitted yet!")
|
|
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
|
return self.infer(x_test)
|
|
|
|
def infer(self, x_test):
|
|
index = x_test.index
|
|
self.model.eval()
|
|
x_values = x_test.values
|
|
sample_num = x_values.shape[0]
|
|
x_values = x_values.reshape(sample_num, self.d_feat, -1).transpose(0, 2, 1)
|
|
preds = []
|
|
|
|
for begin in range(sample_num)[:: self.batch_size]:
|
|
if sample_num - begin < self.batch_size:
|
|
end = sample_num
|
|
else:
|
|
end = begin + self.batch_size
|
|
|
|
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
|
|
|
with torch.no_grad():
|
|
pred = self.model.predict(x_batch).detach().cpu().numpy()
|
|
|
|
preds.append(pred)
|
|
|
|
return pd.Series(np.concatenate(preds), index=index)
|
|
|
|
def transform_type(self, init_weight):
|
|
weight = torch.ones(self.num_layers, self.len_seq).to(self.device)
|
|
for i in range(self.num_layers):
|
|
for j in range(self.len_seq):
|
|
weight[i, j] = init_weight[i][j].item()
|
|
return weight
|
|
|
|
|
|
class data_loader(Dataset):
|
|
def __init__(self, df):
|
|
self.df_feature = df["feature"]
|
|
self.df_label_reg = df["label"]
|
|
self.df_index = df.index
|
|
self.df_feature = torch.tensor(
|
|
self.df_feature.values.reshape(-1, 6, 60).transpose(0, 2, 1), dtype=torch.float32
|
|
)
|
|
self.df_label_reg = torch.tensor(self.df_label_reg.values.reshape(-1), dtype=torch.float32)
|
|
|
|
def __getitem__(self, index):
|
|
sample, label_reg = self.df_feature[index], self.df_label_reg[index]
|
|
return sample, label_reg
|
|
|
|
def __len__(self):
|
|
return len(self.df_feature)
|
|
|
|
|
|
def get_stock_loader(df, batch_size, shuffle=True):
|
|
train_loader = DataLoader(data_loader(df), batch_size=batch_size, shuffle=shuffle)
|
|
return train_loader
|
|
|
|
|
|
def get_index(num_domain=2):
|
|
index = []
|
|
for i in range(num_domain):
|
|
for j in range(i + 1, num_domain + 1):
|
|
index.append((i, j))
|
|
return index
|
|
|
|
|
|
class AdaRNN(nn.Module):
|
|
"""
|
|
model_type: 'Boosting', 'AdaRNN'
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
use_bottleneck=False,
|
|
bottleneck_width=256,
|
|
n_input=128,
|
|
n_hiddens=[64, 64],
|
|
n_output=6,
|
|
dropout=0.0,
|
|
len_seq=9,
|
|
model_type="AdaRNN",
|
|
trans_loss="mmd",
|
|
GPU=0,
|
|
):
|
|
super(AdaRNN, self).__init__()
|
|
self.use_bottleneck = use_bottleneck
|
|
self.n_input = n_input
|
|
self.num_layers = len(n_hiddens)
|
|
self.hiddens = n_hiddens
|
|
self.n_output = n_output
|
|
self.model_type = model_type
|
|
self.trans_loss = trans_loss
|
|
self.len_seq = len_seq
|
|
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
|
in_size = self.n_input
|
|
|
|
features = nn.ModuleList()
|
|
for hidden in n_hiddens:
|
|
rnn = nn.GRU(input_size=in_size, num_layers=1, hidden_size=hidden, batch_first=True, dropout=dropout)
|
|
features.append(rnn)
|
|
in_size = hidden
|
|
self.features = nn.Sequential(*features)
|
|
|
|
if use_bottleneck is True: # finance
|
|
self.bottleneck = nn.Sequential(
|
|
nn.Linear(n_hiddens[-1], bottleneck_width),
|
|
nn.Linear(bottleneck_width, bottleneck_width),
|
|
nn.BatchNorm1d(bottleneck_width),
|
|
nn.ReLU(),
|
|
nn.Dropout(),
|
|
)
|
|
self.bottleneck[0].weight.data.normal_(0, 0.005)
|
|
self.bottleneck[0].bias.data.fill_(0.1)
|
|
self.bottleneck[1].weight.data.normal_(0, 0.005)
|
|
self.bottleneck[1].bias.data.fill_(0.1)
|
|
self.fc = nn.Linear(bottleneck_width, n_output)
|
|
torch.nn.init.xavier_normal_(self.fc.weight)
|
|
else:
|
|
self.fc_out = nn.Linear(n_hiddens[-1], self.n_output)
|
|
|
|
if self.model_type == "AdaRNN":
|
|
gate = nn.ModuleList()
|
|
for i in range(len(n_hiddens)):
|
|
gate_weight = nn.Linear(len_seq * self.hiddens[i] * 2, len_seq)
|
|
gate.append(gate_weight)
|
|
self.gate = gate
|
|
|
|
bnlst = nn.ModuleList()
|
|
for i in range(len(n_hiddens)):
|
|
bnlst.append(nn.BatchNorm1d(len_seq))
|
|
self.bn_lst = bnlst
|
|
self.softmax = torch.nn.Softmax(dim=0)
|
|
self.init_layers()
|
|
|
|
def init_layers(self):
|
|
for i in range(len(self.hiddens)):
|
|
self.gate[i].weight.data.normal_(0, 0.05)
|
|
self.gate[i].bias.data.fill_(0.0)
|
|
|
|
def forward_pre_train(self, x, len_win=0):
|
|
out = self.gru_features(x)
|
|
fea = out[0] # [2N,L,H]
|
|
if self.use_bottleneck is True:
|
|
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
|
fc_out = self.fc(fea_bottleneck).squeeze()
|
|
else:
|
|
fc_out = self.fc_out(fea[:, -1, :]).squeeze() # [N,]
|
|
|
|
out_list_all, out_weight_list = out[1], out[2]
|
|
out_list_s, out_list_t = self.get_features(out_list_all)
|
|
loss_transfer = torch.zeros((1,)).to(self.device)
|
|
for i, n in enumerate(out_list_s):
|
|
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
|
h_start = 0
|
|
for j in range(h_start, self.len_seq, 1):
|
|
i_start = j - len_win if j - len_win >= 0 else 0
|
|
i_end = j + len_win if j + len_win < self.len_seq else self.len_seq - 1
|
|
for k in range(i_start, i_end + 1):
|
|
weight = (
|
|
out_weight_list[i][j]
|
|
if self.model_type == "AdaRNN"
|
|
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
|
|
)
|
|
loss_transfer = loss_transfer + weight * criterion_transder.compute(
|
|
n[:, j, :], out_list_t[i][:, k, :]
|
|
)
|
|
return fc_out, loss_transfer, out_weight_list
|
|
|
|
def gru_features(self, x, predict=False):
|
|
x_input = x
|
|
out = None
|
|
out_lis = []
|
|
out_weight_list = [] if (self.model_type == "AdaRNN") else None
|
|
for i in range(self.num_layers):
|
|
out, _ = self.features[i](x_input.float())
|
|
x_input = out
|
|
out_lis.append(out)
|
|
if self.model_type == "AdaRNN" and predict is False:
|
|
out_gate = self.process_gate_weight(x_input, i)
|
|
out_weight_list.append(out_gate)
|
|
return out, out_lis, out_weight_list
|
|
|
|
def process_gate_weight(self, out, index):
|
|
x_s = out[0 : int(out.shape[0] // 2)]
|
|
x_t = out[out.shape[0] // 2 : out.shape[0]]
|
|
x_all = torch.cat((x_s, x_t), 2)
|
|
x_all = x_all.view(x_all.shape[0], -1)
|
|
weight = torch.sigmoid(self.bn_lst[index](self.gate[index](x_all.float())))
|
|
weight = torch.mean(weight, dim=0)
|
|
res = self.softmax(weight).squeeze()
|
|
return res
|
|
|
|
@staticmethod
|
|
def get_features(output_list):
|
|
fea_list_src, fea_list_tar = [], []
|
|
for fea in output_list:
|
|
fea_list_src.append(fea[0 : fea.size(0) // 2])
|
|
fea_list_tar.append(fea[fea.size(0) // 2 :])
|
|
return fea_list_src, fea_list_tar
|
|
|
|
# For Boosting-based
|
|
def forward_Boosting(self, x, weight_mat=None):
|
|
out = self.gru_features(x)
|
|
fea = out[0]
|
|
if self.use_bottleneck:
|
|
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
|
fc_out = self.fc(fea_bottleneck).squeeze()
|
|
else:
|
|
fc_out = self.fc_out(fea[:, -1, :]).squeeze()
|
|
|
|
out_list_all = out[1]
|
|
out_list_s, out_list_t = self.get_features(out_list_all)
|
|
loss_transfer = torch.zeros((1,)).to(self.device)
|
|
if weight_mat is None:
|
|
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device)
|
|
else:
|
|
weight = weight_mat
|
|
dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device)
|
|
for i, n in enumerate(out_list_s):
|
|
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
|
for j in range(self.len_seq):
|
|
loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :])
|
|
loss_transfer = loss_transfer + weight[i, j] * loss_trans
|
|
dist_mat[i, j] = loss_trans
|
|
return fc_out, loss_transfer, dist_mat, weight
|
|
|
|
# For Boosting-based
|
|
def update_weight_Boosting(self, weight_mat, dist_old, dist_new):
|
|
epsilon = 1e-5
|
|
dist_old = dist_old.detach()
|
|
dist_new = dist_new.detach()
|
|
ind = dist_new > dist_old + epsilon
|
|
weight_mat[ind] = weight_mat[ind] * (1 + torch.sigmoid(dist_new[ind] - dist_old[ind]))
|
|
weight_norm = torch.norm(weight_mat, dim=1, p=1)
|
|
weight_mat = weight_mat / weight_norm.t().unsqueeze(1).repeat(1, self.len_seq)
|
|
return weight_mat
|
|
|
|
def predict(self, x):
|
|
out = self.gru_features(x, predict=True)
|
|
fea = out[0]
|
|
if self.use_bottleneck is True:
|
|
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
|
fc_out = self.fc(fea_bottleneck).squeeze()
|
|
else:
|
|
fc_out = self.fc_out(fea[:, -1, :]).squeeze()
|
|
return fc_out
|
|
|
|
|
|
class TransferLoss:
|
|
def __init__(self, loss_type="cosine", input_dim=512, GPU=0):
|
|
"""
|
|
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
|
"""
|
|
self.loss_type = loss_type
|
|
self.input_dim = input_dim
|
|
self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
|
|
|
def compute(self, X, Y):
|
|
"""Compute adaptation loss
|
|
|
|
Arguments:
|
|
X {tensor} -- source matrix
|
|
Y {tensor} -- target matrix
|
|
|
|
Returns:
|
|
[tensor] -- transfer loss
|
|
"""
|
|
loss = None
|
|
if self.loss_type in ("mmd_lin", "mmd"):
|
|
mmdloss = MMD_loss(kernel_type="linear")
|
|
loss = mmdloss(X, Y)
|
|
elif self.loss_type == "coral":
|
|
loss = CORAL(X, Y, self.device)
|
|
elif self.loss_type in ("cosine", "cos"):
|
|
loss = 1 - cosine(X, Y)
|
|
elif self.loss_type == "kl":
|
|
loss = kl_div(X, Y)
|
|
elif self.loss_type == "js":
|
|
loss = js(X, Y)
|
|
elif self.loss_type == "mine":
|
|
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device)
|
|
loss = mine_model(X, Y)
|
|
elif self.loss_type == "adv":
|
|
loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32)
|
|
elif self.loss_type == "mmd_rbf":
|
|
mmdloss = MMD_loss(kernel_type="rbf")
|
|
loss = mmdloss(X, Y)
|
|
elif self.loss_type == "pairwise":
|
|
pair_mat = pairwise_dist(X, Y)
|
|
loss = torch.norm(pair_mat)
|
|
|
|
return loss
|
|
|
|
|
|
def cosine(source, target):
|
|
source, target = source.mean(), target.mean()
|
|
cos = nn.CosineSimilarity(dim=0)
|
|
loss = cos(source, target)
|
|
return loss.mean()
|
|
|
|
|
|
class ReverseLayerF(Function):
|
|
@staticmethod
|
|
def forward(ctx, x, alpha):
|
|
ctx.alpha = alpha
|
|
return x.view_as(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
output = grad_output.neg() * ctx.alpha
|
|
return output, None
|
|
|
|
|
|
class Discriminator(nn.Module):
|
|
def __init__(self, input_dim=256, hidden_dim=256):
|
|
super(Discriminator, self).__init__()
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.dis1 = nn.Linear(input_dim, hidden_dim)
|
|
self.dis2 = nn.Linear(hidden_dim, 1)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.dis1(x))
|
|
x = self.dis2(x)
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
|
|
def adv(source, target, device, input_dim=256, hidden_dim=512):
|
|
domain_loss = nn.BCELoss()
|
|
# !!! Pay attention to .cuda !!!
|
|
adv_net = Discriminator(input_dim, hidden_dim).to(device)
|
|
domain_src = torch.ones(len(source)).to(device)
|
|
domain_tar = torch.zeros(len(target)).to(device)
|
|
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
|
|
reverse_src = ReverseLayerF.apply(source, 1)
|
|
reverse_tar = ReverseLayerF.apply(target, 1)
|
|
pred_src = adv_net(reverse_src)
|
|
pred_tar = adv_net(reverse_tar)
|
|
loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar)
|
|
loss = loss_s + loss_t
|
|
return loss
|
|
|
|
|
|
def CORAL(source, target, device):
|
|
d = source.size(1)
|
|
ns, nt = source.size(0), target.size(0)
|
|
|
|
# source covariance
|
|
tmp_s = torch.ones((1, ns)).to(device) @ source
|
|
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
|
|
|
|
# target covariance
|
|
tmp_t = torch.ones((1, nt)).to(device) @ target
|
|
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
|
|
|
|
# frobenius norm
|
|
loss = (cs - ct).pow(2).sum()
|
|
loss = loss / (4 * d * d)
|
|
|
|
return loss
|
|
|
|
|
|
class MMD_loss(nn.Module):
|
|
def __init__(self, kernel_type="linear", kernel_mul=2.0, kernel_num=5):
|
|
super(MMD_loss, self).__init__()
|
|
self.kernel_num = kernel_num
|
|
self.kernel_mul = kernel_mul
|
|
self.fix_sigma = None
|
|
self.kernel_type = kernel_type
|
|
|
|
@staticmethod
|
|
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
|
n_samples = int(source.size()[0]) + int(target.size()[0])
|
|
total = torch.cat([source, target], dim=0)
|
|
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
|
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
|
L2_distance = ((total0 - total1) ** 2).sum(2)
|
|
if fix_sigma:
|
|
bandwidth = fix_sigma
|
|
else:
|
|
bandwidth = torch.sum(L2_distance.data) / (n_samples**2 - n_samples)
|
|
bandwidth /= kernel_mul ** (kernel_num // 2)
|
|
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
|
|
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
|
return sum(kernel_val)
|
|
|
|
@staticmethod
|
|
def linear_mmd(X, Y):
|
|
delta = X.mean(axis=0) - Y.mean(axis=0)
|
|
loss = delta.dot(delta.T)
|
|
return loss
|
|
|
|
def forward(self, source, target):
|
|
if self.kernel_type == "linear":
|
|
return self.linear_mmd(source, target)
|
|
elif self.kernel_type == "rbf":
|
|
batch_size = int(source.size()[0])
|
|
kernels = self.guassian_kernel(
|
|
source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma
|
|
)
|
|
with torch.no_grad():
|
|
XX = torch.mean(kernels[:batch_size, :batch_size])
|
|
YY = torch.mean(kernels[batch_size:, batch_size:])
|
|
XY = torch.mean(kernels[:batch_size, batch_size:])
|
|
YX = torch.mean(kernels[batch_size:, :batch_size])
|
|
loss = torch.mean(XX + YY - XY - YX)
|
|
return loss
|
|
|
|
|
|
class Mine_estimator(nn.Module):
|
|
def __init__(self, input_dim=2048, hidden_dim=512):
|
|
super(Mine_estimator, self).__init__()
|
|
self.mine_model = Mine(input_dim, hidden_dim)
|
|
|
|
def forward(self, X, Y):
|
|
Y_shffle = Y[torch.randperm(len(Y))]
|
|
loss_joint = self.mine_model(X, Y)
|
|
loss_marginal = self.mine_model(X, Y_shffle)
|
|
ret = torch.mean(loss_joint) - torch.log(torch.mean(torch.exp(loss_marginal)))
|
|
loss = -ret
|
|
return loss
|
|
|
|
|
|
class Mine(nn.Module):
|
|
def __init__(self, input_dim=2048, hidden_dim=512):
|
|
super(Mine, self).__init__()
|
|
self.fc1_x = nn.Linear(input_dim, hidden_dim)
|
|
self.fc1_y = nn.Linear(input_dim, hidden_dim)
|
|
self.fc2 = nn.Linear(hidden_dim, 1)
|
|
|
|
def forward(self, x, y):
|
|
h1 = F.leaky_relu(self.fc1_x(x) + self.fc1_y(y))
|
|
h2 = self.fc2(h1)
|
|
return h2
|
|
|
|
|
|
def pairwise_dist(X, Y):
|
|
n, d = X.shape
|
|
m, _ = Y.shape
|
|
assert d == Y.shape[1]
|
|
a = X.unsqueeze(1).expand(n, m, d)
|
|
b = Y.unsqueeze(0).expand(n, m, d)
|
|
return torch.pow(a - b, 2).sum(2)
|
|
|
|
|
|
def pairwise_dist_np(X, Y):
|
|
n, d = X.shape
|
|
m, _ = Y.shape
|
|
assert d == Y.shape[1]
|
|
a = np.expand_dims(X, 1)
|
|
b = np.expand_dims(Y, 0)
|
|
a = np.tile(a, (1, m, 1))
|
|
b = np.tile(b, (n, 1, 1))
|
|
return np.power(a - b, 2).sum(2)
|
|
|
|
|
|
def pa(X, Y):
|
|
XY = np.dot(X, Y.T)
|
|
XX = np.sum(np.square(X), axis=1)
|
|
XX = np.transpose([XX])
|
|
YY = np.sum(np.square(Y), axis=1)
|
|
dist = XX + YY - 2 * XY
|
|
|
|
return dist
|
|
|
|
|
|
def kl_div(source, target):
|
|
if len(source) < len(target):
|
|
target = target[: len(source)]
|
|
elif len(source) > len(target):
|
|
source = source[: len(target)]
|
|
criterion = nn.KLDivLoss(reduction="batchmean")
|
|
loss = criterion(source.log(), target)
|
|
return loss
|
|
|
|
|
|
def js(source, target):
|
|
if len(source) < len(target):
|
|
target = target[: len(source)]
|
|
elif len(source) > len(target):
|
|
source = source[: len(target)]
|
|
M = 0.5 * (source + target)
|
|
loss_1, loss_2 = kl_div(source, M), kl_div(target, M)
|
|
return 0.5 * (loss_1 + loss_2)
|