1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Add A New Baseline: TCN (#668)

This commit is contained in:
fengcunguang
2021-11-04 20:30:52 +08:00
committed by GitHub
parent 5ee2d9496b
commit f0b9a807ea
9 changed files with 893 additions and 1 deletions

View File

@@ -294,6 +294,7 @@ Here is a list of models built on `Qlib`.
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](qlib/contrib/model/pytorch_tcn.py)
Your PR of new Quant models is highly welcomed.

View File

@@ -34,6 +34,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| MLP | Alpha158 | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02 | 1.1408±0.23 | -0.1103±0.02 |
| LightGBM(Guolin Ke, et al.) | Alpha158 | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00 | 1.0164±0.00 | -0.1038±0.00 |
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
| TCN | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 |
@@ -55,6 +56,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.

View File

@@ -0,0 +1,4 @@
numpy==1.17.4
pandas==1.1.2
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,100 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: FilterCol
kwargs:
fields_group: feature
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
]
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TCN
module_path: qlib.contrib.model.pytorch_tcn_ts
kwargs:
d_feat: 20
num_layers: 5
n_chans: 32
kernel_size: 7
dropout: 0.5
n_epochs: 200
lr: 1e-4
early_stop: 20
batch_size: 2000
metric: loss
loss: mse
optimizer: adam
n_jobs: 20
GPU: 0
dataset:
class: TSDatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
step_len: 20
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,90 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TCN
module_path: qlib.contrib.model.pytorch_tcn
kwargs:
d_feat: 6
num_layers: 5
n_chans: 128
kernel_size: 3
dropout: 0.5
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 2000
metric: loss
loss: mse
optimizer: adam
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -30,8 +30,9 @@ try:
from .pytorch_nn import DNNModelPytorch
from .pytorch_tabnet import TabnetModel
from .pytorch_sfm import SFM_Model
from .pytorch_tcn import TCN
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN)
except ModuleNotFoundError:
pytorch_classes = ()
print("Please install necessary libs for PyTorch models.")

317
qlib/contrib/model/pytorch_tcn.py Executable file
View File

@@ -0,0 +1,317 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import weight_norm
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from .tcn import TemporalConvNet
class TCN(Model):
"""TCN Model
Parameters
----------
d_feat : int
input dimension for each time step
n_chans: int
number of channels
metric: str
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
"""
def __init__(
self,
d_feat=6,
n_chans=128,
kernel_size=5,
num_layers=5,
dropout=0.5,
n_epochs=200,
lr=0.0001,
metric="",
batch_size=2000,
early_stop=20,
loss="mse",
optimizer="adam",
GPU=0,
seed=None,
**kwargs
):
# Set logger.
self.logger = get_module_logger("TCN")
self.logger.info("TCN pytorch version...")
# set hyper-parameters.
self.d_feat = d_feat
self.n_chans = n_chans
self.kernel_size = kernel_size
self.num_layers = num_layers
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger.info(
"TCN parameters setting:"
"\nd_feat : {}"
"\nn_chans : {}"
"\nkernel_size : {}"
"\nnum_layers : {}"
"\ndropout : {}"
"\nn_epochs : {}"
"\nlr : {}"
"\nmetric : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
n_chans,
kernel_size,
num_layers,
dropout,
n_epochs,
lr,
metric,
batch_size,
early_stop,
optimizer.lower(),
loss,
GPU,
self.use_gpu,
seed,
)
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.tcn_model = TCNModel(
num_input=self.d_feat,
output_size=1,
num_channels=[self.n_chans] * self.num_layers,
kernel_size=self.kernel_size,
dropout=self.dropout,
)
self.logger.info("model:\n{:}".format(self.tcn_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.tcn_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.tcn_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.tcn_model.parameters(), lr=self.lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.tcn_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)
self.tcn_model.train()
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.tcn_model(feature)
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.tcn_model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_x, data_y):
x_values = data_x.values
y_values = np.squeeze(data_y.values)
self.tcn_model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
with torch.no_grad():
pred = self.tcn_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.tcn_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.tcn_model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.tcn_model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.tcn_model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
class TCNModel(nn.Module):
def __init__(self, num_input, output_size, num_channels, kernel_size, dropout):
super().__init__()
self.num_input = num_input
self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout)
self.linear = nn.Linear(num_channels[-1], output_size)
def forward(self, x):
x = x.reshape(x.shape[0], self.num_input, -1)
output = self.tcn(x)
output = self.linear(output[:, :, -1])
return output.squeeze()

View File

@@ -0,0 +1,300 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
import copy
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset.handler import DataHandlerLP
from .tcn import TemporalConvNet
class TCN(Model):
"""TCN Model
Parameters
----------
d_feat : int
input dimension for each time step
metric: str
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
the GPU ID(s) used for training
"""
def __init__(
self,
d_feat=6,
n_chans=128,
kernel_size=5,
num_layers=2,
dropout=0.0,
n_epochs=200,
lr=0.001,
metric="",
batch_size=2000,
early_stop=20,
loss="mse",
optimizer="adam",
n_jobs=10,
GPU=0,
seed=None,
**kwargs
):
# Set logger.
self.logger = get_module_logger("TCN")
self.logger.info("TCN pytorch version...")
# set hyper-parameters.
self.d_feat = d_feat
self.n_chans = n_chans
self.kernel_size = kernel_size
self.num_layers = num_layers
self.dropout = dropout
self.n_epochs = n_epochs
self.lr = lr
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.n_jobs = n_jobs
self.seed = seed
self.logger.info(
"TCN parameters setting:"
"\nd_feat : {}"
"\nn_chans : {}"
"\nkernel_size : {}"
"\nnum_layers : {}"
"\ndropout : {}"
"\nn_epochs : {}"
"\nlr : {}"
"\nmetric : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\ndevice : {}"
"\nn_jobs : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
n_chans,
kernel_size,
num_layers,
dropout,
n_epochs,
lr,
metric,
batch_size,
early_stop,
optimizer.lower(),
loss,
self.device,
n_jobs,
self.use_gpu,
seed,
)
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.TCN_model = TCNModel(
num_input=self.d_feat,
output_size=1,
num_channels=[self.n_chans] * self.num_layers,
kernel_size=self.kernel_size,
dropout=self.dropout,
)
self.logger.info("model:\n{:}".format(self.TCN_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.TCN_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.TCN_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.TCN_model.parameters(), lr=self.lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.TCN_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def train_epoch(self, data_loader):
self.TCN_model.train()
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
pred = self.TCN_model(feature.float())
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.TCN_model.parameters(), 3.0)
self.train_optimizer.step()
def test_epoch(self, data_loader):
self.TCN_model.eval()
scores = []
losses = []
for data in data_loader:
feature = data[:, :, 0:-1].to(self.device)
# feature[torch.isnan(feature)] = 0
label = data[:, -1, -1].to(self.device)
with torch.no_grad():
pred = self.TCN_model(feature.float())
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def fit(
self,
dataset,
evals_result=dict(),
save_path=None,
):
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
# process nan brought by dataloader
dl_train.config(fillna_type="ffill+bfill")
# process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill")
train_loader = DataLoader(
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
)
valid_loader = DataLoader(
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
)
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(train_loader)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(train_loader)
val_loss, val_score = self.test_epoch(valid_loader)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.TCN_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.TCN_model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.TCN_model.eval()
preds = []
for data in test_loader:
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():
pred = self.TCN_model(feature.float()).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
class TCNModel(nn.Module):
def __init__(self, num_input, output_size, num_channels, kernel_size, dropout):
super().__init__()
self.num_input = num_input
self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout)
self.linear = nn.Linear(num_channels[-1], output_size)
def forward(self, x):
output = self.tcn(x)
output = self.linear(output[:, :, -1])
return output.squeeze()

77
qlib/contrib/model/tcn.py Normal file
View File

@@ -0,0 +1,77 @@
# MIT License
# Copyright (c) 2018 CMU Locus Lab
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, : -self.chomp_size].contiguous()
class TemporalBlock(nn.Module):
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
super(TemporalBlock, self).__init__()
self.conv1 = weight_norm(
nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
)
self.chomp1 = Chomp1d(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)
self.conv2 = weight_norm(
nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
)
self.chomp2 = Chomp1d(padding)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)
self.net = nn.Sequential(
self.conv1, self.chomp1, self.relu1, self.dropout1, self.conv2, self.chomp2, self.relu2, self.dropout2
)
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weights()
def init_weights(self):
self.conv1.weight.data.normal_(0, 0.01)
self.conv2.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return self.relu(out + res)
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
layers += [
TemporalBlock(
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=dilation_size,
padding=(kernel_size - 1) * dilation_size,
dropout=dropout,
)
]
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)