mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
bug fix & use oracle transport pretrain
This commit is contained in:
@@ -6,30 +6,30 @@ market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
@@ -37,7 +37,10 @@ memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
@@ -50,21 +53,21 @@ model_config: &model_config
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
@@ -76,13 +79,13 @@ task:
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 10
|
||||
smooth_steps: 5
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158
|
||||
seed: 0
|
||||
logdir:
|
||||
lamb: 1.0
|
||||
rho: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
|
||||
@@ -6,24 +6,24 @@ market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
@@ -31,7 +31,10 @@ memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
@@ -44,21 +47,21 @@ model_config: &model_config
|
||||
dropout: 0.2
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
@@ -70,13 +73,13 @@ task:
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 10
|
||||
smooth_steps: 5
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha158_full
|
||||
seed: 0
|
||||
logdir:
|
||||
lamb: 1.0
|
||||
rho: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
|
||||
@@ -6,24 +6,24 @@ market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
num_states: &num_states 3
|
||||
|
||||
@@ -31,7 +31,10 @@ memory_mode: &memory_mode sample
|
||||
|
||||
tra_config: &tra_config
|
||||
num_states: *num_states
|
||||
hidden_size: 16
|
||||
rnn_arch: LSTM
|
||||
hidden_size: 32
|
||||
num_layers: 1
|
||||
dropout: 0.0
|
||||
tau: 1.0
|
||||
src_info: LR_TPE
|
||||
|
||||
@@ -44,21 +47,21 @@ model_config: &model_config
|
||||
dropout: 0.0
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
task:
|
||||
model:
|
||||
@@ -70,13 +73,13 @@ task:
|
||||
model_type: RNN
|
||||
lr: 1e-3
|
||||
n_epochs: 100
|
||||
max_steps_per_epoch: 100
|
||||
early_stop: 10
|
||||
smooth_steps: 5
|
||||
logdir:
|
||||
max_steps_per_epoch:
|
||||
early_stop: 20
|
||||
logdir: output/Alpha360
|
||||
seed: 0
|
||||
lamb: 1.0
|
||||
rho: 1.0
|
||||
rho: 0.99
|
||||
alpha: 0.5
|
||||
transport_method: router
|
||||
memory_mode: *memory_mode
|
||||
eval_train: False
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import io
|
||||
import os
|
||||
import copy
|
||||
import math
|
||||
@@ -8,6 +9,8 @@ import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -40,10 +43,11 @@ class TRAModel(Model):
|
||||
lr (float): learning rate
|
||||
n_epochs (int): number of total epochs
|
||||
early_stop (int): early stop when performance not improved at this step
|
||||
smooth_steps (int): number of steps for parameter smoothing
|
||||
update_freq (int): gradient update frequency
|
||||
max_steps_per_epoch (int): maximum number of steps in one epoch
|
||||
lamb (float): regularization parameter
|
||||
rho (float): exponential decay rate for `lamb`
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
seed (int): random seed
|
||||
logdir (str): local log directory
|
||||
eval_train (bool): whether evaluate train set between epochs
|
||||
@@ -65,16 +69,18 @@ class TRAModel(Model):
|
||||
lr=1e-3,
|
||||
n_epochs=500,
|
||||
early_stop=50,
|
||||
smooth_steps=5,
|
||||
update_freq=1,
|
||||
max_steps_per_epoch=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
alpha=1.0,
|
||||
seed=0,
|
||||
logdir=None,
|
||||
eval_train=False,
|
||||
eval_test=False,
|
||||
pretrain=False,
|
||||
init_state=None,
|
||||
reset_router=False,
|
||||
freeze_model=False,
|
||||
freeze_predictors=False,
|
||||
transport_method="none",
|
||||
@@ -102,16 +108,18 @@ class TRAModel(Model):
|
||||
self.lr = lr
|
||||
self.n_epochs = n_epochs
|
||||
self.early_stop = early_stop
|
||||
self.smooth_steps = smooth_steps
|
||||
self.update_freq = update_freq
|
||||
self.max_steps_per_epoch = max_steps_per_epoch
|
||||
self.lamb = lamb
|
||||
self.rho = rho
|
||||
self.alpha = alpha
|
||||
self.seed = seed
|
||||
self.logdir = logdir
|
||||
self.eval_train = eval_train
|
||||
self.eval_test = eval_test
|
||||
self.pretrain = pretrain
|
||||
self.init_state = init_state
|
||||
self.reset_router = reset_router
|
||||
self.freeze_model = freeze_model
|
||||
self.freeze_predictors = freeze_predictors
|
||||
self.transport_method = transport_method
|
||||
@@ -139,20 +147,24 @@ class TRAModel(Model):
|
||||
print(self.tra)
|
||||
|
||||
if self.init_state:
|
||||
self.logger.warninging(f"load state dict from `init_state`")
|
||||
self.logger.warning(f"load state dict from `init_state`")
|
||||
state_dict = torch.load(self.init_state, map_location="cpu")
|
||||
self.model.load_state_dict(state_dict["model"])
|
||||
try:
|
||||
self.tra.load_state_dict(state_dict["tra"])
|
||||
except:
|
||||
self.logger.warninging("cannot load tra model, will skip")
|
||||
res = load_state_dict_unsafe(self.tra, state_dict["tra"])
|
||||
self.logger.warning(str(res))
|
||||
|
||||
if self.reset_router:
|
||||
self.logger.warning(f"reset TRA.router parameters")
|
||||
self.tra.fc.reset_parameters()
|
||||
self.tra.router.reset_parameters()
|
||||
|
||||
if self.freeze_model:
|
||||
self.logger.warninging(f"freeze model parameters")
|
||||
self.logger.warning(f"freeze model parameters")
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
if self.freeze_predictors:
|
||||
self.logger.warninging(f"freeze TRA.predictors parameters")
|
||||
self.logger.warning(f"freeze TRA.predictors parameters")
|
||||
for param in self.tra.predictors.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
@@ -169,7 +181,11 @@ class TRAModel(Model):
|
||||
self.model.train()
|
||||
self.tra.train()
|
||||
data_set.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
P_all = []
|
||||
prob_all = []
|
||||
choice_all = []
|
||||
max_steps = len(data_set)
|
||||
if self.max_steps_per_epoch is not None:
|
||||
if epoch == 0 and self.max_steps_per_epoch < max_steps:
|
||||
@@ -184,49 +200,76 @@ class TRAModel(Model):
|
||||
if cur_step > max_steps:
|
||||
break
|
||||
|
||||
self.global_step += 1
|
||||
if not is_pretrain:
|
||||
self.global_step += 1
|
||||
|
||||
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
|
||||
index = batch["daily_index"] if self.use_daily_transport else batch["index"]
|
||||
|
||||
hidden = self.model(data)
|
||||
with torch.set_grad_enabled(not self.freeze_model):
|
||||
hidden = self.model(data)
|
||||
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if not is_pretrain and self.transport_method != "none":
|
||||
if is_pretrain or self.transport_method != "none":
|
||||
# NOTE: use oracle transport for pre-training
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds, label, choice, prob, count, self.transport_method, training=True
|
||||
all_preds,
|
||||
label,
|
||||
choice,
|
||||
prob,
|
||||
state.mean(dim=1),
|
||||
count,
|
||||
self.transport_method if not is_pretrain else "oracle",
|
||||
self.alpha,
|
||||
training=True,
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
lamb = self.lamb * (self.rho ** self.global_step) # regularization decay
|
||||
if self.use_daily_transport: # only save for daily transport
|
||||
P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index))
|
||||
prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index))
|
||||
choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index))
|
||||
decay = self.rho ** (self.global_step // 100) # decay every 100 steps
|
||||
lamb = 0 if is_pretrain else self.lamb * decay
|
||||
reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict OT assignment
|
||||
if self._writer is not None:
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step)
|
||||
self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step)
|
||||
self._writer.add_scalar("training/lamb", lamb, self.global_step)
|
||||
prob_mean = prob.mean(axis=0).detach()
|
||||
self._writer.add_scalar("training/prob_max", prob_mean.max(), self.global_step)
|
||||
self._writer.add_scalar("training/prob_min", prob_mean.min(), self.global_step)
|
||||
P_mean = P.mean(axis=0).detach()
|
||||
self._writer.add_scalar("training/P_max", P_mean.max(), self.global_step)
|
||||
self._writer.add_scalar("training/P_min", P_mean.min(), self.global_step)
|
||||
if not self.use_daily_transport:
|
||||
P_mean = P.mean(axis=0).detach()
|
||||
self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step)
|
||||
loss = loss - lamb * reg
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
loss = loss_fn(pred, label)
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
(loss / self.update_freq).backward()
|
||||
if cur_step % self.update_freq == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self._writer is not None:
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/total_loss", loss.item(), self.global_step)
|
||||
|
||||
total_loss += loss.item()
|
||||
total_count += 1
|
||||
|
||||
if self.use_daily_transport and len(P_all):
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
prob_all = pd.concat(prob_all, axis=0)
|
||||
choice_all = pd.concat(choice_all, axis=0)
|
||||
P_all.index = data_set.restore_daily_index(P_all.index)
|
||||
prob_all.index = P_all.index
|
||||
choice_all.index = P_all.index
|
||||
if not is_pretrain:
|
||||
self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC")
|
||||
self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC")
|
||||
self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC")
|
||||
|
||||
total_loss /= total_count
|
||||
|
||||
if self._writer is not None:
|
||||
if self._writer is not None and not is_pretrain:
|
||||
self._writer.add_scalar("training/loss", total_loss, epoch)
|
||||
|
||||
return total_loss
|
||||
@@ -239,6 +282,7 @@ class TRAModel(Model):
|
||||
|
||||
preds = []
|
||||
probs = []
|
||||
P_all = []
|
||||
metrics = []
|
||||
for batch in tqdm(data_set):
|
||||
data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"]
|
||||
@@ -248,11 +292,21 @@ class TRAModel(Model):
|
||||
hidden = self.model(data)
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if not is_pretrain and self.transport_method != "none":
|
||||
if is_pretrain or self.transport_method != "none":
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds, label, choice, prob, count, self.transport_method, training=False
|
||||
all_preds,
|
||||
label,
|
||||
choice,
|
||||
prob,
|
||||
state.mean(dim=1),
|
||||
count,
|
||||
self.transport_method if not is_pretrain else "oracle",
|
||||
self.alpha,
|
||||
training=False,
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
if P is not None and return_pred:
|
||||
P_all.append(pd.DataFrame(P.cpu().numpy(), index=index))
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
|
||||
@@ -276,7 +330,7 @@ class TRAModel(Model):
|
||||
"ICIR": metrics.IC.mean() / metrics.IC.std(),
|
||||
}
|
||||
|
||||
if self._writer is not None and epoch >= 0:
|
||||
if self._writer is not None and epoch >= 0 and not is_pretrain:
|
||||
for key, value in metrics.items():
|
||||
self._writer.add_scalar(prefix + "/" + key, value, epoch)
|
||||
|
||||
@@ -285,6 +339,7 @@ class TRAModel(Model):
|
||||
preds.index = data_set.restore_index(preds.index)
|
||||
preds.index = preds.index.swaplevel()
|
||||
preds.sort_index(inplace=True)
|
||||
|
||||
if probs:
|
||||
probs = pd.concat(probs, axis=0)
|
||||
if self.use_daily_transport:
|
||||
@@ -294,9 +349,18 @@ class TRAModel(Model):
|
||||
probs.index = probs.index.swaplevel()
|
||||
probs.sort_index(inplace=True)
|
||||
|
||||
return metrics, preds, probs
|
||||
if len(P_all):
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
if self.use_daily_transport:
|
||||
P_all.index = data_set.restore_daily_index(P_all.index)
|
||||
else:
|
||||
P_all.index = data_set.restore_index(P_all.index)
|
||||
P_all.index = P_all.index.swaplevel()
|
||||
P_all.sort_index(inplace=True)
|
||||
|
||||
def _fit(self, train_set, valid_set, test_set, evals_result, start_epoch=0, is_pretrain=True):
|
||||
return metrics, preds, probs, P_all
|
||||
|
||||
def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True):
|
||||
|
||||
best_score = -1
|
||||
best_epoch = 0
|
||||
@@ -305,29 +369,18 @@ class TRAModel(Model):
|
||||
"model": copy.deepcopy(self.model.state_dict()),
|
||||
"tra": copy.deepcopy(self.tra.state_dict()),
|
||||
}
|
||||
params_list = {
|
||||
"model": collections.deque(maxlen=self.smooth_steps),
|
||||
"tra": collections.deque(maxlen=self.smooth_steps),
|
||||
}
|
||||
|
||||
# train
|
||||
if not is_pretrain and self.transport_method == "router":
|
||||
if not is_pretrain and self.transport_method != "none":
|
||||
self.logger.info("init memory...")
|
||||
self.test_epoch(-1, train_set)
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + self.n_epochs):
|
||||
for epoch in range(self.n_epochs):
|
||||
self.logger.info("Epoch %d:", epoch)
|
||||
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(epoch, train_set, is_pretrain=is_pretrain)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
# average params for inference
|
||||
params_list["model"].append(copy.deepcopy(self.model.state_dict()))
|
||||
params_list["tra"].append(copy.deepcopy(self.tra.state_dict()))
|
||||
self.model.load_state_dict(average_params(params_list["model"]))
|
||||
self.tra.load_state_dict(average_params(params_list["tra"]))
|
||||
|
||||
# NOTE: during evaluating, the whole memory will be refreshed
|
||||
if not is_pretrain and (self.transport_method == "router" or self.eval_train):
|
||||
train_set.clear_memory() # NOTE: clear the shared memory
|
||||
@@ -360,15 +413,11 @@ class TRAModel(Model):
|
||||
self.logger.info("early stop @ %s" % epoch)
|
||||
break
|
||||
|
||||
# restore parameters
|
||||
self.model.load_state_dict(params_list["model"][-1])
|
||||
self.tra.load_state_dict(params_list["tra"][-1])
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_params["model"])
|
||||
self.tra.load_state_dict(best_params["tra"])
|
||||
|
||||
return best_score, epoch
|
||||
return best_score
|
||||
|
||||
def fit(self, dataset, evals_result=dict()):
|
||||
|
||||
@@ -383,29 +432,27 @@ class TRAModel(Model):
|
||||
evals_result["valid"] = []
|
||||
evals_result["test"] = []
|
||||
|
||||
epoch = 0
|
||||
if self.pretrain:
|
||||
|
||||
self.logger.info("pretraining...")
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
_, epoch = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
|
||||
|
||||
self.logger.info("reset TRA")
|
||||
self.tra.reset_parameters() # reset both router and predictors
|
||||
self.optimizer = optim.Adam(
|
||||
list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr
|
||||
)
|
||||
self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True)
|
||||
|
||||
# reset optimizer
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
|
||||
self.logger.info("training...")
|
||||
best_score, _ = self._fit(train_set, valid_set, test_set, evals_result, start_epoch=epoch, is_pretrain=False)
|
||||
best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False)
|
||||
|
||||
self.logger.info("inference")
|
||||
train_metrics, train_preds, train_probs = self.test_epoch(-1, train_set, return_pred=True)
|
||||
train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True)
|
||||
self.logger.info("train metrics: %s" % train_metrics)
|
||||
|
||||
valid_metrics, valid_preds, valid_probs = self.test_epoch(-1, valid_set, return_pred=True)
|
||||
valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True)
|
||||
self.logger.info("valid metrics: %s" % valid_metrics)
|
||||
|
||||
test_metrics, test_preds, test_probs = self.test_epoch(-1, test_set, return_pred=True)
|
||||
test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % test_metrics)
|
||||
|
||||
if self.logdir:
|
||||
@@ -426,6 +473,11 @@ class TRAModel(Model):
|
||||
valid_probs.to_pickle(self.logdir + "/valid_prob.pkl")
|
||||
test_probs.to_pickle(self.logdir + "/test_prob.pkl")
|
||||
|
||||
if len(train_P):
|
||||
train_P.to_pickle(self.logdir + "/train_P.pkl")
|
||||
valid_P.to_pickle(self.logdir + "/valid_P.pkl")
|
||||
test_P.to_pickle(self.logdir + "/test_P.pkl")
|
||||
|
||||
info = {
|
||||
"config": {
|
||||
"model_config": self.model_config,
|
||||
@@ -434,10 +486,10 @@ class TRAModel(Model):
|
||||
"lr": self.lr,
|
||||
"n_epochs": self.n_epochs,
|
||||
"early_stop": self.early_stop,
|
||||
"smooth_steps": self.smooth_steps,
|
||||
"max_steps_per_epoch": self.max_steps_per_epoch,
|
||||
"lamb": self.lamb,
|
||||
"rho": self.rho,
|
||||
"alpha": self.alpha,
|
||||
"seed": self.seed,
|
||||
"logdir": self.logdir,
|
||||
"pretrain": self.pretrain,
|
||||
@@ -460,7 +512,7 @@ class TRAModel(Model):
|
||||
|
||||
test_set = dataset.prepare(segment)
|
||||
|
||||
metrics, preds, probs = self.test_epoch(-1, test_set, return_pred=True)
|
||||
metrics, preds, _, _ = self.test_epoch(-1, test_set, return_pred=True)
|
||||
self.logger.info("test metrics: %s" % metrics)
|
||||
|
||||
return preds
|
||||
@@ -476,7 +528,7 @@ class RNN(nn.Module):
|
||||
num_layers (int): number of hidden layers
|
||||
rnn_arch (str): rnn architecture
|
||||
use_attn (bool): whether use attention layer.
|
||||
we use concat attention as https://github.com/fulifeng/Adv-AGRU/
|
||||
we use concat attention as https://github.com/fulifeng/Adv-ALSTM/
|
||||
dropout (float): dropout rate
|
||||
"""
|
||||
|
||||
@@ -498,10 +550,14 @@ class RNN(nn.Module):
|
||||
self.rnn_arch = rnn_arch
|
||||
self.use_attn = use_attn
|
||||
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
if hidden_size < input_size:
|
||||
# compression
|
||||
self.input_proj = nn.Linear(input_size, hidden_size)
|
||||
else:
|
||||
self.input_proj = None
|
||||
|
||||
self.rnn = getattr(nn, rnn_arch)(
|
||||
input_size=hidden_size,
|
||||
input_size=min(input_size, hidden_size),
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
@@ -517,7 +573,8 @@ class RNN(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.input_proj(x)
|
||||
if self.input_proj is not None:
|
||||
x = self.input_proj(x)
|
||||
|
||||
rnn_out, last_out = self.rnn(x)
|
||||
if self.rnn_arch == "LSTM":
|
||||
@@ -617,24 +674,36 @@ class TRA(nn.Module):
|
||||
src_info (str): information for the router
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
num_states=1,
|
||||
hidden_size=8,
|
||||
rnn_arch="GRU",
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
tau=1.0,
|
||||
src_info="LR_TPE",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
|
||||
|
||||
self.num_states = num_states
|
||||
self.tau = tau
|
||||
self.rnn_arch = rnn_arch
|
||||
self.src_info = src_info
|
||||
|
||||
self.predictors = nn.Linear(input_size, num_states)
|
||||
|
||||
if self.num_states > 1:
|
||||
if "TPE" in src_info:
|
||||
self.router = nn.GRU(
|
||||
self.router = getattr(nn, rnn_arch)(
|
||||
input_size=num_states,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states)
|
||||
else:
|
||||
@@ -652,7 +721,10 @@ class TRA(nn.Module):
|
||||
return preds, None, None
|
||||
|
||||
if "TPE" in self.src_info:
|
||||
out = self.router(hist_loss)[0][:, -1] # TPE
|
||||
out = self.router(hist_loss)[1] # TPE
|
||||
if self.rnn_arch == "LSTM":
|
||||
out = out[0]
|
||||
out = out.mean(dim=0)
|
||||
if "LR" in self.src_info:
|
||||
out = torch.cat([hidden, out], dim=-1) # LR_TPE
|
||||
else:
|
||||
@@ -677,26 +749,6 @@ def evaluate(pred):
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
|
||||
def average_params(params_list):
|
||||
assert isinstance(params_list, (tuple, list, collections.deque))
|
||||
n = len(params_list)
|
||||
if n == 1:
|
||||
return params_list[0]
|
||||
new_params = collections.OrderedDict()
|
||||
keys = None
|
||||
for i, params in enumerate(params_list):
|
||||
if keys is None:
|
||||
keys = params.keys()
|
||||
for k, v in params.items():
|
||||
if k not in keys:
|
||||
raise ValueError("the %d-th model has different params" % i)
|
||||
if k not in new_params:
|
||||
new_params[k] = v / n
|
||||
else:
|
||||
new_params[k] += v / n
|
||||
return new_params
|
||||
|
||||
|
||||
def shoot_infs(inp_tensor):
|
||||
"""Replaces inf by maximum of tensor"""
|
||||
mask_inf = torch.isinf(inp_tensor)
|
||||
@@ -716,7 +768,7 @@ def shoot_infs(inp_tensor):
|
||||
return inp_tensor
|
||||
|
||||
|
||||
def sinkhorn(Q, n_iters=3, epsilon=0.01):
|
||||
def sinkhorn(Q, n_iters=3, epsilon=0.1):
|
||||
# epsilon should be adjusted according to logits value's scale
|
||||
with torch.no_grad():
|
||||
Q = torch.exp(Q / epsilon)
|
||||
@@ -734,7 +786,16 @@ def loss_fn(pred, label):
|
||||
return (pred[mask] - label[mask]).pow(2).mean(dim=0)
|
||||
|
||||
|
||||
def transport_sample(all_preds, label, choice, prob, count, transport_method, training=False):
|
||||
def minmax_norm(x):
|
||||
xmin = x.min(dim=-1, keepdim=True).values
|
||||
xmax = x.max(dim=-1, keepdim=True).values
|
||||
mask = (xmin == xmax).squeeze()
|
||||
x = (x - xmin) / (xmax - xmin + 1e-12)
|
||||
x[mask] = 1
|
||||
return x
|
||||
|
||||
|
||||
def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
|
||||
"""
|
||||
sample-wise transport
|
||||
|
||||
@@ -743,39 +804,43 @@ def transport_sample(all_preds, label, choice, prob, count, transport_method, tr
|
||||
label (torch.Tensor): label, [sample]
|
||||
choice (torch.Tensor): gumbel softmax choice, [sample x states]
|
||||
prob (torch.Tensor): router predicted probility, [sample x states]
|
||||
hist_loss (torch.Tensor): history loss matrix, [sample x states]
|
||||
count (list): sample counts for each day, empty list for sample-wise transport
|
||||
transport_method (str): transportation method
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
training (bool): indicate training or inference
|
||||
"""
|
||||
assert all_preds.shape == choice.shape
|
||||
assert len(all_preds) == len(label)
|
||||
assert transport_method in ["oracle", "router"]
|
||||
|
||||
all_loss = (all_preds - label[:, None]).pow(2) # [sample x states]
|
||||
all_loss[torch.isnan(label)] = 0.0
|
||||
all_loss = torch.zeros_like(all_preds)
|
||||
mask = ~torch.isnan(label)
|
||||
all_loss[mask] = (all_preds[mask] - label[mask, None]).pow(2) # [sample x states]
|
||||
|
||||
L = minmax_norm(all_loss.detach())
|
||||
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
|
||||
Lh = minmax_norm(Lh)
|
||||
P = sinkhorn(-Lh)
|
||||
del Lh
|
||||
|
||||
if transport_method == "router":
|
||||
if training: # router training
|
||||
if training:
|
||||
pred = (all_preds * choice).sum(dim=1) # gumbel softmax
|
||||
else: # router inference
|
||||
else:
|
||||
pred = all_preds[range(len(all_preds)), prob.argmax(dim=-1)] # argmax
|
||||
elif not training: # oracle inference: always choose the model with the smallest loss
|
||||
pred = all_preds[range(len(all_preds)), all_loss.argmin(dim=-1)]
|
||||
else: # oracle training: pred is not needed
|
||||
pred = None
|
||||
else:
|
||||
pred = (all_preds * P).sum(dim=1)
|
||||
|
||||
L = (all_loss - all_loss.min(dim=1, keepdim=True).values).detach() # normalize
|
||||
P = sinkhorn(-L) if training else None # use sinkhorn to get sample assignment during training
|
||||
|
||||
if pred is not None: # router training/inference & oracle inference loss
|
||||
if transport_method == "router":
|
||||
loss = loss_fn(pred, label)
|
||||
else: # oracle training loss
|
||||
else:
|
||||
loss = (all_loss * P).sum(dim=1).mean()
|
||||
|
||||
return loss, pred, L, P
|
||||
|
||||
|
||||
def transport_daily(all_preds, label, choice, prob, count, transport_method, training=False):
|
||||
def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False):
|
||||
"""
|
||||
daily transport
|
||||
|
||||
@@ -784,8 +849,10 @@ def transport_daily(all_preds, label, choice, prob, count, transport_method, tra
|
||||
label (torch.Tensor): label, [sample]
|
||||
choice (torch.Tensor): gumbel softmax choice, [days x states]
|
||||
prob (torch.Tensor): router predicted probility, [days x states]
|
||||
hist_loss (torch.Tensor): history loss matrix, [days x states]
|
||||
count (list): sample counts for each day, [days]
|
||||
transport_method (str): transportation method
|
||||
alpha (float): fusion parameter for calculating transport loss matrix
|
||||
training (bool): indicate training or inference
|
||||
"""
|
||||
assert len(prob) == len(count)
|
||||
@@ -793,34 +860,85 @@ def transport_daily(all_preds, label, choice, prob, count, transport_method, tra
|
||||
assert transport_method in ["oracle", "router"]
|
||||
|
||||
all_loss = [] # loss of all predictions
|
||||
pred = [] # final predictions
|
||||
start = 0
|
||||
for i, cnt in enumerate(count):
|
||||
slc = slice(start, start + cnt) # samples from the i-th day
|
||||
start += cnt
|
||||
tloss = loss_fn(all_preds[slc], label[slc]) # loss of the i-th day
|
||||
all_loss.append(tloss)
|
||||
if transport_method == "router":
|
||||
if training: # router training
|
||||
tpred = all_preds[slc] @ choice[i] # gumbel softmax
|
||||
else: # router inference
|
||||
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
|
||||
elif not training: # oracle inference: always choose the model with the smallest loss
|
||||
tpred = all_preds[slc][:, tloss.argmin(dim=-1)]
|
||||
else: # oracle training: pred is not needed
|
||||
tpred = None
|
||||
if tpred is not None:
|
||||
pred.append(tpred)
|
||||
all_loss = torch.stack(all_loss, dim=0) # [days x states]
|
||||
if pred:
|
||||
pred = torch.cat(pred, dim=0) # [samples]
|
||||
|
||||
L = (all_loss - all_loss.min(dim=1, keepdim=True).values).detach() # normalize
|
||||
P = sinkhorn(-L) if training else None # use sinkhorn to get sample assignment during training
|
||||
L = minmax_norm(all_loss.detach())
|
||||
Lh = L * alpha + minmax_norm(hist_loss) * (1 - alpha) # add hist loss for transport
|
||||
Lh = minmax_norm(Lh)
|
||||
P = sinkhorn(-Lh)
|
||||
del Lh
|
||||
|
||||
if len(pred): # router training/inference & oracle inference loss
|
||||
pred = []
|
||||
start = 0
|
||||
for i, cnt in enumerate(count):
|
||||
slc = slice(start, start + cnt) # samples from the i-th day
|
||||
start += cnt
|
||||
if transport_method == "router":
|
||||
if training:
|
||||
tpred = all_preds[slc] @ choice[i] # gumbel softmax
|
||||
else:
|
||||
tpred = all_preds[slc][:, prob[i].argmax(dim=-1)] # argmax
|
||||
else:
|
||||
tpred = all_preds[slc] @ P[i]
|
||||
pred.append(tpred)
|
||||
pred = torch.cat(pred, dim=0) # [samples]
|
||||
|
||||
if transport_method == "router":
|
||||
loss = loss_fn(pred, label)
|
||||
else: # oracle training loss
|
||||
else:
|
||||
loss = (all_loss * P).sum(dim=1).mean()
|
||||
|
||||
return loss, pred, L, P
|
||||
|
||||
|
||||
def load_state_dict_unsafe(model, state_dict):
|
||||
"""
|
||||
Load state dict to provided model while ignore exceptions.
|
||||
"""
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model)
|
||||
load = None # break load->load reference cycle
|
||||
|
||||
return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs}
|
||||
|
||||
|
||||
def plot(P):
|
||||
assert isinstance(P, pd.DataFrame)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
||||
P.plot.area(ax=axes[0], xlabel="")
|
||||
P.idxmax(axis=1).value_counts().sort_index().plot.bar(ax=axes[1], xlabel="")
|
||||
plt.tight_layout()
|
||||
|
||||
with io.BytesIO() as buf:
|
||||
plt.savefig(buf, format="png")
|
||||
buf.seek(0)
|
||||
img = plt.imread(buf)
|
||||
plt.close()
|
||||
|
||||
return np.uint8(img * 255)
|
||||
|
||||
Reference in New Issue
Block a user