mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
225 lines
7.0 KiB
Python
225 lines
7.0 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The Google Research Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Lint as: python3
|
|
"""Generic helper functions used across codebase."""
|
|
|
|
import os
|
|
import pathlib
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
|
|
|
|
|
|
# Generic.
|
|
def get_single_col_by_input_type(input_type, column_definition):
|
|
"""Returns name of single column.
|
|
|
|
Args:
|
|
input_type: Input type of column to extract
|
|
column_definition: Column definition list for experiment
|
|
"""
|
|
|
|
l = [tup[0] for tup in column_definition if tup[2] == input_type]
|
|
|
|
if len(l) != 1:
|
|
raise ValueError("Invalid number of columns for {}".format(input_type))
|
|
|
|
return l[0]
|
|
|
|
|
|
def extract_cols_from_data_type(data_type, column_definition, excluded_input_types):
|
|
"""Extracts the names of columns that correspond to a define data_type.
|
|
|
|
Args:
|
|
data_type: DataType of columns to extract.
|
|
column_definition: Column definition to use.
|
|
excluded_input_types: Set of input types to exclude
|
|
|
|
Returns:
|
|
List of names for columns with data type specified.
|
|
"""
|
|
return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types]
|
|
|
|
|
|
# Loss functions.
|
|
def tensorflow_quantile_loss(y, y_pred, quantile):
|
|
"""Computes quantile loss for tensorflow.
|
|
|
|
Standard quantile loss as defined in the "Training Procedure" section of
|
|
the main TFT paper
|
|
|
|
Args:
|
|
y: Targets
|
|
y_pred: Predictions
|
|
quantile: Quantile to use for loss calculations (between 0 & 1)
|
|
|
|
Returns:
|
|
Tensor for quantile loss.
|
|
"""
|
|
|
|
# Checks quantile
|
|
if quantile < 0 or quantile > 1:
|
|
raise ValueError("Illegal quantile value={}! Values should be between 0 and 1.".format(quantile))
|
|
|
|
prediction_underflow = y - y_pred
|
|
q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum(
|
|
-prediction_underflow, 0.0
|
|
)
|
|
|
|
return tf.reduce_sum(q_loss, axis=-1)
|
|
|
|
|
|
def numpy_normalised_quantile_loss(y, y_pred, quantile):
|
|
"""Computes normalised quantile loss for numpy arrays.
|
|
|
|
Uses the q-Risk metric as defined in the "Training Procedure" section of the
|
|
main TFT paper.
|
|
|
|
Args:
|
|
y: Targets
|
|
y_pred: Predictions
|
|
quantile: Quantile to use for loss calculations (between 0 & 1)
|
|
|
|
Returns:
|
|
Float for normalised quantile loss.
|
|
"""
|
|
prediction_underflow = y - y_pred
|
|
weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum(
|
|
-prediction_underflow, 0.0
|
|
)
|
|
|
|
quantile_loss = weighted_errors.mean()
|
|
normaliser = y.abs().mean()
|
|
|
|
return 2 * quantile_loss / normaliser
|
|
|
|
|
|
# OS related functions.
|
|
def create_folder_if_not_exist(directory):
|
|
"""Creates folder if it doesn't exist.
|
|
|
|
Args:
|
|
directory: Folder path to create.
|
|
"""
|
|
# Also creates directories recursively
|
|
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
# Tensorflow related functions.
|
|
def get_default_tensorflow_config(tf_device="gpu", gpu_id=0):
|
|
"""Creates tensorflow config for graphs to run on CPU or GPU.
|
|
|
|
Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi
|
|
GPU machines.
|
|
|
|
Args:
|
|
tf_device: 'cpu' or 'gpu'
|
|
gpu_id: GPU ID to use if relevant
|
|
|
|
Returns:
|
|
Tensorflow config.
|
|
"""
|
|
|
|
if tf_device == "cpu":
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # for training on cpu
|
|
tf_config = tf.ConfigProto(log_device_placement=False, device_count={"GPU": 0})
|
|
|
|
else:
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
|
|
print("Selecting GPU ID={}".format(gpu_id))
|
|
|
|
tf_config = tf.ConfigProto(log_device_placement=False)
|
|
tf_config.gpu_options.allow_growth = True
|
|
|
|
return tf_config
|
|
|
|
|
|
def save(tf_session, model_folder, cp_name, scope=None):
|
|
"""Saves Tensorflow graph to checkpoint.
|
|
|
|
Saves all trainiable variables under a given variable scope to checkpoint.
|
|
|
|
Args:
|
|
tf_session: Session containing graph
|
|
model_folder: Folder to save models
|
|
cp_name: Name of Tensorflow checkpoint
|
|
scope: Variable scope containing variables to save
|
|
"""
|
|
# Save model
|
|
if scope is None:
|
|
saver = tf.train.Saver()
|
|
else:
|
|
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
|
|
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
|
|
|
|
save_path = saver.save(tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name)))
|
|
print("Model saved to: {0}".format(save_path))
|
|
|
|
|
|
def load(tf_session, model_folder, cp_name, scope=None, verbose=False):
|
|
"""Loads Tensorflow graph from checkpoint.
|
|
|
|
Args:
|
|
tf_session: Session to load graph into
|
|
model_folder: Folder containing serialised model
|
|
cp_name: Name of Tensorflow checkpoint
|
|
scope: Variable scope to use.
|
|
verbose: Whether to print additional debugging information.
|
|
"""
|
|
# Load model proper
|
|
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
|
|
|
|
print("Loading model from {0}".format(load_path))
|
|
|
|
print_weights_in_checkpoint(model_folder, cp_name)
|
|
|
|
initial_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
|
|
|
|
# Saver
|
|
if scope is None:
|
|
saver = tf.train.Saver()
|
|
else:
|
|
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
|
|
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
|
|
# Load
|
|
saver.restore(tf_session, load_path)
|
|
all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
|
|
|
|
if verbose:
|
|
print("Restored {0}".format(",".join(initial_vars.difference(all_vars))))
|
|
print("Existing {0}".format(",".join(all_vars.difference(initial_vars))))
|
|
print("All {0}".format(",".join(all_vars)))
|
|
|
|
print("Done.")
|
|
|
|
|
|
def print_weights_in_checkpoint(model_folder, cp_name):
|
|
"""Prints all weights in Tensorflow checkpoint.
|
|
|
|
Args:
|
|
model_folder: Folder containing checkpoint
|
|
cp_name: Name of checkpoint
|
|
|
|
Returns:
|
|
|
|
"""
|
|
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
|
|
|
|
print_tensors_in_checkpoint_file(file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True)
|