1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
Files
qlib/examples/benchmarks/TFT/libs/utils.py
2020-11-25 20:36:28 +08:00

225 lines
7.0 KiB
Python

# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Generic helper functions used across codebase."""
import os
import pathlib
import numpy as np
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
# Generic.
def get_single_col_by_input_type(input_type, column_definition):
"""Returns name of single column.
Args:
input_type: Input type of column to extract
column_definition: Column definition list for experiment
"""
l = [tup[0] for tup in column_definition if tup[2] == input_type]
if len(l) != 1:
raise ValueError("Invalid number of columns for {}".format(input_type))
return l[0]
def extract_cols_from_data_type(data_type, column_definition, excluded_input_types):
"""Extracts the names of columns that correspond to a define data_type.
Args:
data_type: DataType of columns to extract.
column_definition: Column definition to use.
excluded_input_types: Set of input types to exclude
Returns:
List of names for columns with data type specified.
"""
return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types]
# Loss functions.
def tensorflow_quantile_loss(y, y_pred, quantile):
"""Computes quantile loss for tensorflow.
Standard quantile loss as defined in the "Training Procedure" section of
the main TFT paper
Args:
y: Targets
y_pred: Predictions
quantile: Quantile to use for loss calculations (between 0 & 1)
Returns:
Tensor for quantile loss.
"""
# Checks quantile
if quantile < 0 or quantile > 1:
raise ValueError("Illegal quantile value={}! Values should be between 0 and 1.".format(quantile))
prediction_underflow = y - y_pred
q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum(
-prediction_underflow, 0.0
)
return tf.reduce_sum(q_loss, axis=-1)
def numpy_normalised_quantile_loss(y, y_pred, quantile):
"""Computes normalised quantile loss for numpy arrays.
Uses the q-Risk metric as defined in the "Training Procedure" section of the
main TFT paper.
Args:
y: Targets
y_pred: Predictions
quantile: Quantile to use for loss calculations (between 0 & 1)
Returns:
Float for normalised quantile loss.
"""
prediction_underflow = y - y_pred
weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum(
-prediction_underflow, 0.0
)
quantile_loss = weighted_errors.mean()
normaliser = y.abs().mean()
return 2 * quantile_loss / normaliser
# OS related functions.
def create_folder_if_not_exist(directory):
"""Creates folder if it doesn't exist.
Args:
directory: Folder path to create.
"""
# Also creates directories recursively
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
# Tensorflow related functions.
def get_default_tensorflow_config(tf_device="gpu", gpu_id=0):
"""Creates tensorflow config for graphs to run on CPU or GPU.
Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi
GPU machines.
Args:
tf_device: 'cpu' or 'gpu'
gpu_id: GPU ID to use if relevant
Returns:
Tensorflow config.
"""
if tf_device == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # for training on cpu
tf_config = tf.ConfigProto(log_device_placement=False, device_count={"GPU": 0})
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
print("Selecting GPU ID={}".format(gpu_id))
tf_config = tf.ConfigProto(log_device_placement=False)
tf_config.gpu_options.allow_growth = True
return tf_config
def save(tf_session, model_folder, cp_name, scope=None):
"""Saves Tensorflow graph to checkpoint.
Saves all trainiable variables under a given variable scope to checkpoint.
Args:
tf_session: Session containing graph
model_folder: Folder to save models
cp_name: Name of Tensorflow checkpoint
scope: Variable scope containing variables to save
"""
# Save model
if scope is None:
saver = tf.train.Saver()
else:
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
save_path = saver.save(tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name)))
print("Model saved to: {0}".format(save_path))
def load(tf_session, model_folder, cp_name, scope=None, verbose=False):
"""Loads Tensorflow graph from checkpoint.
Args:
tf_session: Session to load graph into
model_folder: Folder containing serialised model
cp_name: Name of Tensorflow checkpoint
scope: Variable scope to use.
verbose: Whether to print additional debugging information.
"""
# Load model proper
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
print("Loading model from {0}".format(load_path))
print_weights_in_checkpoint(model_folder, cp_name)
initial_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
# Saver
if scope is None:
saver = tf.train.Saver()
else:
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
# Load
saver.restore(tf_session, load_path)
all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
if verbose:
print("Restored {0}".format(",".join(initial_vars.difference(all_vars))))
print("Existing {0}".format(",".join(all_vars.difference(initial_vars))))
print("All {0}".format(",".join(all_vars)))
print("Done.")
def print_weights_in_checkpoint(model_folder, cp_name):
"""Prints all weights in Tensorflow checkpoint.
Args:
model_folder: Folder containing checkpoint
cp_name: Name of checkpoint
Returns:
"""
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
print_tensors_in_checkpoint_file(file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True)