From 8f4d320832d3ed138cd90f4a0bb9df03dfa1fcfd Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Mon, 30 Aug 2021 07:32:04 +0000 Subject: [PATCH] bug fix & use oracle transport pretrain --- .../TRA/workflow_config_tra_Alpha158.yaml | 93 ++--- .../workflow_config_tra_Alpha158_full.yaml | 81 ++-- .../TRA/workflow_config_tra_Alpha360.yaml | 81 ++-- qlib/contrib/model/pytorch_tra.py | 368 ++++++++++++------ 4 files changed, 375 insertions(+), 248 deletions(-) diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml index bf4dcb7d8..09ff8893b 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml @@ -6,30 +6,30 @@ market: &market csi300 benchmark: &benchmark SH000300 data_handler_config: &data_handler_config - start_time: 2008-01-01 - end_time: 2020-08-01 - fit_start_time: 2008-01-01 - fit_end_time: 2014-12-31 - instruments: *market - infer_processors: - - class: FilterCol - kwargs: - fields_group: feature - col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", - "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", - "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"] - - class: RobustZScoreNorm - kwargs: - fields_group: feature - clip_outlier: true - - class: Fillna - kwargs: - fields_group: feature - learn_processors: - - class: CSRankNorm - kwargs: - fields_group: label - label: ["Ref($close, -2) / Ref($close, -1) - 1"] + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: FilterCol + kwargs: + fields_group: feature + col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", + "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", + "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"] + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] num_states: &num_states 3 @@ -37,7 +37,10 @@ memory_mode: &memory_mode sample tra_config: &tra_config num_states: *num_states - hidden_size: 16 + rnn_arch: LSTM + hidden_size: 32 + num_layers: 1 + dropout: 0.0 tau: 1.0 src_info: LR_TPE @@ -50,21 +53,21 @@ model_config: &model_config dropout: 0.0 port_analysis_config: &port_analysis_config - strategy: - class: TopkDropoutStrategy - module_path: qlib.contrib.strategy.strategy - kwargs: - topk: 50 - n_drop: 5 - backtest: - verbose: False - limit_threshold: 0.095 - account: 100000000 - benchmark: *benchmark - deal_price: close - open_cost: 0.0005 - close_cost: 0.0015 - min_cost: 5 + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 task: model: @@ -76,13 +79,13 @@ task: model_type: RNN lr: 1e-3 n_epochs: 100 - max_steps_per_epoch: 100 - early_stop: 10 - smooth_steps: 5 + max_steps_per_epoch: + early_stop: 20 + logdir: output/Alpha158 seed: 0 - logdir: lamb: 1.0 - rho: 1.0 + rho: 0.99 + alpha: 0.5 transport_method: router memory_mode: *memory_mode eval_train: False diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml index 8d3c8e582..dd413b00a 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml @@ -6,24 +6,24 @@ market: &market csi300 benchmark: &benchmark SH000300 data_handler_config: &data_handler_config - start_time: 2008-01-01 - end_time: 2020-08-01 - fit_start_time: 2008-01-01 - fit_end_time: 2014-12-31 - instruments: *market - infer_processors: - - class: RobustZScoreNorm - kwargs: - fields_group: feature - clip_outlier: true - - class: Fillna - kwargs: - fields_group: feature - learn_processors: - - class: CSRankNorm - kwargs: - fields_group: label - label: ["Ref($close, -2) / Ref($close, -1) - 1"] + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] num_states: &num_states 3 @@ -31,7 +31,10 @@ memory_mode: &memory_mode sample tra_config: &tra_config num_states: *num_states - hidden_size: 16 + rnn_arch: LSTM + hidden_size: 32 + num_layers: 1 + dropout: 0.0 tau: 1.0 src_info: LR_TPE @@ -44,21 +47,21 @@ model_config: &model_config dropout: 0.2 port_analysis_config: &port_analysis_config - strategy: - class: TopkDropoutStrategy - module_path: qlib.contrib.strategy.strategy - kwargs: - topk: 50 - n_drop: 5 - backtest: - verbose: False - limit_threshold: 0.095 - account: 100000000 - benchmark: *benchmark - deal_price: close - open_cost: 0.0005 - close_cost: 0.0015 - min_cost: 5 + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 task: model: @@ -70,13 +73,13 @@ task: model_type: RNN lr: 1e-3 n_epochs: 100 - max_steps_per_epoch: 100 - early_stop: 10 - smooth_steps: 5 + max_steps_per_epoch: + early_stop: 20 + logdir: output/Alpha158_full seed: 0 - logdir: lamb: 1.0 - rho: 1.0 + rho: 0.99 + alpha: 0.5 transport_method: router memory_mode: *memory_mode eval_train: False diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml index dbdeaf060..84dee5d72 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml @@ -6,24 +6,24 @@ market: &market csi300 benchmark: &benchmark SH000300 data_handler_config: &data_handler_config - start_time: 2008-01-01 - end_time: 2020-08-01 - fit_start_time: 2008-01-01 - fit_end_time: 2014-12-31 - instruments: *market - infer_processors: - - class: RobustZScoreNorm - kwargs: - fields_group: feature - clip_outlier: true - - class: Fillna - kwargs: - fields_group: feature - learn_processors: - - class: CSRankNorm - kwargs: - fields_group: label - label: ["Ref($close, -2) / Ref($close, -1) - 1"] + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] num_states: &num_states 3 @@ -31,7 +31,10 @@ memory_mode: &memory_mode sample tra_config: &tra_config num_states: *num_states - hidden_size: 16 + rnn_arch: LSTM + hidden_size: 32 + num_layers: 1 + dropout: 0.0 tau: 1.0 src_info: LR_TPE @@ -44,21 +47,21 @@ model_config: &model_config dropout: 0.0 port_analysis_config: &port_analysis_config - strategy: - class: TopkDropoutStrategy - module_path: qlib.contrib.strategy.strategy - kwargs: - topk: 50 - n_drop: 5 - backtest: - verbose: False - limit_threshold: 0.095 - account: 100000000 - benchmark: *benchmark - deal_price: close - open_cost: 0.0005 - close_cost: 0.0015 - min_cost: 5 + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 task: model: @@ -70,13 +73,13 @@ task: model_type: RNN lr: 1e-3 n_epochs: 100 - max_steps_per_epoch: 100 - early_stop: 10 - smooth_steps: 5 - logdir: + max_steps_per_epoch: + early_stop: 20 + logdir: output/Alpha360 seed: 0 lamb: 1.0 - rho: 1.0 + rho: 0.99 + alpha: 0.5 transport_method: router memory_mode: *memory_mode eval_train: False diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index f6c659533..5a583a965 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import io import os import copy import math @@ -8,6 +9,8 @@ import json import collections import numpy as np import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt import torch import torch.nn as nn @@ -40,10 +43,11 @@ class TRAModel(Model): lr (float): learning rate n_epochs (int): number of total epochs early_stop (int): early stop when performance not improved at this step - smooth_steps (int): number of steps for parameter smoothing + update_freq (int): gradient update frequency max_steps_per_epoch (int): maximum number of steps in one epoch lamb (float): regularization parameter rho (float): exponential decay rate for `lamb` + alpha (float): fusion parameter for calculating transport loss matrix seed (int): random seed logdir (str): local log directory eval_train (bool): whether evaluate train set between epochs @@ -65,16 +69,18 @@ class TRAModel(Model): lr=1e-3, n_epochs=500, early_stop=50, - smooth_steps=5, + update_freq=1, max_steps_per_epoch=None, lamb=0.0, rho=0.99, + alpha=1.0, seed=0, logdir=None, eval_train=False, eval_test=False, pretrain=False, init_state=None, + reset_router=False, freeze_model=False, freeze_predictors=False, transport_method="none", @@ -102,16 +108,18 @@ class TRAModel(Model): self.lr = lr self.n_epochs = n_epochs self.early_stop = early_stop - self.smooth_steps = smooth_steps + self.update_freq = update_freq self.max_steps_per_epoch = max_steps_per_epoch self.lamb = lamb self.rho = rho + self.alpha = alpha self.seed = seed self.logdir = logdir self.eval_train = eval_train self.eval_test = eval_test self.pretrain = pretrain self.init_state = init_state + self.reset_router = reset_router self.freeze_model = freeze_model self.freeze_predictors = freeze_predictors self.transport_method = transport_method @@ -139,20 +147,24 @@ class TRAModel(Model): print(self.tra) if self.init_state: - self.logger.warninging(f"load state dict from `init_state`") + self.logger.warning(f"load state dict from `init_state`") state_dict = torch.load(self.init_state, map_location="cpu") self.model.load_state_dict(state_dict["model"]) - try: - self.tra.load_state_dict(state_dict["tra"]) - except: - self.logger.warninging("cannot load tra model, will skip") + res = load_state_dict_unsafe(self.tra, state_dict["tra"]) + self.logger.warning(str(res)) + + if self.reset_router: + self.logger.warning(f"reset TRA.router parameters") + self.tra.fc.reset_parameters() + self.tra.router.reset_parameters() if self.freeze_model: - self.logger.warninging(f"freeze model parameters") + self.logger.warning(f"freeze model parameters") for param in self.model.parameters(): param.requires_grad_(False) + if self.freeze_predictors: - self.logger.warninging(f"freeze TRA.predictors parameters") + self.logger.warning(f"freeze TRA.predictors parameters") for param in self.tra.predictors.parameters(): param.requires_grad_(False) @@ -169,7 +181,11 @@ class TRAModel(Model): self.model.train() self.tra.train() data_set.train() + self.optimizer.zero_grad() + P_all = [] + prob_all = [] + choice_all = [] max_steps = len(data_set) if self.max_steps_per_epoch is not None: if epoch == 0 and self.max_steps_per_epoch < max_steps: @@ -184,49 +200,76 @@ class TRAModel(Model): if cur_step > max_steps: break - self.global_step += 1 + if not is_pretrain: + self.global_step += 1 data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"] index = batch["daily_index"] if self.use_daily_transport else batch["index"] - hidden = self.model(data) + with torch.set_grad_enabled(not self.freeze_model): + hidden = self.model(data) + all_preds, choice, prob = self.tra(hidden, state) - if not is_pretrain and self.transport_method != "none": + if is_pretrain or self.transport_method != "none": + # NOTE: use oracle transport for pre-training loss, pred, L, P = self.transport_fn( - all_preds, label, choice, prob, count, self.transport_method, training=True + all_preds, + label, + choice, + prob, + state.mean(dim=1), + count, + self.transport_method if not is_pretrain else "oracle", + self.alpha, + training=True, ) data_set.assign_data(index, L) # save loss to memory - lamb = self.lamb * (self.rho ** self.global_step) # regularization decay + if self.use_daily_transport: # only save for daily transport + P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index)) + prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index)) + choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index)) + decay = self.rho ** (self.global_step // 100) # decay every 100 steps + lamb = 0 if is_pretrain else self.lamb * decay reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict OT assignment - if self._writer is not None: + if self._writer is not None and not is_pretrain: self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step) self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step) self._writer.add_scalar("training/lamb", lamb, self.global_step) - prob_mean = prob.mean(axis=0).detach() - self._writer.add_scalar("training/prob_max", prob_mean.max(), self.global_step) - self._writer.add_scalar("training/prob_min", prob_mean.min(), self.global_step) - P_mean = P.mean(axis=0).detach() - self._writer.add_scalar("training/P_max", P_mean.max(), self.global_step) - self._writer.add_scalar("training/P_min", P_mean.min(), self.global_step) + if not self.use_daily_transport: + P_mean = P.mean(axis=0).detach() + self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step) loss = loss - lamb * reg else: pred = all_preds.mean(dim=1) loss = loss_fn(pred, label) - loss.backward() - self.optimizer.step() - self.optimizer.zero_grad() + (loss / self.update_freq).backward() + if cur_step % self.update_freq == 0: + self.optimizer.step() + self.optimizer.zero_grad() - if self._writer is not None: + if self._writer is not None and not is_pretrain: self._writer.add_scalar("training/total_loss", loss.item(), self.global_step) total_loss += loss.item() total_count += 1 + if self.use_daily_transport and len(P_all): + P_all = pd.concat(P_all, axis=0) + prob_all = pd.concat(prob_all, axis=0) + choice_all = pd.concat(choice_all, axis=0) + P_all.index = data_set.restore_daily_index(P_all.index) + prob_all.index = P_all.index + choice_all.index = P_all.index + if not is_pretrain: + self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC") + self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC") + self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC") + total_loss /= total_count - if self._writer is not None: + if self._writer is not None and not is_pretrain: self._writer.add_scalar("training/loss", total_loss, epoch) return total_loss @@ -239,6 +282,7 @@ class TRAModel(Model): preds = [] probs = [] + P_all = [] metrics = [] for batch in tqdm(data_set): data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"] @@ -248,11 +292,21 @@ class TRAModel(Model): hidden = self.model(data) all_preds, choice, prob = self.tra(hidden, state) - if not is_pretrain and self.transport_method != "none": + if is_pretrain or self.transport_method != "none": loss, pred, L, P = self.transport_fn( - all_preds, label, choice, prob, count, self.transport_method, training=False + all_preds, + label, + choice, + prob, + state.mean(dim=1), + count, + self.transport_method if not is_pretrain else "oracle", + self.alpha, + training=False, ) data_set.assign_data(index, L) # save loss to memory + if P is not None and return_pred: + P_all.append(pd.DataFrame(P.cpu().numpy(), index=index)) else: pred = all_preds.mean(dim=1) @@ -276,7 +330,7 @@ class TRAModel(Model): "ICIR": metrics.IC.mean() / metrics.IC.std(), } - if self._writer is not None and epoch >= 0: + if self._writer is not None and epoch >= 0 and not is_pretrain: for key, value in metrics.items(): self._writer.add_scalar(prefix + "/" + key, value, epoch) @@ -285,6 +339,7 @@ class TRAModel(Model): preds.index = data_set.restore_index(preds.index) preds.index = preds.index.swaplevel() preds.sort_index(inplace=True) + if probs: probs = pd.concat(probs, axis=0) if self.use_daily_transport: @@ -294,9 +349,18 @@ class TRAModel(Model): probs.index = probs.index.swaplevel() probs.sort_index(inplace=True) - return metrics, preds, probs + if len(P_all): + P_all = pd.concat(P_all, axis=0) + if self.use_daily_transport: + P_all.index = data_set.restore_daily_index(P_all.index) + else: + P_all.index = data_set.restore_index(P_all.index) + P_all.index = P_all.index.swaplevel() + P_all.sort_index(inplace=True) - def _fit(self, train_set, valid_set, test_set, evals_result, start_epoch=0, is_pretrain=True): + return metrics, preds, probs, P_all + + def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True): best_score = -1 best_epoch = 0 @@ -305,29 +369,18 @@ class TRAModel(Model): "model": copy.deepcopy(self.model.state_dict()), "tra": copy.deepcopy(self.tra.state_dict()), } - params_list = { - "model": collections.deque(maxlen=self.smooth_steps), - "tra": collections.deque(maxlen=self.smooth_steps), - } - # train - if not is_pretrain and self.transport_method == "router": + if not is_pretrain and self.transport_method != "none": self.logger.info("init memory...") self.test_epoch(-1, train_set) - for epoch in range(start_epoch, start_epoch + self.n_epochs): + for epoch in range(self.n_epochs): self.logger.info("Epoch %d:", epoch) self.logger.info("training...") self.train_epoch(epoch, train_set, is_pretrain=is_pretrain) self.logger.info("evaluating...") - # average params for inference - params_list["model"].append(copy.deepcopy(self.model.state_dict())) - params_list["tra"].append(copy.deepcopy(self.tra.state_dict())) - self.model.load_state_dict(average_params(params_list["model"])) - self.tra.load_state_dict(average_params(params_list["tra"])) - # NOTE: during evaluating, the whole memory will be refreshed if not is_pretrain and (self.transport_method == "router" or self.eval_train): train_set.clear_memory() # NOTE: clear the shared memory @@ -360,15 +413,11 @@ class TRAModel(Model): self.logger.info("early stop @ %s" % epoch) break - # restore parameters - self.model.load_state_dict(params_list["model"][-1]) - self.tra.load_state_dict(params_list["tra"][-1]) - self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) self.model.load_state_dict(best_params["model"]) self.tra.load_state_dict(best_params["tra"]) - return best_score, epoch + return best_score def fit(self, dataset, evals_result=dict()): @@ -383,29 +432,27 @@ class TRAModel(Model): evals_result["valid"] = [] evals_result["test"] = [] - epoch = 0 if self.pretrain: - self.logger.info("pretraining...") - self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr) - _, epoch = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True) - - self.logger.info("reset TRA") - self.tra.reset_parameters() # reset both router and predictors + self.optimizer = optim.Adam( + list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr + ) + self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True) + # reset optimizer self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr) self.logger.info("training...") - best_score, _ = self._fit(train_set, valid_set, test_set, evals_result, start_epoch=epoch, is_pretrain=False) + best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False) self.logger.info("inference") - train_metrics, train_preds, train_probs = self.test_epoch(-1, train_set, return_pred=True) + train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True) self.logger.info("train metrics: %s" % train_metrics) - valid_metrics, valid_preds, valid_probs = self.test_epoch(-1, valid_set, return_pred=True) + valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True) self.logger.info("valid metrics: %s" % valid_metrics) - test_metrics, test_preds, test_probs = self.test_epoch(-1, test_set, return_pred=True) + test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True) self.logger.info("test metrics: %s" % test_metrics) if self.logdir: @@ -426,6 +473,11 @@ class TRAModel(Model): valid_probs.to_pickle(self.logdir + "/valid_prob.pkl") test_probs.to_pickle(self.logdir + "/test_prob.pkl") + if len(train_P): + train_P.to_pickle(self.logdir + "/train_P.pkl") + valid_P.to_pickle(self.logdir + "/valid_P.pkl") + test_P.to_pickle(self.logdir + "/test_P.pkl") + info = { "config": { "model_config": self.model_config, @@ -434,10 +486,10 @@ class TRAModel(Model): "lr": self.lr, "n_epochs": self.n_epochs, "early_stop": self.early_stop, - "smooth_steps": self.smooth_steps, "max_steps_per_epoch": self.max_steps_per_epoch, "lamb": self.lamb, "rho": self.rho, + "alpha": self.alpha, "seed": self.seed, "logdir": self.logdir, "pretrain": self.pretrain, @@ -460,7 +512,7 @@ class TRAModel(Model): test_set = dataset.prepare(segment) - metrics, preds, probs = self.test_epoch(-1, test_set, return_pred=True) + metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True) self.logger.info("test metrics: %s" % metrics) return preds @@ -476,7 +528,7 @@ class RNN(nn.Module): num_layers (int): number of hidden layers rnn_arch (str): rnn architecture use_attn (bool): whether use attention layer. - we use concat attention as https://github.com/fulifeng/Adv-AGRU/ + we use concat attention as https://github.com/fulifeng/Adv-ALSTM/ dropout (float): dropout rate """ @@ -498,10 +550,14 @@ class RNN(nn.Module): self.rnn_arch = rnn_arch self.use_attn = use_attn - self.input_proj = nn.Linear(input_size, hidden_size) + if hidden_size < input_size: + # compression + self.input_proj = nn.Linear(input_size, hidden_size) + else: + self.input_proj = None self.rnn = getattr(nn, rnn_arch)( - input_size=hidden_size, + input_size=min(input_size, hidden_size), hidden_size=hidden_size, num_layers=num_layers, batch_first=True, @@ -517,7 +573,8 @@ class RNN(nn.Module): def forward(self, x): - x = self.input_proj(x) + if self.input_proj is not None: + x = self.input_proj(x) rnn_out, last_out = self.rnn(x) if self.rnn_arch == "LSTM": @@ -617,24 +674,36 @@ class TRA(nn.Module): src_info (str): information for the router """ - def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"): + def __init__( + self, + input_size, + num_states=1, + hidden_size=8, + rnn_arch="GRU", + num_layers=1, + dropout=0.0, + tau=1.0, + src_info="LR_TPE", + ): super().__init__() assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`" self.num_states = num_states self.tau = tau + self.rnn_arch = rnn_arch self.src_info = src_info self.predictors = nn.Linear(input_size, num_states) if self.num_states > 1: if "TPE" in src_info: - self.router = nn.GRU( + self.router = getattr(nn, rnn_arch)( input_size=num_states, hidden_size=hidden_size, - num_layers=1, + num_layers=num_layers, batch_first=True, + dropout=dropout, ) self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states) else: @@ -652,7 +721,10 @@ class TRA(nn.Module): return preds, None, None if "TPE" in self.src_info: - out = self.router(hist_loss)[0][:, -1] # TPE + out = self.router(hist_loss)[1] # TPE + if self.rnn_arch == "LSTM": + out = out[0] + out = out.mean(dim=0) if "LR" in self.src_info: out = torch.cat([hidden, out], dim=-1) # LR_TPE else: @@ -677,26 +749,6 @@ def evaluate(pred): return {"MSE": MSE, "MAE": MAE, "IC": IC} -def average_params(params_list): - assert isinstance(params_list, (tuple, list, collections.deque)) - n = len(params_list) - if n == 1: - return params_list[0] - new_params = collections.OrderedDict() - keys = None - for i, params in enumerate(params_list): - if keys is None: - keys = params.keys() - for k, v in params.items(): - if k not in keys: - raise ValueError("the %d-th model has different params" % i) - if k not in new_params: - new_params[k] = v / n - else: - new_params[k] += v / n - return new_params - - def shoot_infs(inp_tensor): """Replaces inf by maximum of tensor""" mask_inf = torch.isinf(inp_tensor) @@ -716,7 +768,7 @@ def shoot_infs(inp_tensor): return inp_tensor -def sinkhorn(Q, n_iters=3, epsilon=0.01): +def sinkhorn(Q, n_iters=3, epsilon=0.1): # epsilon should be adjusted according to logits value's scale with torch.no_grad(): Q = torch.exp(Q / epsilon) @@ -734,7 +786,16 @@ def loss_fn(pred, label): return (pred[mask] - label[mask]).pow(2).mean(dim=0) -def transport_sample(all_preds, label, choice, prob, count, transport_method, training=False): +def minmax_norm(x): + xmin = x.min(dim=-1, keepdim=True).values + xmax = x.max(dim=-1, keepdim=True).values + mask = (xmin == xmax).squeeze() + x = (x - xmin) / (xmax - xmin + 1e-12) + x[mask] = 1 + return x + + +def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False): """ sample-wise transport @@ -743,39 +804,43 @@ def transport_sample(all_preds, label, choice, prob, count, transport_method, tr label (torch.Tensor): label, [sample] choice (torch.Tensor): gumbel softmax choice, [sample x states] prob (torch.Tensor): router predicted probility, [sample x states] + hist_loss (torch.Tensor): history loss matrix, [sample x states] count (list): sample counts for each day, empty list for sample-wise transport transport_method (str): transportation method + alpha (float): fusion parameter for calculating transport loss matrix training (bool): indicate training or inference """ assert all_preds.shape == choice.shape assert len(all_preds) == len(label) assert transport_method in ["oracle", "router"] - all_loss = (all_preds - label[:, None]).pow(2) # [sample x states] - all_loss[torch.isnan(label)] = 0.0 + all_loss = torch.zeros_like(all_preds) + mask = ~torch.isnan(label) + all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2) # [sample x states] + + L = minmax_norm(all_loss.detach()) + Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport + Lh = minmax_norm(Lh) + P = sinkhorn(-Lh) + del Lh if transport_method == "router": - if training: # router training + if training: pred = (all_preds * choice).sum(dim=1) # gumbel softmax - else: # router inference + else: pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax - elif not training: # oracle inference: always choose the model with the smallest loss - pred = all_preds[range(len(all_preds)), all_loss.argmin(dim=-1)] - else: # oracle training: pred is not needed - pred = None + else: + pred = (all_preds * P).sum(dim=1) - L = (all_loss - all_loss.min(dim=1, keepdim=True).values).detach() # normalize - P = sinkhorn(-L) if training else None # use sinkhorn to get sample assignment during training - - if pred is not None: # router training/inference & oracle inference loss + if transport_method == "router": loss = loss_fn(pred, label) - else: # oracle training loss + else: loss = (all_loss * P).sum(dim=1).mean() return loss, pred, L, P -def transport_daily(all_preds, label, choice, prob, count, transport_method, training=False): +def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False): """ daily transport @@ -784,8 +849,10 @@ def transport_daily(all_preds, label, choice, prob, count, transport_method, tra label (torch.Tensor): label, [sample] choice (torch.Tensor): gumbel softmax choice, [days x states] prob (torch.Tensor): router predicted probility, [days x states] + hist_loss (torch.Tensor): history loss matrix, [days x states] count (list): sample counts for each day, [days] transport_method (str): transportation method + alpha (float): fusion parameter for calculating transport loss matrix training (bool): indicate training or inference """ assert len(prob) == len(count) @@ -793,34 +860,85 @@ def transport_daily(all_preds, label, choice, prob, count, transport_method, tra assert transport_method in ["oracle", "router"] all_loss = [] # loss of all predictions - pred = [] # final predictions start = 0 for i, cnt in enumerate(count): slc = slice(start, start + cnt) # samples from the i-th day start += cnt tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day all_loss.append(tloss) - if transport_method == "router": - if training: # router training - tpred = all_preds[slc] @ choice[i] # gumbel softmax - else: # router inference - tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax - elif not training: # oracle inference: always choose the model with the smallest loss - tpred = all_preds[slc][:, tloss.argmin(dim=-1)] - else: # oracle training: pred is not needed - tpred = None - if tpred is not None: - pred.append(tpred) all_loss = torch.stack(all_loss, dim=0) # [days x states] - if pred: - pred = torch.cat(pred, dim=0) # [samples] - L = (all_loss - all_loss.min(dim=1, keepdim=True).values).detach() # normalize - P = sinkhorn(-L) if training else None # use sinkhorn to get sample assignment during training + L = minmax_norm(all_loss.detach()) + Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport + Lh = minmax_norm(Lh) + P = sinkhorn(-Lh) + del Lh - if len(pred): # router training/inference & oracle inference loss + pred = [] + start = 0 + for i, cnt in enumerate(count): + slc = slice(start, start + cnt) # samples from the i-th day + start += cnt + if transport_method == "router": + if training: + tpred = all_preds[slc] @ choice[i] # gumbel softmax + else: + tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax + else: + tpred = all_preds[slc] @ P[i] + pred.append(tpred) + pred = torch.cat(pred, dim=0) # [samples] + + if transport_method == "router": loss = loss_fn(pred, label) - else: # oracle training loss + else: loss = (all_loss * P).sum(dim=1).mean() return loss, pred, L, P + + +def load_state_dict_unsafe(model, state_dict): + """ + Load state dict to provided model while ignore exceptions. + """ + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model) + load = None # break load->load reference cycle + + return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs} + + +def plot(P): + assert isinstance(P, pd.DataFrame) + + fig, axes = plt.subplots(1, 2, figsize=(10, 4)) + P.plot.area(ax=axes[0], xlabel="") + P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel="") + plt.tight_layout() + + with io.BytesIO() as buf: + plt.savefig(buf, format="png") + buf.seek(0) + img = plt.imread(buf) + plt.close() + + return np.uint8(img * 255)