mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add the HIST and IGMTF model on Alpha360 (#1040)
* Commit the code of HIST and IGMTF on Alpha360 * add stock index * Update README.md * delete useless code * fix the bug of code format with black * fix pylint bugs * fix the bugs of pylint * fix pylint bugs * fix flake8
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
|
||||
| Qlib notebook tutorial | 📖 [Released](https://github.com/microsoft/qlib/pull/1037) on Apr 7, 2022 |
|
||||
| Ibovespa index data | :rice: [Released](https://github.com/microsoft/qlib/pull/990) on Apr 6, 2022 |
|
||||
| Point-in-Time database | :hammer: [Released](https://github.com/microsoft/qlib/pull/343) on Mar 10, 2022 |
|
||||
@@ -339,6 +340,8 @@ Here is a list of models built on `Qlib`.
|
||||
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
|
||||
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
- [IGMTF based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/IGMTF/)
|
||||
- [HIST based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/HIST/)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
|
||||
3
examples/benchmarks/HIST/README.md
Normal file
3
examples/benchmarks/HIST/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# HIST
|
||||
* Code: [https://github.com/Wentao-Xu/HIST](https://github.com/Wentao-Xu/HIST)
|
||||
* Paper: [HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared InformationAdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/abs/2110.13716).
|
||||
BIN
examples/benchmarks/HIST/qlib_csi300_stock_index.npy
Normal file
BIN
examples/benchmarks/HIST/qlib_csi300_stock_index.npy
Normal file
Binary file not shown.
4
examples/benchmarks/HIST/requirements.txt
Normal file
4
examples/benchmarks/HIST/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
92
examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml
Normal file
92
examples/benchmarks/HIST/workflow_config_hist_Alpha360.yaml
Normal file
@@ -0,0 +1,92 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: HIST
|
||||
module_path: qlib.contrib.model.pytorch_hist
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: ic
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
stock2concept: "benchmarks/HIST/qlib_csi300_stock2concept.npy"
|
||||
stock_index: "benchmarks/HIST/qlib_csi300_stock_index.npy"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
4
examples/benchmarks/IGMTF/README.md
Normal file
4
examples/benchmarks/IGMTF/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# IGMTF
|
||||
* Code: [https://github.com/Wentao-Xu/IGMTF](https://github.com/Wentao-Xu/IGMTF)
|
||||
* Paper: [IGMTF: An Instance-wise Graph-based Framework for
|
||||
Multivariate Time Series Forecasting](https://arxiv.org/abs/2109.06489).
|
||||
4
examples/benchmarks/IGMTF/requirements.txt
Normal file
4
examples/benchmarks/IGMTF/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -0,0 +1,89 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: IGMTF
|
||||
module_path: qlib.contrib.model.pytorch_igmtf
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: ic
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
model_path: "benchmarks/LSTM/model_lstm_csi300.pkl"
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -65,6 +65,9 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
|
||||
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
| IGMTF(Wentao Xu, et al.) | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02 | 1.3509±0.25 | -0.0716±0.02 |
|
||||
| HIST(Wentao Xu, et al.) | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02 | 1.3726±0.27 | -0.0681±0.01 |
|
||||
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
501
qlib/contrib/model/pytorch_hist.py
Normal file
501
qlib/contrib/model/pytorch_hist.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import urllib.request
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class HIST(Model):
|
||||
"""HIST Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
stock2concept=None,
|
||||
stock_index=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("HIST")
|
||||
self.logger.info("HIST pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.model_path = model_path
|
||||
self.stock2concept = stock2concept
|
||||
self.stock_index = stock_index
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"HIST parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nstock2concept : {}"
|
||||
"\nstock_index : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
model_path,
|
||||
stock2concept,
|
||||
stock_index,
|
||||
GPU,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.HIST_model = HISTModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.HIST_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.HIST_model)))
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.HIST_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.HIST_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.HIST_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "ic":
|
||||
x = pred[mask]
|
||||
y = label[mask]
|
||||
|
||||
vx = x - torch.mean(x)
|
||||
vy = y - torch.mean(y)
|
||||
return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)))
|
||||
|
||||
if self.metric == ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def train_epoch(self, x_train, y_train, stock_index):
|
||||
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
stock_index = stock_index.values
|
||||
stock_index[np.isnan(stock_index)] = 733
|
||||
self.HIST_model.train()
|
||||
|
||||
# organize the train data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[batch]).float().to(self.device)
|
||||
pred = self.HIST_model(feature, concept_matrix)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.HIST_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y, stock_index):
|
||||
|
||||
# prepare training data
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
stock_index = stock_index.values
|
||||
stock_index[np.isnan(stock_index)] = 733
|
||||
self.HIST_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
# organize the test data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[batch]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.HIST_model(feature, concept_matrix)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
if not os.path.exists(self.stock2concept):
|
||||
url = "http://fintech.msra.cn/stock_data/downloads/qlib_csi300_stock2concept.npy"
|
||||
urllib.request.urlretrieve(url, self.stock2concept)
|
||||
|
||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||
df_train["stock_index"] = 733
|
||||
df_train["stock_index"] = df_train.index.get_level_values("instrument").map(stock_index)
|
||||
df_valid["stock_index"] = 733
|
||||
df_valid["stock_index"] = df_valid.index.get_level_values("instrument").map(stock_index)
|
||||
|
||||
x_train, y_train, stock_index_train = df_train["feature"], df_train["label"], df_train["stock_index"]
|
||||
x_valid, y_valid, stock_index_valid = df_valid["feature"], df_valid["label"], df_valid["stock_index"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
|
||||
model_dict = self.HIST_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.HIST_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train, stock_index_train)
|
||||
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train, stock_index_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid, stock_index_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.HIST_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.HIST_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
stock2concept_matrix = np.load(self.stock2concept)
|
||||
stock_index = np.load(self.stock_index, allow_pickle=True).item()
|
||||
df_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
df_test["stock_index"] = 733
|
||||
df_test["stock_index"] = df_test.index.get_level_values("instrument").map(stock_index)
|
||||
stock_index_test = df_test["stock_index"].values
|
||||
stock_index_test[np.isnan(stock_index_test)] = 733
|
||||
stock_index_test = stock_index_test.astype("int")
|
||||
df_test = df_test.drop(["stock_index"], axis=1)
|
||||
index = df_test.index
|
||||
|
||||
self.HIST_model.eval()
|
||||
x_values = df_test.values
|
||||
preds = []
|
||||
|
||||
# organize the data into daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(df_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index_test[batch]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.HIST_model(x_batch, concept_matrix).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class HISTModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
|
||||
super().__init__()
|
||||
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
if base_model == "GRU":
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
|
||||
self.fc_es = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es.weight)
|
||||
self.fc_is = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is.weight)
|
||||
|
||||
self.fc_es_middle = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_middle.weight)
|
||||
self.fc_is_middle = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_middle.weight)
|
||||
|
||||
self.fc_es_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_fore.weight)
|
||||
self.fc_is_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_fore.weight)
|
||||
self.fc_indi_fore = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_indi_fore.weight)
|
||||
|
||||
self.fc_es_back = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_es_back.weight)
|
||||
self.fc_is_back = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_is_back.weight)
|
||||
self.fc_indi = nn.Linear(hidden_size, hidden_size)
|
||||
torch.nn.init.xavier_uniform_(self.fc_indi.weight)
|
||||
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.softmax_s2t = torch.nn.Softmax(dim=0)
|
||||
self.softmax_t2s = torch.nn.Softmax(dim=1)
|
||||
|
||||
self.fc_out_es = nn.Linear(hidden_size, 1)
|
||||
self.fc_out_is = nn.Linear(hidden_size, 1)
|
||||
self.fc_out_indi = nn.Linear(hidden_size, 1)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
|
||||
def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same
|
||||
xy = x.mm(torch.t(y))
|
||||
x_norm = torch.sqrt(torch.sum(x * x, dim=1)).reshape(-1, 1)
|
||||
y_norm = torch.sqrt(torch.sum(y * y, dim=1)).reshape(-1, 1)
|
||||
cos_similarity = xy / (x_norm.mm(torch.t(y_norm)) + 1e-6)
|
||||
return cos_similarity
|
||||
|
||||
def forward(self, x, concept_matrix):
|
||||
device = torch.device(torch.get_device(x))
|
||||
|
||||
x_hidden = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x_hidden = x_hidden.permute(0, 2, 1) # [N, T, F]
|
||||
x_hidden, _ = self.rnn(x_hidden)
|
||||
x_hidden = x_hidden[:, -1, :]
|
||||
|
||||
# Predefined Concept Module
|
||||
|
||||
stock_to_concept = concept_matrix
|
||||
|
||||
stock_to_concept_sum = torch.sum(stock_to_concept, 0).reshape(1, -1).repeat(stock_to_concept.shape[0], 1)
|
||||
stock_to_concept_sum = stock_to_concept_sum.mul(concept_matrix)
|
||||
|
||||
stock_to_concept_sum = stock_to_concept_sum + (
|
||||
torch.ones(stock_to_concept.shape[0], stock_to_concept.shape[1]).to(device)
|
||||
)
|
||||
stock_to_concept = stock_to_concept / stock_to_concept_sum
|
||||
hidden = torch.t(stock_to_concept).mm(x_hidden)
|
||||
|
||||
hidden = hidden[hidden.sum(1) != 0]
|
||||
|
||||
concept_to_stock = self.cal_cos_similarity(x_hidden, hidden)
|
||||
concept_to_stock = self.softmax_t2s(concept_to_stock)
|
||||
|
||||
e_shared_info = concept_to_stock.mm(hidden)
|
||||
e_shared_info = self.fc_es(e_shared_info)
|
||||
|
||||
e_shared_back = self.fc_es_back(e_shared_info)
|
||||
output_es = self.fc_es_fore(e_shared_info)
|
||||
output_es = self.leaky_relu(output_es)
|
||||
|
||||
# Hidden Concept Module
|
||||
i_shared_info = x_hidden - e_shared_back
|
||||
hidden = i_shared_info
|
||||
i_stock_to_concept = self.cal_cos_similarity(i_shared_info, hidden)
|
||||
dim = i_stock_to_concept.shape[0]
|
||||
diag = i_stock_to_concept.diagonal(0)
|
||||
i_stock_to_concept = i_stock_to_concept * (torch.ones(dim, dim) - torch.eye(dim)).to(device)
|
||||
row = torch.linspace(0, dim - 1, dim).to(device).long()
|
||||
column = i_stock_to_concept.max(1)[1].long()
|
||||
value = i_stock_to_concept.max(1)[0]
|
||||
i_stock_to_concept[row, column] = 10
|
||||
i_stock_to_concept[i_stock_to_concept != 10] = 0
|
||||
i_stock_to_concept[row, column] = value
|
||||
i_stock_to_concept = i_stock_to_concept + torch.diag_embed((i_stock_to_concept.sum(0) != 0).float() * diag)
|
||||
hidden = torch.t(i_shared_info).mm(i_stock_to_concept).t()
|
||||
hidden = hidden[hidden.sum(1) != 0]
|
||||
|
||||
i_concept_to_stock = self.cal_cos_similarity(i_shared_info, hidden)
|
||||
i_concept_to_stock = self.softmax_t2s(i_concept_to_stock)
|
||||
i_shared_info = i_concept_to_stock.mm(hidden)
|
||||
i_shared_info = self.fc_is(i_shared_info)
|
||||
|
||||
i_shared_back = self.fc_is_back(i_shared_info)
|
||||
output_is = self.fc_is_fore(i_shared_info)
|
||||
output_is = self.leaky_relu(output_is)
|
||||
|
||||
# Individual Information Module
|
||||
individual_info = x_hidden - e_shared_back - i_shared_back
|
||||
output_indi = individual_info
|
||||
output_indi = self.fc_indi(output_indi)
|
||||
output_indi = self.leaky_relu(output_indi)
|
||||
|
||||
# Stock Trend Prediction
|
||||
all_info = output_es + output_is + output_indi
|
||||
pred_all = self.fc_out(all_info).squeeze()
|
||||
|
||||
return pred_all
|
||||
446
qlib/contrib/model/pytorch_igmtf.py
Normal file
446
qlib/contrib/model/pytorch_igmtf.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class IGMTF(Model):
|
||||
"""IGMTF Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("IGMTF")
|
||||
self.logger.info("IMGTF pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"IGMTF parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
model_path,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.igmtf_model = IGMTFModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.igmtf_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.igmtf_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.igmtf_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.igmtf_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.igmtf_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "ic":
|
||||
x = pred[mask]
|
||||
y = label[mask]
|
||||
|
||||
vx = x - torch.mean(x)
|
||||
vy = y - torch.mean(y)
|
||||
return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)))
|
||||
|
||||
if self.metric == ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def get_train_hidden(self, x_train):
|
||||
x_train_values = x_train.values
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
self.igmtf_model.eval()
|
||||
train_hidden = []
|
||||
train_hidden_day = []
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
out = self.igmtf_model(feature, get_hidden=True)
|
||||
train_hidden.append(out.detach().cpu())
|
||||
train_hidden_day.append(out.detach().cpu().mean(dim=0).unsqueeze(dim=0))
|
||||
|
||||
train_hidden = np.asarray(train_hidden, dtype=object)
|
||||
train_hidden_day = torch.cat(train_hidden_day)
|
||||
|
||||
return train_hidden, train_hidden_day
|
||||
|
||||
def train_epoch(self, x_train, y_train, train_hidden, train_hidden_day):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.igmtf_model.train()
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[batch]).float().to(self.device)
|
||||
pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.igmtf_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y, train_hidden, train_hidden_day):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.igmtf_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[batch]).float().to(self.device)
|
||||
|
||||
pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.igmtf_model.state_dict()
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.igmtf_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
train_hidden, train_hidden_day = self.get_train_hidden(x_train)
|
||||
self.train_epoch(x_train, y_train, train_hidden, train_hidden_day)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train, train_hidden, train_hidden_day)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid, train_hidden, train_hidden_day)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.igmtf_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.igmtf_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_train = dataset.prepare("train", col_set="feature", data_key=DataHandlerLP.DK_L)
|
||||
train_hidden, train_hidden_day = self.get_train_hidden(x_train)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.igmtf_model.eval()
|
||||
x_values = x_test.values
|
||||
preds = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = (
|
||||
self.igmtf_model(x_batch, train_hidden=train_hidden, train_hidden_day=train_hidden_day)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class IGMTFModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
|
||||
super().__init__()
|
||||
|
||||
if base_model == "GRU":
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
self.lins = nn.Sequential()
|
||||
for i in range(2):
|
||||
self.lins.add_module("linear" + str(i), nn.Linear(hidden_size, hidden_size))
|
||||
self.lins.add_module("leakyrelu" + str(i), nn.LeakyReLU())
|
||||
self.fc_output = nn.Linear(hidden_size * 2, hidden_size * 2)
|
||||
self.project1 = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.project2 = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.fc_out_pred = nn.Linear(hidden_size * 2, 1)
|
||||
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.d_feat = d_feat
|
||||
|
||||
def cal_cos_similarity(self, x, y): # the 2nd dimension of x and y are the same
|
||||
xy = x.mm(torch.t(y))
|
||||
x_norm = torch.sqrt(torch.sum(x * x, dim=1)).reshape(-1, 1)
|
||||
y_norm = torch.sqrt(torch.sum(y * y, dim=1)).reshape(-1, 1)
|
||||
cos_similarity = xy / (x_norm.mm(torch.t(y_norm)) + 1e-6)
|
||||
return cos_similarity
|
||||
|
||||
def sparse_dense_mul(self, s, d):
|
||||
i = s._indices()
|
||||
v = s._values()
|
||||
dv = d[i[0, :], i[1, :]] # get values from relevant entries of dense matrix
|
||||
return torch.sparse.FloatTensor(i, v * dv, s.size())
|
||||
|
||||
def forward(self, x, get_hidden=False, train_hidden=None, train_hidden_day=None, k_day=10, n_neighbor=10):
|
||||
# x: [N, F*T]
|
||||
device = x.device
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.rnn(x)
|
||||
out = out[:, -1, :]
|
||||
out = self.lins(out)
|
||||
mini_batch_out = out
|
||||
if get_hidden is True:
|
||||
return mini_batch_out
|
||||
|
||||
mini_batch_out_day = torch.mean(mini_batch_out, dim=0).unsqueeze(0)
|
||||
day_similarity = self.cal_cos_similarity(mini_batch_out_day, train_hidden_day.to(device))
|
||||
day_index = torch.topk(day_similarity, k_day, dim=1)[1]
|
||||
sample_train_hidden = train_hidden[day_index.long().cpu()].squeeze()
|
||||
sample_train_hidden = torch.cat(list(sample_train_hidden)).to(device)
|
||||
sample_train_hidden = self.lins(sample_train_hidden)
|
||||
cos_similarity = self.cal_cos_similarity(self.project1(mini_batch_out), self.project2(sample_train_hidden))
|
||||
|
||||
row = (
|
||||
torch.linspace(0, x.shape[0] - 1, x.shape[0])
|
||||
.reshape([-1, 1])
|
||||
.repeat(1, n_neighbor)
|
||||
.reshape(1, -1)
|
||||
.to(device)
|
||||
)
|
||||
column = torch.topk(cos_similarity, n_neighbor, dim=1)[1].reshape(1, -1)
|
||||
mask = torch.sparse_coo_tensor(
|
||||
torch.cat([row, column]),
|
||||
torch.ones([row.shape[1]]).to(device) / n_neighbor,
|
||||
(x.shape[0], sample_train_hidden.shape[0]),
|
||||
)
|
||||
cos_similarity = self.sparse_dense_mul(mask, cos_similarity)
|
||||
|
||||
agg_out = torch.sparse.mm(cos_similarity, self.project2(sample_train_hidden))
|
||||
# out = self.fc_out(out).squeeze()
|
||||
out = self.fc_out_pred(torch.cat([mini_batch_out, agg_out], axis=1)).squeeze()
|
||||
return out
|
||||
Reference in New Issue
Block a user