1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 20:11:08 +08:00
Files
qlib/qlib/contrib/model/pytorch_tra.py
Linlang 1ef8e61abd fix_pylint_for_CI (#1119)
* fix_pylint_for_CI

* reformat_with_black

* fix_pylint_C3001

* fix_flake8_error
2022-06-09 16:12:33 +08:00

944 lines
33 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import io
import os
import copy
import math
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None
from tqdm import tqdm
from qlib.constant import EPS
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.contrib.data.dataset import MTSDatasetH
device = "cuda" if torch.cuda.is_available() else "cpu"
class TRAModel(Model):
"""
TRA Model
Args:
model_config (dict): model config (will be used by RNN or Transformer)
tra_config (dict): TRA config (will be used by TRA)
model_type (str): which backbone model to use (RNN/Transformer)
lr (float): learning rate
n_epochs (int): number of total epochs
early_stop (int): early stop when performance not improved at this step
update_freq (int): gradient update frequency
max_steps_per_epoch (int): maximum number of steps in one epoch
lamb (float): regularization parameter
rho (float): exponential decay rate for `lamb`
alpha (float): fusion parameter for calculating transport loss matrix
seed (int): random seed
logdir (str): local log directory
eval_train (bool): whether evaluate train set between epochs
eval_test (bool): whether evaluate test set between epochs
pretrain (bool): whether pretrain the backbone model before training TRA.
Note that only TRA will be optimized after pretraining
init_state (str): model init state path
freeze_model (bool): whether freeze backbone model parameters
freeze_predictors (bool): whether freeze predictors parameters
transport_method (str): transport method, can be none/router/oracle
memory_mode (str): memory mode, the same argument for MTSDatasetH
"""
def __init__(
self,
model_config,
tra_config,
model_type="RNN",
lr=1e-3,
n_epochs=500,
early_stop=50,
update_freq=1,
max_steps_per_epoch=None,
lamb=0.0,
rho=0.99,
alpha=1.0,
seed=None,
logdir=None,
eval_train=False,
eval_test=False,
pretrain=False,
init_state=None,
reset_router=False,
freeze_model=False,
freeze_predictors=False,
transport_method="none",
memory_mode="sample",
):
self.logger = get_module_logger("TRA")
assert memory_mode in ["sample", "daily"], "invalid memory mode"
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1"
assert (
memory_mode != "daily" or tra_config["src_info"] == "TPE"
), "daily transport can only support TPE as `src_info`"
if transport_method == "router" and not eval_train:
self.logger.warning("`eval_train` will be ignored when using TRA.router")
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
self.model_config = model_config
self.tra_config = tra_config
self.model_type = model_type
self.lr = lr
self.n_epochs = n_epochs
self.early_stop = early_stop
self.update_freq = update_freq
self.max_steps_per_epoch = max_steps_per_epoch
self.lamb = lamb
self.rho = rho
self.alpha = alpha
self.seed = seed
self.logdir = logdir
self.eval_train = eval_train
self.eval_test = eval_test
self.pretrain = pretrain
self.init_state = init_state
self.reset_router = reset_router
self.freeze_model = freeze_model
self.freeze_predictors = freeze_predictors
self.transport_method = transport_method
self.use_daily_transport = memory_mode == "daily"
self.transport_fn = transport_daily if self.use_daily_transport else transport_sample
self._writer = None
if self.logdir is not None:
if os.path.exists(self.logdir):
self.logger.warning(f"logdir {self.logdir} is not empty")
os.makedirs(self.logdir, exist_ok=True)
if SummaryWriter is not None:
self._writer = SummaryWriter(log_dir=self.logdir)
self._init_model()
def _init_model(self):
self.logger.info("init TRAModel...")
self.model = eval(self.model_type)(**self.model_config).to(device)
print(self.model)
self.tra = TRA(self.model.output_size, **self.tra_config).to(device)
print(self.tra)
if self.init_state:
self.logger.warning(f"load state dict from `init_state`")
state_dict = torch.load(self.init_state, map_location="cpu")
self.model.load_state_dict(state_dict["model"])
res = load_state_dict_unsafe(self.tra, state_dict["tra"])
self.logger.warning(str(res))
if self.reset_router:
self.logger.warning(f"reset TRA.router parameters")
self.tra.fc.reset_parameters()
self.tra.router.reset_parameters()
if self.freeze_model:
self.logger.warning(f"freeze model parameters")
for param in self.model.parameters():
param.requires_grad_(False)
if self.freeze_predictors:
self.logger.warning(f"freeze TRA.predictors parameters")
for param in self.tra.predictors.parameters():
param.requires_grad_(False)
self.logger.info("# model params: %d" % sum(p.numel() for p in self.model.parameters() if p.requires_grad))
self.logger.info("# tra params: %d" % sum(p.numel() for p in self.tra.parameters() if p.requires_grad))
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
self.fitted = False
self.global_step = -1
def train_epoch(self, epoch, data_set, is_pretrain=False):
self.model.train()
self.tra.train()
data_set.train()
self.optimizer.zero_grad()
P_all = []
prob_all = []
choice_all = []
max_steps = len(data_set)
if self.max_steps_per_epoch is not None:
if epoch == 0 and self.max_steps_per_epoch < max_steps:
self.logger.info(f"max steps updated from {max_steps} to {self.max_steps_per_epoch}")
max_steps = min(self.max_steps_per_epoch, max_steps)
cur_step = 0
total_loss = 0
total_count = 0
for batch in tqdm(data_set, total=max_steps):
cur_step += 1
if cur_step > max_steps:
break
if not is_pretrain:
self.global_step += 1
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
with torch.set_grad_enabled(not self.freeze_model):
hidden = self.model(data)
all_preds, choice, prob = self.tra(hidden, state)
if is_pretrain or self.transport_method != "none":
# NOTE: use oracle transport for pre-training
loss, pred, L, P = self.transport_fn(
all_preds,
label,
choice,
prob,
state.mean(dim=1),
count,
self.transport_method if not is_pretrain else "oracle",
self.alpha,
training=True,
)
data_set.assign_data(index, L) # save loss to memory
if self.use_daily_transport: # only save for daily transport
P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index))
prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index))
choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index))
decay = self.rho ** (self.global_step // 100) # decay every 100 steps
lamb = 0 if is_pretrain else self.lamb * decay
reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict TO assignment
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step)
self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step)
self._writer.add_scalar("training/lamb", lamb, self.global_step)
if not self.use_daily_transport:
P_mean = P.mean(axis=0).detach()
self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step)
loss = loss - lamb * reg
else:
pred = all_preds.mean(dim=1)
loss = loss_fn(pred, label)
(loss / self.update_freq).backward()
if cur_step % self.update_freq == 0:
self.optimizer.step()
self.optimizer.zero_grad()
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/total_loss", loss.item(), self.global_step)
total_loss += loss.item()
total_count += 1
if self.use_daily_transport and len(P_all) > 0:
P_all = pd.concat(P_all, axis=0)
prob_all = pd.concat(prob_all, axis=0)
choice_all = pd.concat(choice_all, axis=0)
P_all.index = data_set.restore_daily_index(P_all.index)
prob_all.index = P_all.index
choice_all.index = P_all.index
if not is_pretrain:
self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC")
self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC")
self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC")
total_loss /= total_count
if self._writer is not None and not is_pretrain:
self._writer.add_scalar("training/loss", total_loss, epoch)
return total_loss
def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False):
self.model.eval()
self.tra.eval()
data_set.eval()
preds = []
probs = []
P_all = []
metrics = []
for batch in tqdm(data_set):
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
with torch.no_grad():
hidden = self.model(data)
all_preds, choice, prob = self.tra(hidden, state)
if is_pretrain or self.transport_method != "none":
loss, pred, L, P = self.transport_fn(
all_preds,
label,
choice,
prob,
state.mean(dim=1),
count,
self.transport_method if not is_pretrain else "oracle",
self.alpha,
training=False,
)
data_set.assign_data(index, L) # save loss to memory
if P is not None and return_pred:
P_all.append(pd.DataFrame(P.cpu().numpy(), index=index))
else:
pred = all_preds.mean(dim=1)
X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()]
columns = ["score", "label"] + ["score_%d" % d for d in range(all_preds.shape[1])]
pred = pd.DataFrame(X, index=batch["index"], columns=columns)
metrics.append(evaluate(pred))
if return_pred:
preds.append(pred)
if prob is not None:
columns = ["prob_%d" % d for d in range(all_preds.shape[1])]
probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns))
metrics = pd.DataFrame(metrics)
metrics = {
"MSE": metrics.MSE.mean(),
"MAE": metrics.MAE.mean(),
"IC": metrics.IC.mean(),
"ICIR": metrics.IC.mean() / metrics.IC.std(),
}
if self._writer is not None and epoch >= 0 and not is_pretrain:
for key, value in metrics.items():
self._writer.add_scalar(prefix + "/" + key, value, epoch)
if return_pred:
preds = pd.concat(preds, axis=0)
preds.index = data_set.restore_index(preds.index)
preds.index = preds.index.swaplevel()
preds.sort_index(inplace=True)
if probs:
probs = pd.concat(probs, axis=0)
if self.use_daily_transport:
probs.index = data_set.restore_daily_index(probs.index)
else:
probs.index = data_set.restore_index(probs.index)
probs.index = probs.index.swaplevel()
probs.sort_index(inplace=True)
if len(P_all):
P_all = pd.concat(P_all, axis=0)
if self.use_daily_transport:
P_all.index = data_set.restore_daily_index(P_all.index)
else:
P_all.index = data_set.restore_index(P_all.index)
P_all.index = P_all.index.swaplevel()
P_all.sort_index(inplace=True)
return metrics, preds, probs, P_all
def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True):
best_score = -1
best_epoch = 0
stop_rounds = 0
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
# train
if not is_pretrain and self.transport_method != "none":
self.logger.info("init memory...")
self.test_epoch(-1, train_set)
for epoch in range(self.n_epochs):
self.logger.info("Epoch %d:", epoch)
self.logger.info("training...")
self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)
self.logger.info("evaluating...")
# NOTE: during evaluating, the whole memory will be refreshed
if not is_pretrain and (self.transport_method == "router" or self.eval_train):
train_set.clear_memory() # NOTE: clear the shared memory
train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix="train")[0]
evals_result["train"].append(train_metrics)
self.logger.info("train metrics: %s" % train_metrics)
valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix="valid")[0]
evals_result["valid"].append(valid_metrics)
self.logger.info("valid metrics: %s" % valid_metrics)
if self.eval_test:
test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix="test")[0]
evals_result["test"].append(test_metrics)
self.logger.info("test metrics: %s" % test_metrics)
if valid_metrics["IC"] > best_score:
best_score = valid_metrics["IC"]
stop_rounds = 0
best_epoch = epoch
best_params = {
"model": copy.deepcopy(self.model.state_dict()),
"tra": copy.deepcopy(self.tra.state_dict()),
}
if self.logdir is not None:
torch.save(best_params, self.logdir + "/model.bin")
else:
stop_rounds += 1
if stop_rounds >= self.early_stop:
self.logger.info("early stop @ %s" % epoch)
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.model.load_state_dict(best_params["model"])
self.tra.load_state_dict(best_params["tra"])
return best_score
def fit(self, dataset, evals_result=dict()):
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"])
self.fitted = True
self.global_step = -1
evals_result["train"] = []
evals_result["valid"] = []
evals_result["test"] = []
if self.pretrain:
self.logger.info("pretraining...")
self.optimizer = optim.Adam(
list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr
)
self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
# reset optimizer
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
self.logger.info("training...")
best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False)
self.logger.info("inference")
train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True)
self.logger.info("train metrics: %s" % train_metrics)
valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True)
self.logger.info("valid metrics: %s" % valid_metrics)
test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True)
self.logger.info("test metrics: %s" % test_metrics)
if self.logdir:
self.logger.info("save model & pred to local directory")
pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv(
self.logdir + "/logs.csv", index=False
)
torch.save({"model": self.model.state_dict(), "tra": self.tra.state_dict()}, self.logdir + "/model.bin")
train_preds.to_pickle(self.logdir + "/train_pred.pkl")
valid_preds.to_pickle(self.logdir + "/valid_pred.pkl")
test_preds.to_pickle(self.logdir + "/test_pred.pkl")
if len(train_probs):
train_probs.to_pickle(self.logdir + "/train_prob.pkl")
valid_probs.to_pickle(self.logdir + "/valid_prob.pkl")
test_probs.to_pickle(self.logdir + "/test_prob.pkl")
if len(train_P):
train_P.to_pickle(self.logdir + "/train_P.pkl")
valid_P.to_pickle(self.logdir + "/valid_P.pkl")
test_P.to_pickle(self.logdir + "/test_P.pkl")
info = {
"config": {
"model_config": self.model_config,
"tra_config": self.tra_config,
"model_type": self.model_type,
"lr": self.lr,
"n_epochs": self.n_epochs,
"early_stop": self.early_stop,
"max_steps_per_epoch": self.max_steps_per_epoch,
"lamb": self.lamb,
"rho": self.rho,
"alpha": self.alpha,
"seed": self.seed,
"logdir": self.logdir,
"pretrain": self.pretrain,
"init_state": self.init_state,
"transport_method": self.transport_method,
"use_daily_transport": self.use_daily_transport,
},
"best_eval_metric": -best_score, # NOTE: -1 for minimize
"metrics": {"train": train_metrics, "valid": valid_metrics, "test": test_metrics},
}
with open(self.logdir + "/info.json", "w") as f:
json.dump(info, f)
def predict(self, dataset, segment="test"):
assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`"
if not self.fitted:
raise ValueError("model is not fitted yet!")
test_set = dataset.prepare(segment)
metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True)
self.logger.info("test metrics: %s" % metrics)
return preds
class RNN(nn.Module):
"""RNN Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of hidden layers
rnn_arch (str): rnn architecture
use_attn (bool): whether use attention layer.
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
dropout (float): dropout rate
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
rnn_arch="GRU",
use_attn=True,
dropout=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn_arch = rnn_arch
self.use_attn = use_attn
if hidden_size < input_size:
# compression
self.input_proj = nn.Linear(input_size, hidden_size)
else:
self.input_proj = None
self.rnn = getattr(nn, rnn_arch)(
input_size=min(input_size, hidden_size),
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
if self.use_attn:
self.W = nn.Linear(hidden_size, hidden_size)
self.u = nn.Linear(hidden_size, 1, bias=False)
self.output_size = hidden_size * 2
else:
self.output_size = hidden_size
def forward(self, x):
if self.input_proj is not None:
x = self.input_proj(x)
rnn_out, last_out = self.rnn(x)
if self.rnn_arch == "LSTM":
last_out = last_out[0]
last_out = last_out.mean(dim=0)
if self.use_attn:
laten = self.W(rnn_out).tanh()
scores = self.u(laten).softmax(dim=1)
att_out = (rnn_out * scores).sum(dim=1)
last_out = torch.cat([last_out, att_out], dim=1)
return last_out
class PositionalEncoding(nn.Module):
# reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0), :]
return self.dropout(x)
class Transformer(nn.Module):
"""Transformer Model
Args:
input_size (int): input size (# features)
hidden_size (int): hidden size
num_layers (int): number of transformer layers
num_heads (int): number of heads in transformer
dropout (float): dropout rate
"""
def __init__(
self,
input_size=16,
hidden_size=64,
num_layers=2,
num_heads=2,
dropout=0.0,
**kwargs,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.input_proj = nn.Linear(input_size, hidden_size)
self.pe = PositionalEncoding(input_size, dropout)
layer = nn.TransformerEncoderLayer(
nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.output_size = hidden_size
def forward(self, x):
x = x.permute(1, 0, 2).contiguous() # the first dim need to be time
x = self.pe(x)
x = self.input_proj(x)
out = self.encoder(x)
return out[-1]
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,
then routes the input sample to a specific predictor for training & inference.
Args:
input_size (int): input size (RNN/Transformer's hidden size)
num_states (int): number of latent states (i.e., trading patterns)
If `num_states=1`, then TRA falls back to traditional methods
hidden_size (int): hidden size of the router
tau (float): gumbel softmax temperature
src_info (str): information for the router
"""
def __init__(
self,
input_size,
num_states=1,
hidden_size=8,
rnn_arch="GRU",
num_layers=1,
dropout=0.0,
tau=1.0,
src_info="LR_TPE",
):
super().__init__()
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
self.num_states = num_states
self.tau = tau
self.rnn_arch = rnn_arch
self.src_info = src_info
self.predictors = nn.Linear(input_size, num_states)
if self.num_states > 1:
if "TPE" in src_info:
self.router = getattr(nn, rnn_arch)(
input_size=num_states,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
)
self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states)
else:
self.fc = nn.Linear(input_size, num_states)
def reset_parameters(self):
for child in self.children():
child.reset_parameters()
def forward(self, hidden, hist_loss):
preds = self.predictors(hidden)
if self.num_states == 1: # no need for router when having only one prediction
return preds, None, None
if "TPE" in self.src_info:
out = self.router(hist_loss)[1] # TPE
if self.rnn_arch == "LSTM":
out = out[0]
out = out.mean(dim=0)
if "LR" in self.src_info:
out = torch.cat([hidden, out], dim=-1) # LR_TPE
else:
out = hidden # LR
out = self.fc(out)
choice = F.gumbel_softmax(out, dim=-1, tau=self.tau, hard=True)
prob = torch.softmax(out / self.tau, dim=-1)
return preds, choice, prob
def evaluate(pred):
pred = pred.rank(pct=True) # transform into percentiles
score = pred.score
label = pred.label
diff = score - label
MSE = (diff**2).mean()
MAE = (diff.abs()).mean()
IC = score.corr(label, method="spearman")
return {"MSE": MSE, "MAE": MAE, "IC": IC}
def shoot_infs(inp_tensor):
"""Replaces inf by maximum of tensor"""
mask_inf = torch.isinf(inp_tensor)
ind_inf = torch.nonzero(mask_inf, as_tuple=False)
if len(ind_inf) > 0:
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = 0
elif len(ind) == 1:
inp_tensor[ind[0]] = 0
m = torch.max(inp_tensor)
for ind in ind_inf:
if len(ind) == 2:
inp_tensor[ind[0], ind[1]] = m
elif len(ind) == 1:
inp_tensor[ind[0]] = m
return inp_tensor
def sinkhorn(Q, n_iters=3, epsilon=0.1):
# epsilon should be adjusted according to logits value's scale
with torch.no_grad():
Q = torch.exp(Q / epsilon)
Q = shoot_infs(Q)
for i in range(n_iters):
Q /= Q.sum(dim=0, keepdim=True)
Q /= Q.sum(dim=1, keepdim=True)
return Q
def loss_fn(pred, label):
mask = ~torch.isnan(label)
if len(pred.shape) == 2:
label = label[:, None]
return (pred[mask] - label[mask]).pow(2).mean(dim=0)
def minmax_norm(x):
xmin = x.min(dim=-1, keepdim=True).values
xmax = x.max(dim=-1, keepdim=True).values
mask = (xmin == xmax).squeeze()
x = (x - xmin) / (xmax - xmin + EPS)
x[mask] = 1
return x
def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
"""
sample-wise transport
Args:
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
label (torch.Tensor): label, [sample]
choice (torch.Tensor): gumbel softmax choice, [sample x states]
prob (torch.Tensor): router predicted probility, [sample x states]
hist_loss (torch.Tensor): history loss matrix, [sample x states]
count (list): sample counts for each day, empty list for sample-wise transport
transport_method (str): transportation method
alpha (float): fusion parameter for calculating transport loss matrix
training (bool): indicate training or inference
"""
assert all_preds.shape == choice.shape
assert len(all_preds) == len(label)
assert transport_method in ["oracle", "router"]
all_loss = torch.zeros_like(all_preds)
mask = ~torch.isnan(label)
all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2) # [sample x states]
L = minmax_norm(all_loss.detach())
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
Lh = minmax_norm(Lh)
P = sinkhorn(-Lh)
del Lh
if transport_method == "router":
if training:
pred = (all_preds * choice).sum(dim=1) # gumbel softmax
else:
pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax
else:
pred = (all_preds * P).sum(dim=1)
if transport_method == "router":
loss = loss_fn(pred, label)
else:
loss = (all_loss * P).sum(dim=1).mean()
return loss, pred, L, P
def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
"""
daily transport
Args:
all_preds (torch.Tensor): predictions from all predictors, [sample x states]
label (torch.Tensor): label, [sample]
choice (torch.Tensor): gumbel softmax choice, [days x states]
prob (torch.Tensor): router predicted probility, [days x states]
hist_loss (torch.Tensor): history loss matrix, [days x states]
count (list): sample counts for each day, [days]
transport_method (str): transportation method
alpha (float): fusion parameter for calculating transport loss matrix
training (bool): indicate training or inference
"""
assert len(prob) == len(count)
assert len(all_preds) == sum(count)
assert transport_method in ["oracle", "router"]
all_loss = [] # loss of all predictions
start = 0
for i, cnt in enumerate(count):
slc = slice(start, start + cnt) # samples from the i-th day
start += cnt
tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day
all_loss.append(tloss)
all_loss = torch.stack(all_loss, dim=0) # [days x states]
L = minmax_norm(all_loss.detach())
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
Lh = minmax_norm(Lh)
P = sinkhorn(-Lh)
del Lh
pred = []
start = 0
for i, cnt in enumerate(count):
slc = slice(start, start + cnt) # samples from the i-th day
start += cnt
if transport_method == "router":
if training:
tpred = all_preds[slc] @ choice[i] # gumbel softmax
else:
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
else:
tpred = all_preds[slc] @ P[i]
pred.append(tpred)
pred = torch.cat(pred, dim=0) # [samples]
if transport_method == "router":
loss = loss_fn(pred, label)
else:
loss = (all_loss * P).sum(dim=1).mean()
return loss, pred, L, P
def load_state_dict_unsafe(model, state_dict):
"""
Load state dict to provided model while ignore exceptions.
"""
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model)
load = None # break load->load reference cycle
return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs}
def plot(P):
assert isinstance(P, pd.DataFrame)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
P.plot.area(ax=axes[0], xlabel="")
P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel="")
plt.tight_layout()
with io.BytesIO() as buf:
plt.savefig(buf, format="png")
buf.seek(0)
img = plt.imread(buf)
plt.close()
return np.uint8(img * 255)