1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/examples/benchmarks/TFT/libs/tft_model.py
you-n-g be4646b4b7 Adjust rolling api (#1594)
* Intermediate version

* Fix yaml template & Successfully run rolling

* Be compatible with benchmark

* Get same results with previous linear model

* Black formatting

* Update black

* Update the placeholder mechanism

* Update CI

* Update CI

* Upgrade Black

* Fix CI and simplify code

* Fix CI

* Move the data processing caching mechanism into utils.

* Adjusting DDG-DA

* Organize import
2023-07-14 12:16:12 +08:00

1278 lines
47 KiB
Python

# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Temporal Fusion Transformer Model.
Contains the full TFT architecture and associated components. Defines functions
for training, evaluation and prediction using simple Pandas Dataframe inputs.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
import json
import os
import shutil
import data_formatters.base
import libs.utils as utils
import numpy as np
import pandas as pd
import tensorflow as tf
# Layer definitions.
concat = tf.keras.backend.concatenate
stack = tf.keras.backend.stack
K = tf.keras.backend
Add = tf.keras.layers.Add
LayerNorm = tf.keras.layers.LayerNormalization
Dense = tf.keras.layers.Dense
Multiply = tf.keras.layers.Multiply
Dropout = tf.keras.layers.Dropout
Activation = tf.keras.layers.Activation
Lambda = tf.keras.layers.Lambda
# Default input types.
InputTypes = data_formatters.base.InputTypes
# Layer utility functions.
def linear_layer(size, activation=None, use_time_distributed=False, use_bias=True):
"""Returns simple Keras linear layer.
Args:
size: Output size
activation: Activation function to apply if required
use_time_distributed: Whether to apply layer across time
use_bias: Whether bias should be included in layer
"""
linear = tf.keras.layers.Dense(size, activation=activation, use_bias=use_bias)
if use_time_distributed:
linear = tf.keras.layers.TimeDistributed(linear)
return linear
def apply_mlp(
inputs, hidden_size, output_size, output_activation=None, hidden_activation="tanh", use_time_distributed=False
):
"""Applies simple feed-forward network to an input.
Args:
inputs: MLP inputs
hidden_size: Hidden state size
output_size: Output size of MLP
output_activation: Activation function to apply on output
hidden_activation: Activation function to apply on input
use_time_distributed: Whether to apply across time
Returns:
Tensor for MLP outputs.
"""
if use_time_distributed:
hidden = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_size, activation=hidden_activation))(
inputs
)
return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(output_size, activation=output_activation))(hidden)
else:
hidden = tf.keras.layers.Dense(hidden_size, activation=hidden_activation)(inputs)
return tf.keras.layers.Dense(output_size, activation=output_activation)(hidden)
def apply_gating_layer(x, hidden_layer_size, dropout_rate=None, use_time_distributed=True, activation=None):
"""Applies a Gated Linear Unit (GLU) to an input.
Args:
x: Input to gating layer
hidden_layer_size: Dimension of GLU
dropout_rate: Dropout rate to apply if any
use_time_distributed: Whether to apply across time
activation: Activation function to apply to the linear feature transform if
necessary
Returns:
Tuple of tensors for: (GLU output, gate)
"""
if dropout_rate is not None:
x = tf.keras.layers.Dropout(dropout_rate)(x)
if use_time_distributed:
activation_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(hidden_layer_size, activation=activation)
)(x)
gated_layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid"))(x)
else:
activation_layer = tf.keras.layers.Dense(hidden_layer_size, activation=activation)(x)
gated_layer = tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid")(x)
return tf.keras.layers.Multiply()([activation_layer, gated_layer]), gated_layer
def add_and_norm(x_list):
"""Applies skip connection followed by layer normalisation.
Args:
x_list: List of inputs to sum for skip connection
Returns:
Tensor output from layer.
"""
tmp = Add()(x_list)
tmp = LayerNorm()(tmp)
return tmp
def gated_residual_network(
x,
hidden_layer_size,
output_size=None,
dropout_rate=None,
use_time_distributed=True,
additional_context=None,
return_gate=False,
):
"""Applies the gated residual network (GRN) as defined in paper.
Args:
x: Network inputs
hidden_layer_size: Internal state size
output_size: Size of output layer
dropout_rate: Dropout rate if dropout is applied
use_time_distributed: Whether to apply network across time dimension
additional_context: Additional context vector to use if relevant
return_gate: Whether to return GLU gate for diagnostic purposes
Returns:
Tuple of tensors for: (GRN output, GLU gate)
"""
# Setup skip connection
if output_size is None:
output_size = hidden_layer_size
skip = x
else:
linear = Dense(output_size)
if use_time_distributed:
linear = tf.keras.layers.TimeDistributed(linear)
skip = linear(x)
# Apply feedforward network
hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(x)
if additional_context is not None:
hidden = hidden + linear_layer(
hidden_layer_size, activation=None, use_time_distributed=use_time_distributed, use_bias=False
)(additional_context)
hidden = tf.keras.layers.Activation("elu")(hidden)
hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(hidden)
gating_layer, gate = apply_gating_layer(
hidden, output_size, dropout_rate=dropout_rate, use_time_distributed=use_time_distributed, activation=None
)
if return_gate:
return add_and_norm([skip, gating_layer]), gate
else:
return add_and_norm([skip, gating_layer])
# Attention Components.
def get_decoder_mask(self_attn_inputs):
"""Returns causal mask to apply for self-attention layer.
Args:
self_attn_inputs: Inputs to self attention layer to determine mask shape
"""
len_s = tf.shape(self_attn_inputs)[1]
bs = tf.shape(self_attn_inputs)[:1]
mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1)
return mask
class ScaledDotProductAttention:
"""Defines scaled dot product attention layer.
Attributes:
dropout: Dropout rate to use
activation: Normalisation function for scaled dot product attention (e.g.
softmax by default)
"""
def __init__(self, attn_dropout=0.0):
self.dropout = Dropout(attn_dropout)
self.activation = Activation("softmax")
def __call__(self, q, k, v, mask):
"""Applies scaled dot product attention.
Args:
q: Queries
k: Keys
v: Values
mask: Masking if required -- sets softmax to very large value
Returns:
Tuple of (layer outputs, attention weights)
"""
temper = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype="float32"))
attn = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]) / temper)([q, k]) # shape=(batch, q, k)
if mask is not None:
mmask = Lambda(lambda x: (-1e9) * (1.0 - K.cast(x, "float32")))(mask) # setting to infinity
attn = Add()([attn, mmask])
attn = self.activation(attn)
attn = self.dropout(attn)
output = Lambda(lambda x: K.batch_dot(x[0], x[1]))([attn, v])
return output, attn
class InterpretableMultiHeadAttention:
"""Defines interpretable multi-head attention layer.
Attributes:
n_head: Number of heads
d_k: Key/query dimensionality per head
d_v: Value dimensionality
dropout: Dropout rate to apply
qs_layers: List of queries across heads
ks_layers: List of keys across heads
vs_layers: List of values across heads
attention: Scaled dot product attention layer
w_o: Output weight matrix to project internal state to the original TFT
state size
"""
def __init__(self, n_head, d_model, dropout):
"""Initialises layer.
Args:
n_head: Number of heads
d_model: TFT state dimensionality
dropout: Dropout discard rate
"""
self.n_head = n_head
self.d_k = self.d_v = d_k = d_v = d_model // n_head
self.dropout = dropout
self.qs_layers = []
self.ks_layers = []
self.vs_layers = []
# Use same value layer to facilitate interp
vs_layer = Dense(d_v, use_bias=False)
for _ in range(n_head):
self.qs_layers.append(Dense(d_k, use_bias=False))
self.ks_layers.append(Dense(d_k, use_bias=False))
self.vs_layers.append(vs_layer) # use same vs_layer
self.attention = ScaledDotProductAttention()
self.w_o = Dense(d_model, use_bias=False)
def __call__(self, q, k, v, mask=None):
"""Applies interpretable multihead attention.
Using T to denote the number of time steps fed into the transformer.
Args:
q: Query tensor of shape=(?, T, d_model)
k: Key of shape=(?, T, d_model)
v: Values of shape=(?, T, d_model)
mask: Masking if required with shape=(?, T, T)
Returns:
Tuple of (layer outputs, attention weights)
"""
n_head = self.n_head
heads = []
attns = []
for i in range(n_head):
qs = self.qs_layers[i](q)
ks = self.ks_layers[i](k)
vs = self.vs_layers[i](v)
head, attn = self.attention(qs, ks, vs, mask)
head_dropout = Dropout(self.dropout)(head)
heads.append(head_dropout)
attns.append(attn)
head = K.stack(heads) if n_head > 1 else heads[0]
attn = K.stack(attns)
outputs = K.mean(head, axis=0) if n_head > 1 else head
outputs = self.w_o(outputs)
outputs = Dropout(self.dropout)(outputs) # output dropout
return outputs, attn
class TFTDataCache:
"""Caches data for the TFT."""
_data_cache = {}
@classmethod
def update(cls, data, key):
"""Updates cached data.
Args:
data: Source to update
key: Key to dictionary location
"""
cls._data_cache[key] = data
@classmethod
def get(cls, key):
"""Returns data stored at key location."""
return cls._data_cache[key].copy()
@classmethod
def contains(cls, key):
"""Returns boolean indicating whether key is present in cache."""
return key in cls._data_cache
# TFT model definitions.
class TemporalFusionTransformer:
"""Defines Temporal Fusion Transformer.
Attributes:
name: Name of model
time_steps: Total number of input time steps per forecast date (i.e. Width
of Temporal fusion decoder N)
input_size: Total number of inputs
output_size: Total number of outputs
category_counts: Number of categories per categorical variable
n_multiprocessing_workers: Number of workers to use for parallel
computations
column_definition: List of tuples of (string, DataType, InputType) that
define each column
quantiles: Quantiles to forecast for TFT
use_cudnn: Whether to use Keras CuDNNLSTM or standard LSTM layers
hidden_layer_size: Internal state size of TFT
dropout_rate: Dropout discard rate
max_gradient_norm: Maximum norm for gradient clipping
learning_rate: Initial learning rate of ADAM optimizer
minibatch_size: Size of minibatches for training
num_epochs: Maximum number of epochs for training
early_stopping_patience: Maximum number of iterations of non-improvement
before early stopping kicks in
num_encoder_steps: Size of LSTM encoder -- i.e. number of past time steps
before forecast date to use
num_stacks: Number of self-attention layers to apply (default is 1 for basic
TFT)
num_heads: Number of heads for interpretable mulit-head attention
model: Keras model for TFT
"""
def __init__(self, raw_params, use_cudnn=False):
"""Builds TFT from parameters.
Args:
raw_params: Parameters to define TFT
use_cudnn: Whether to use CUDNN GPU optimised LSTM
"""
self.name = self.__class__.__name__
params = dict(raw_params) # copy locally
# Data parameters
self.time_steps = int(params["total_time_steps"])
self.input_size = int(params["input_size"])
self.output_size = int(params["output_size"])
self.category_counts = json.loads(str(params["category_counts"]))
self.n_multiprocessing_workers = int(params["multiprocessing_workers"])
# Relevant indices for TFT
self._input_obs_loc = json.loads(str(params["input_obs_loc"]))
self._static_input_loc = json.loads(str(params["static_input_loc"]))
self._known_regular_input_idx = json.loads(str(params["known_regular_inputs"]))
self._known_categorical_input_idx = json.loads(str(params["known_categorical_inputs"]))
self.column_definition = params["column_definition"]
# Network params
self.quantiles = [0.1, 0.5, 0.9]
self.use_cudnn = use_cudnn # Whether to use GPU optimised LSTM
self.hidden_layer_size = int(params["hidden_layer_size"])
self.dropout_rate = float(params["dropout_rate"])
self.max_gradient_norm = float(params["max_gradient_norm"])
self.learning_rate = float(params["learning_rate"])
self.minibatch_size = int(params["minibatch_size"])
self.num_epochs = int(params["num_epochs"])
self.early_stopping_patience = int(params["early_stopping_patience"])
self.num_encoder_steps = int(params["num_encoder_steps"])
self.num_stacks = int(params["stack_size"])
self.num_heads = int(params["num_heads"])
# Serialisation options
self._temp_folder = os.path.join(params["model_folder"], "tmp")
self.reset_temp_folder()
# Extra components to store Tensorflow nodes for attention computations
self._input_placeholder = None
self._attention_components = None
self._prediction_parts = None
print("*** {} params ***".format(self.name))
for k in params:
print("# {} = {}".format(k, params[k]))
# Build model
self.model = self.build_model()
def get_tft_embeddings(self, all_inputs):
"""Transforms raw inputs to embeddings.
Applies linear transformation onto continuous variables and uses embeddings
for categorical variables.
Args:
all_inputs: Inputs to transform
Returns:
Tensors for transformed inputs.
"""
time_steps = self.time_steps
# Sanity checks
for i in self._known_regular_input_idx:
if i in self._input_obs_loc:
raise ValueError("Observation cannot be known a priori!")
for i in self._input_obs_loc:
if i in self._static_input_loc:
raise ValueError("Observation cannot be static!")
if all_inputs.get_shape().as_list()[-1] != self.input_size:
raise ValueError(
"Illegal number of inputs! Inputs observed={}, expected={}".format(
all_inputs.get_shape().as_list()[-1], self.input_size
)
)
num_categorical_variables = len(self.category_counts)
num_regular_variables = self.input_size - num_categorical_variables
embedding_sizes = [self.hidden_layer_size for i, size in enumerate(self.category_counts)]
embeddings = []
for i in range(num_categorical_variables):
embedding = tf.keras.Sequential(
[
tf.keras.layers.InputLayer([time_steps]),
tf.keras.layers.Embedding(
self.category_counts[i], embedding_sizes[i], input_length=time_steps, dtype=tf.float32
),
]
)
embeddings.append(embedding)
regular_inputs, categorical_inputs = (
all_inputs[:, :, :num_regular_variables],
all_inputs[:, :, num_regular_variables:],
)
embedded_inputs = [embeddings[i](categorical_inputs[Ellipsis, i]) for i in range(num_categorical_variables)]
# Static inputs
if self._static_input_loc:
static_inputs = [
tf.keras.layers.Dense(self.hidden_layer_size)(regular_inputs[:, 0, i : i + 1])
for i in range(num_regular_variables)
if i in self._static_input_loc
] + [
embedded_inputs[i][:, 0, :]
for i in range(num_categorical_variables)
if i + num_regular_variables in self._static_input_loc
]
static_inputs = tf.keras.backend.stack(static_inputs, axis=1)
else:
static_inputs = None
def convert_real_to_embedding(x):
"""Applies linear transformation for time-varying inputs."""
return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.hidden_layer_size))(x)
# Targets
obs_inputs = tf.keras.backend.stack(
[convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1]) for i in self._input_obs_loc], axis=-1
)
# Observed (a prioir unknown) inputs
wired_embeddings = []
for i in range(num_categorical_variables):
if i not in self._known_categorical_input_idx and i + num_regular_variables not in self._input_obs_loc:
e = embeddings[i](categorical_inputs[:, :, i])
wired_embeddings.append(e)
unknown_inputs = []
for i in range(regular_inputs.shape[-1]):
if i not in self._known_regular_input_idx and i not in self._input_obs_loc:
e = convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1])
unknown_inputs.append(e)
if unknown_inputs + wired_embeddings:
unknown_inputs = tf.keras.backend.stack(unknown_inputs + wired_embeddings, axis=-1)
else:
unknown_inputs = None
# A priori known inputs
known_regular_inputs = [
convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1])
for i in self._known_regular_input_idx
if i not in self._static_input_loc
]
known_categorical_inputs = [
embedded_inputs[i]
for i in self._known_categorical_input_idx
if i + num_regular_variables not in self._static_input_loc
]
known_combined_layer = tf.keras.backend.stack(known_regular_inputs + known_categorical_inputs, axis=-1)
return unknown_inputs, known_combined_layer, obs_inputs, static_inputs
def _get_single_col_by_type(self, input_type):
"""Returns name of single column for input type."""
return utils.get_single_col_by_input_type(input_type, self.column_definition)
def training_data_cached(self):
"""Returns boolean indicating if training data has been cached."""
return TFTDataCache.contains("train") and TFTDataCache.contains("valid")
def cache_batched_data(self, data, cache_key, num_samples=-1):
"""Batches and caches data once for using during training.
Args:
data: Data to batch and cache
cache_key: Key used for cache
num_samples: Maximum number of samples to extract (-1 to use all data)
"""
if num_samples > 0:
TFTDataCache.update(self._batch_sampled_data(data, max_samples=num_samples), cache_key)
else:
TFTDataCache.update(self._batch_data(data), cache_key)
print('Cached data "{}" updated'.format(cache_key))
def _batch_sampled_data(self, data, max_samples):
"""Samples segments into a compatible format.
Args:
data: Sources data to sample and batch
max_samples: Maximum number of samples in batch
Returns:
Dictionary of batched data with the maximum samples specified.
"""
if max_samples < 1:
raise ValueError("Illegal number of samples specified! samples={}".format(max_samples))
id_col = self._get_single_col_by_type(InputTypes.ID)
time_col = self._get_single_col_by_type(InputTypes.TIME)
data.sort_values(by=[id_col, time_col], inplace=True)
print("Getting valid sampling locations.")
valid_sampling_locations = []
split_data_map = {}
for identifier, df in data.groupby(id_col):
print("Getting locations for {}".format(identifier))
num_entries = len(df)
if num_entries >= self.time_steps:
valid_sampling_locations += [
(identifier, self.time_steps + i) for i in range(num_entries - self.time_steps + 1)
]
split_data_map[identifier] = df
inputs = np.zeros((max_samples, self.time_steps, self.input_size))
outputs = np.zeros((max_samples, self.time_steps, self.output_size))
time = np.empty((max_samples, self.time_steps, 1), dtype=object)
identifiers = np.empty((max_samples, self.time_steps, 1), dtype=object)
if max_samples > 0 and len(valid_sampling_locations) > max_samples:
print("Extracting {} samples...".format(max_samples))
ranges = [
valid_sampling_locations[i]
for i in np.random.choice(len(valid_sampling_locations), max_samples, replace=False)
]
else:
print("Max samples={} exceeds # available segments={}".format(max_samples, len(valid_sampling_locations)))
ranges = valid_sampling_locations
id_col = self._get_single_col_by_type(InputTypes.ID)
time_col = self._get_single_col_by_type(InputTypes.TIME)
target_col = self._get_single_col_by_type(InputTypes.TARGET)
input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
for i, tup in enumerate(ranges):
if (i + 1 % 1000) == 0:
print(i + 1, "of", max_samples, "samples done...")
identifier, start_idx = tup
sliced = split_data_map[identifier].iloc[start_idx - self.time_steps : start_idx]
inputs[i, :, :] = sliced[input_cols]
outputs[i, :, :] = sliced[[target_col]]
time[i, :, 0] = sliced[time_col]
identifiers[i, :, 0] = sliced[id_col]
sampled_data = {
"inputs": inputs,
"outputs": outputs[:, self.num_encoder_steps :, :],
"active_entries": np.ones_like(outputs[:, self.num_encoder_steps :, :]),
"time": time,
"identifier": identifiers,
}
return sampled_data
def _batch_data(self, data):
"""Batches data for training.
Converts raw dataframe from a 2-D tabular format to a batched 3-D array
to feed into Keras model.
Args:
data: DataFrame to batch
Returns:
Batched Numpy array with shape=(?, self.time_steps, self.input_size)
"""
# Functions.
def _batch_single_entity(input_data):
time_steps = len(input_data)
lags = self.time_steps
x = input_data.values
if time_steps >= lags:
return np.stack([x[i : time_steps - (lags - 1) + i, :] for i in range(lags)], axis=1)
else:
return None
id_col = self._get_single_col_by_type(InputTypes.ID)
time_col = self._get_single_col_by_type(InputTypes.TIME)
target_col = self._get_single_col_by_type(InputTypes.TARGET)
input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
data_map = {}
for _, sliced in data.groupby(id_col):
col_mappings = {"identifier": [id_col], "time": [time_col], "outputs": [target_col], "inputs": input_cols}
for k in col_mappings:
cols = col_mappings[k]
arr = _batch_single_entity(sliced[cols].copy())
if k not in data_map:
data_map[k] = [arr]
else:
data_map[k].append(arr)
# Combine all data
for k in data_map:
# Wendi: Avoid returning None when the length is not enough
data_map[k] = np.concatenate([i for i in data_map[k] if i is not None], axis=0)
# Shorten target so we only get decoder steps
data_map["outputs"] = data_map["outputs"][:, self.num_encoder_steps :, :]
active_entries = np.ones_like(data_map["outputs"])
if "active_entries" not in data_map:
data_map["active_entries"] = active_entries
else:
data_map["active_entries"].append(active_entries)
return data_map
def _get_active_locations(self, x):
"""Formats sample weights for Keras training."""
return (np.sum(x, axis=-1) > 0.0) * 1.0
def _build_base_graph(self):
"""Returns graph defining layers of the TFT."""
# Size definitions.
time_steps = self.time_steps
combined_input_size = self.input_size
encoder_steps = self.num_encoder_steps
# Inputs.
all_inputs = tf.keras.layers.Input(
shape=(
time_steps,
combined_input_size,
)
)
unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)
# Isolate known and observed historical inputs.
if unknown_inputs is not None:
historical_inputs = concat(
[
unknown_inputs[:, :encoder_steps, :],
known_combined_layer[:, :encoder_steps, :],
obs_inputs[:, :encoder_steps, :],
],
axis=-1,
)
else:
historical_inputs = concat(
[known_combined_layer[:, :encoder_steps, :], obs_inputs[:, :encoder_steps, :]], axis=-1
)
# Isolate only known future inputs.
future_inputs = known_combined_layer[:, encoder_steps:, :]
def static_combine_and_mask(embedding):
"""Applies variable selection network to static inputs.
Args:
embedding: Transformed static inputs
Returns:
Tensor output for variable selection network
"""
# Add temporal features
_, num_static, _ = embedding.get_shape().as_list()
flatten = tf.keras.layers.Flatten()(embedding)
# Nonlinear transformation with gated residual network.
mlp_outputs = gated_residual_network(
flatten,
self.hidden_layer_size,
output_size=num_static,
dropout_rate=self.dropout_rate,
use_time_distributed=False,
additional_context=None,
)
sparse_weights = tf.keras.layers.Activation("softmax")(mlp_outputs)
sparse_weights = K.expand_dims(sparse_weights, axis=-1)
trans_emb_list = []
for i in range(num_static):
e = gated_residual_network(
embedding[:, i : i + 1, :],
self.hidden_layer_size,
dropout_rate=self.dropout_rate,
use_time_distributed=False,
)
trans_emb_list.append(e)
transformed_embedding = concat(trans_emb_list, axis=1)
combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding])
static_vec = K.sum(combined, axis=1)
return static_vec, sparse_weights
static_encoder, static_weights = static_combine_and_mask(static_inputs)
static_context_variable_selection = gated_residual_network(
static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
)
static_context_enrichment = gated_residual_network(
static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
)
static_context_state_h = gated_residual_network(
static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
)
static_context_state_c = gated_residual_network(
static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
)
def lstm_combine_and_mask(embedding):
"""Apply temporal variable selection networks.
Args:
embedding: Transformed inputs.
Returns:
Processed tensor outputs.
"""
# Add temporal features
_, time_steps, embedding_dim, num_inputs = embedding.get_shape().as_list()
flatten = K.reshape(embedding, [-1, time_steps, embedding_dim * num_inputs])
expanded_static_context = K.expand_dims(static_context_variable_selection, axis=1)
# Variable selection weights
mlp_outputs, static_gate = gated_residual_network(
flatten,
self.hidden_layer_size,
output_size=num_inputs,
dropout_rate=self.dropout_rate,
use_time_distributed=True,
additional_context=expanded_static_context,
return_gate=True,
)
sparse_weights = tf.keras.layers.Activation("softmax")(mlp_outputs)
sparse_weights = tf.expand_dims(sparse_weights, axis=2)
# Non-linear Processing & weight application
trans_emb_list = []
for i in range(num_inputs):
grn_output = gated_residual_network(
embedding[Ellipsis, i],
self.hidden_layer_size,
dropout_rate=self.dropout_rate,
use_time_distributed=True,
)
trans_emb_list.append(grn_output)
transformed_embedding = stack(trans_emb_list, axis=-1)
combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding])
temporal_ctx = K.sum(combined, axis=-1)
return temporal_ctx, sparse_weights, static_gate
historical_features, historical_flags, _ = lstm_combine_and_mask(historical_inputs)
future_features, future_flags, _ = lstm_combine_and_mask(future_inputs)
# LSTM layer
def get_lstm(return_state):
"""Returns LSTM cell initialized with default parameters."""
if self.use_cudnn:
lstm = tf.keras.layers.CuDNNLSTM(
self.hidden_layer_size,
return_sequences=True,
return_state=return_state,
stateful=False,
)
else:
lstm = tf.keras.layers.LSTM(
self.hidden_layer_size,
return_sequences=True,
return_state=return_state,
stateful=False,
# Additional params to ensure LSTM matches CuDNN, See TF 2.0 :
# (https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM)
activation="tanh",
recurrent_activation="sigmoid",
recurrent_dropout=0,
unroll=False,
use_bias=True,
)
return lstm
history_lstm, state_h, state_c = get_lstm(return_state=True)(
historical_features, initial_state=[static_context_state_h, static_context_state_c]
)
future_lstm = get_lstm(return_state=False)(future_features, initial_state=[state_h, state_c])
lstm_layer = concat([history_lstm, future_lstm], axis=1)
# Apply gated skip connection
input_embeddings = concat([historical_features, future_features], axis=1)
lstm_layer, _ = apply_gating_layer(lstm_layer, self.hidden_layer_size, self.dropout_rate, activation=None)
temporal_feature_layer = add_and_norm([lstm_layer, input_embeddings])
# Static enrichment layers
expanded_static_context = K.expand_dims(static_context_enrichment, axis=1)
enriched, _ = gated_residual_network(
temporal_feature_layer,
self.hidden_layer_size,
dropout_rate=self.dropout_rate,
use_time_distributed=True,
additional_context=expanded_static_context,
return_gate=True,
)
# Decoder self attention
self_attn_layer = InterpretableMultiHeadAttention(
self.num_heads, self.hidden_layer_size, dropout=self.dropout_rate
)
mask = get_decoder_mask(enriched)
x, self_att = self_attn_layer(enriched, enriched, enriched, mask=mask)
x, _ = apply_gating_layer(x, self.hidden_layer_size, dropout_rate=self.dropout_rate, activation=None)
x = add_and_norm([x, enriched])
# Nonlinear processing on outputs
decoder = gated_residual_network(
x, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=True
)
# Final skip connection
decoder, _ = apply_gating_layer(decoder, self.hidden_layer_size, activation=None)
transformer_layer = add_and_norm([decoder, temporal_feature_layer])
# Attention components for explainability
attention_components = {
# Temporal attention weights
"decoder_self_attn": self_att,
# Static variable selection weights
"static_flags": static_weights[Ellipsis, 0],
# Variable selection weights of past inputs
"historical_flags": historical_flags[Ellipsis, 0, :],
# Variable selection weights of future inputs
"future_flags": future_flags[Ellipsis, 0, :],
}
return transformer_layer, all_inputs, attention_components
def build_model(self):
"""Build model and defines training losses.
Returns:
Fully defined Keras model.
"""
with tf.variable_scope(self.name):
transformer_layer, all_inputs, attention_components = self._build_base_graph()
outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.output_size * len(self.quantiles)))(
transformer_layer[Ellipsis, self.num_encoder_steps :, :]
)
self._attention_components = attention_components
adam = tf.keras.optimizers.Adam(lr=self.learning_rate, clipnorm=self.max_gradient_norm)
model = tf.keras.Model(inputs=all_inputs, outputs=outputs)
print(model.summary())
valid_quantiles = self.quantiles
output_size = self.output_size
class QuantileLossCalculator:
"""Computes the combined quantile loss for prespecified quantiles.
Attributes:
quantiles: Quantiles to compute losses
"""
def __init__(self, quantiles):
"""Initializes computer with quantiles for loss calculations.
Args:
quantiles: Quantiles to use for computations.
"""
self.quantiles = quantiles
def quantile_loss(self, a, b):
"""Returns quantile loss for specified quantiles.
Args:
a: Targets
b: Predictions
"""
quantiles_used = set(self.quantiles)
loss = 0.0
for i, quantile in enumerate(valid_quantiles):
if quantile in quantiles_used:
loss += utils.tensorflow_quantile_loss(
a[Ellipsis, output_size * i : output_size * (i + 1)],
b[Ellipsis, output_size * i : output_size * (i + 1)],
quantile,
)
return loss
quantile_loss = QuantileLossCalculator(valid_quantiles).quantile_loss
model.compile(loss=quantile_loss, optimizer=adam, sample_weight_mode="temporal")
self._input_placeholder = all_inputs
return model
def fit(self, train_df=None, valid_df=None):
"""Fits deep neural network for given training and validation data.
Args:
train_df: DataFrame for training data
valid_df: DataFrame for validation data
"""
print("*** Fitting {} ***".format(self.name))
# Add relevant callbacks
callbacks = [
tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=self.early_stopping_patience, min_delta=1e-4),
tf.keras.callbacks.ModelCheckpoint(
filepath=self.get_keras_saved_path(self._temp_folder),
monitor="val_loss",
save_best_only=True,
save_weights_only=True,
),
tf.keras.callbacks.TerminateOnNaN(),
]
print("Getting batched_data")
if train_df is None:
print("Using cached training data")
train_data = TFTDataCache.get("train")
else:
train_data = self._batch_data(train_df)
if valid_df is None:
print("Using cached validation data")
valid_data = TFTDataCache.get("valid")
else:
valid_data = self._batch_data(valid_df)
print("Using keras standard fit")
def _unpack(data):
return data["inputs"], data["outputs"], self._get_active_locations(data["active_entries"])
# Unpack without sample weights
data, labels, active_flags = _unpack(train_data)
val_data, val_labels, val_flags = _unpack(valid_data)
all_callbacks = callbacks
self.model.fit(
x=data,
y=np.concatenate([labels, labels, labels], axis=-1),
sample_weight=active_flags,
epochs=self.num_epochs,
batch_size=self.minibatch_size,
validation_data=(val_data, np.concatenate([val_labels, val_labels, val_labels], axis=-1), val_flags),
callbacks=all_callbacks,
shuffle=True,
use_multiprocessing=True,
workers=self.n_multiprocessing_workers,
)
# Load best checkpoint again
tmp_checkpont = self.get_keras_saved_path(self._temp_folder)
if os.path.exists(tmp_checkpont):
self.load(self._temp_folder, use_keras_loadings=True)
else:
print("Cannot load from {}, skipping ...".format(self._temp_folder))
def evaluate(self, data=None, eval_metric="loss"):
"""Applies evaluation metric to the training data.
Args:
data: Dataframe for evaluation
eval_metric: Evaluation metic to return, based on model definition.
Returns:
Computed evaluation loss.
"""
if data is None:
print("Using cached validation data")
raw_data = TFTDataCache.get("valid")
else:
raw_data = self._batch_data(data)
inputs = raw_data["inputs"]
outputs = raw_data["outputs"]
active_entries = self._get_active_locations(raw_data["active_entries"])
metric_values = self.model.evaluate(
x=inputs,
y=np.concatenate([outputs, outputs, outputs], axis=-1),
sample_weight=active_entries,
workers=16,
use_multiprocessing=True,
)
metrics = pd.Series(metric_values, self.model.metrics_names)
return metrics[eval_metric]
def predict(self, df, return_targets=False):
"""Computes predictions for a given input dataset.
Args:
df: Input dataframe
return_targets: Whether to also return outputs aligned with predictions to
facilitate evaluation
Returns:
Input dataframe or tuple of (input dataframe, aligned output dataframe).
"""
data = self._batch_data(df)
inputs = data["inputs"]
time = data["time"]
identifier = data["identifier"]
outputs = data["outputs"]
combined = self.model.predict(inputs, workers=16, use_multiprocessing=True, batch_size=self.minibatch_size)
# Format output_csv
if self.output_size != 1:
raise NotImplementedError("Current version only supports 1D targets!")
def format_outputs(prediction):
"""Returns formatted dataframes for prediction."""
flat_prediction = pd.DataFrame(
prediction[:, :, 0], columns=["t+{}".format(i) for i in range(self.time_steps - self.num_encoder_steps)]
)
cols = list(flat_prediction.columns)
flat_prediction["forecast_time"] = time[:, self.num_encoder_steps - 1, 0]
flat_prediction["identifier"] = identifier[:, 0, 0]
# Arrange in order
return flat_prediction[["forecast_time", "identifier"] + cols]
# Extract predictions for each quantile into different entries
process_map = {
"p{}".format(int(q * 100)): combined[Ellipsis, i * self.output_size : (i + 1) * self.output_size]
for i, q in enumerate(self.quantiles)
}
if return_targets:
# Add targets if relevant
process_map["targets"] = outputs
return {k: format_outputs(process_map[k]) for k in process_map}
def get_attention(self, df):
"""Computes TFT attention weights for a given dataset.
Args:
df: Input dataframe
Returns:
Dictionary of numpy arrays for temporal attention weights and variable
selection weights, along with their identifiers and time indices
"""
data = self._batch_data(df)
inputs = data["inputs"]
identifiers = data["identifier"]
time = data["time"]
def get_batch_attention_weights(input_batch):
"""Returns weights for a given minibatch of data."""
input_placeholder = self._input_placeholder
attention_weights = {}
for k in self._attention_components:
attention_weight = tf.keras.backend.get_session().run(
self._attention_components[k], {input_placeholder: input_batch.astype(np.float32)}
)
attention_weights[k] = attention_weight
return attention_weights
# Compute number of batches
batch_size = self.minibatch_size
n = inputs.shape[0]
num_batches = n // batch_size
if n - (num_batches * batch_size) > 0:
num_batches += 1
# Split up inputs into batches
batched_inputs = [inputs[i * batch_size : (i + 1) * batch_size, Ellipsis] for i in range(num_batches)]
# Get attention weights, while avoiding large memory increases
attention_by_batch = [get_batch_attention_weights(batch) for batch in batched_inputs]
attention_weights = {}
for k in self._attention_components:
attention_weights[k] = []
for batch_weights in attention_by_batch:
attention_weights[k].append(batch_weights[k])
if len(attention_weights[k][0].shape) == 4:
tmp = np.concatenate(attention_weights[k], axis=1)
else:
tmp = np.concatenate(attention_weights[k], axis=0)
del attention_weights[k]
gc.collect()
attention_weights[k] = tmp
attention_weights["identifiers"] = identifiers[:, 0, 0]
attention_weights["time"] = time[:, :, 0]
return attention_weights
# Serialisation.
def reset_temp_folder(self):
"""Deletes and recreates folder with temporary Keras training outputs."""
print("Resetting temp folder...")
utils.create_folder_if_not_exist(self._temp_folder)
shutil.rmtree(self._temp_folder)
os.makedirs(self._temp_folder)
def get_keras_saved_path(self, model_folder):
"""Returns path to keras checkpoint."""
return os.path.join(model_folder, "{}.check".format(self.name))
def save(self, model_folder):
"""Saves optimal TFT weights.
Args:
model_folder: Location to serialze model.
"""
# Allows for direct serialisation of tensorflow variables to avoid spurious
# issue with Keras that leads to different performance evaluation results
# when model is reloaded (https://github.com/keras-team/keras/issues/4875).
utils.save(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name)
def load(self, model_folder, use_keras_loadings=False):
"""Loads TFT weights.
Args:
model_folder: Folder containing serialized models.
use_keras_loadings: Whether to load from Keras checkpoint.
Returns:
"""
if use_keras_loadings:
# Loads temporary Keras model saved during training.
serialisation_path = self.get_keras_saved_path(model_folder)
print("Loading model from {}".format(serialisation_path))
self.model.load_weights(serialisation_path)
else:
# Loads tensorflow graph for optimal models.
utils.load(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name)
@classmethod
def get_hyperparm_choices(cls):
"""Returns hyperparameter ranges for random search."""
return {
"dropout_rate": [0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9],
"hidden_layer_size": [10, 20, 40, 80, 160, 240, 320],
"minibatch_size": [64, 128, 256],
"learning_rate": [1e-4, 1e-3, 1e-2],
"max_gradient_norm": [0.01, 1.0, 100.0],
"num_heads": [1, 4],
"stack_size": [1],
}