mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 17:41:18 +08:00
304 lines
8.4 KiB
Python
304 lines
8.4 KiB
Python
from collections import namedtuple
|
|
from torch.nn.utils.rnn import pack_padded_sequence
|
|
from tianshou.data import Batch
|
|
import numpy as np
|
|
import torch
|
|
import copy
|
|
from typing import Union, Optional
|
|
from numbers import Number
|
|
|
|
|
|
def nan_weighted_avg(vals, weights, axis=None):
|
|
"""
|
|
|
|
:param vals: The values to be averaged on.
|
|
:param weights: The weights of weighted avrage.
|
|
:param axis: On which axis to calculate the weighted avrage. (Default value = None)
|
|
|
|
"""
|
|
assert vals.shape == weights.shape, AssertionError(f"{vals.shape} & {weights.shape}")
|
|
vals = vals.copy()
|
|
weights = weights.copy()
|
|
res = (vals * weights).sum(axis=axis) / weights.sum(axis=axis)
|
|
return np.nan_to_num(res, nan=vals[0])
|
|
|
|
|
|
def robust_auc(y_true, y_pred):
|
|
"""
|
|
|
|
Calculate AUC.
|
|
|
|
"""
|
|
try:
|
|
return roc_auc_score(y_true, y_pred)
|
|
except:
|
|
return np.nan
|
|
|
|
|
|
def merge_dicts(d1, d2):
|
|
"""
|
|
|
|
:param d1: Dict 1.
|
|
:type d1: dict
|
|
:param d2: Dict 2.
|
|
:returns: A new dict that is d1 and d2 deep merged.
|
|
:rtype: dict
|
|
|
|
"""
|
|
merged = copy.deepcopy(d1)
|
|
deep_update(merged, d2, True, [])
|
|
return merged
|
|
|
|
|
|
def deep_update(
|
|
original, new_dict, new_keys_allowed=False, whitelist=None, override_all_if_type_changes=None,
|
|
):
|
|
"""Updates original dict with values from new_dict recursively.
|
|
If new key is introduced in new_dict, then if new_keys_allowed is not
|
|
True, an error will be thrown. Further, for sub-dicts, if the key is
|
|
in the whitelist, then new subkeys can be introduced.
|
|
|
|
:param original: Dictionary with default values.
|
|
:type original: dict
|
|
:param new_dict(dict: dict): Dictionary with values to be updated
|
|
:param new_keys_allowed: Whether new keys are allowed. (Default value = False)
|
|
:type new_keys_allowed: bool
|
|
:param whitelist: List of keys that correspond to dict
|
|
values where new subkeys can be introduced. This is only at the top
|
|
level. (Default value = None)
|
|
:type whitelist: Optional[List[str]]
|
|
:param override_all_if_type_changes: List of top level
|
|
keys with value=dict, for which we always simply override the
|
|
entire value (dict), iff the "type" key in that value dict changes. (Default value = None)
|
|
:type override_all_if_type_changes: Optional[List[str]]
|
|
:param new_dict:
|
|
|
|
"""
|
|
whitelist = whitelist or []
|
|
override_all_if_type_changes = override_all_if_type_changes or []
|
|
|
|
for k, value in new_dict.items():
|
|
if k not in original and not new_keys_allowed:
|
|
raise Exception("Unknown config parameter `{}` ".format(k))
|
|
|
|
# Both orginal value and new one are dicts.
|
|
if isinstance(original.get(k), dict) and isinstance(value, dict):
|
|
# Check old type vs old one. If different, override entire value.
|
|
if (
|
|
k in override_all_if_type_changes
|
|
and "type" in value
|
|
and "type" in original[k]
|
|
and value["type"] != original[k]["type"]
|
|
):
|
|
original[k] = value
|
|
# Whitelisted key -> ok to add new subkeys.
|
|
elif k in whitelist:
|
|
deep_update(original[k], value, True)
|
|
# Non-whitelisted key.
|
|
else:
|
|
deep_update(original[k], value, new_keys_allowed)
|
|
# Original value not a dict OR new value not a dict:
|
|
# Override entire value.
|
|
else:
|
|
original[k] = value
|
|
return original
|
|
|
|
|
|
def get_seqlen(done_seq):
|
|
"""
|
|
|
|
:param done_seq:
|
|
|
|
"""
|
|
seqlen = []
|
|
length = 0
|
|
for i, done in enumerate(done_seq):
|
|
length += 1
|
|
if done:
|
|
seqlen.append(length)
|
|
length = 0
|
|
if length > 0:
|
|
seqlen.append(length)
|
|
return np.array(seqlen)
|
|
|
|
|
|
def generate_seq(seqlen, list):
|
|
"""
|
|
|
|
:param seqlen: param list:
|
|
:param list:
|
|
|
|
"""
|
|
res = []
|
|
index = 0
|
|
maxlen = np.max(seqlen)
|
|
for i in seqlen:
|
|
if isinstance(list, torch.Tensor):
|
|
res.append(torch.cat((list[index : index + i], torch.zeros_like(list[: maxlen - i])), dim=0,))
|
|
else:
|
|
res.append(np.concatenate((list[index : index + i], np.zeros_like(list[: maxlen - i])), axis=0))
|
|
index += i
|
|
if isinstance(list, torch.Tensor):
|
|
res = torch.stack(res, dim=0)
|
|
else:
|
|
res = np.stack(res, axis=0)
|
|
return res
|
|
|
|
|
|
def sequence_batch(batch):
|
|
"""
|
|
|
|
:param batch:
|
|
|
|
"""
|
|
seqlen = get_seqlen(batch.done)
|
|
# print(seqlen.max())
|
|
# print(len(seqlen))
|
|
res = Batch()
|
|
# print(batch.keys())
|
|
|
|
for v in batch.keys():
|
|
if v not in ["policy", "info"]:
|
|
res[v] = generate_seq(seqlen, batch[v])
|
|
else:
|
|
res[v] = batch[v]
|
|
res.seqlen = seqlen
|
|
return res
|
|
|
|
|
|
def flatten_seq(seq, seqlen):
|
|
"""
|
|
|
|
:param seq: param seqlen:
|
|
:param seqlen:
|
|
|
|
"""
|
|
res = []
|
|
for i, length in enumerate(seqlen):
|
|
res.append(seq[i][:length])
|
|
if isinstance(seq, torch.Tensor):
|
|
res = torch.cat(res, dim=0)
|
|
else:
|
|
res = np.concatenate(res, axis=0)
|
|
|
|
return res
|
|
|
|
|
|
def flatten_batch(batch):
|
|
"""
|
|
|
|
:param batch:
|
|
|
|
"""
|
|
for v in batch.keys():
|
|
if v in ["policy", "info", "seqlen"]:
|
|
continue
|
|
batch[v] = flatten_seq(batch[v], batch.seqlen)
|
|
return batch
|
|
|
|
|
|
def to_numpy(
|
|
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]
|
|
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
|
"""
|
|
|
|
:param x: Union[Batch:
|
|
:param dict: param list:
|
|
:param tuple: param np.ndarray:
|
|
:param torch: Tensor]:
|
|
:param x: Union[Batch:
|
|
:param list:
|
|
:param np.ndarray:
|
|
:param torch.Tensor]:
|
|
:param x: Union[Batch:
|
|
|
|
"""
|
|
if isinstance(x, torch.Tensor):
|
|
x = x.detach().cpu().numpy()
|
|
elif isinstance(x, dict):
|
|
for k, v in x.items():
|
|
x[k] = to_numpy(v)
|
|
elif isinstance(x, Batch):
|
|
x.to_numpy()
|
|
elif isinstance(x, (list, tuple)):
|
|
try:
|
|
x = to_numpy(_parse_value(x))
|
|
except TypeError:
|
|
x = [to_numpy(e) for e in x]
|
|
else: # fallback
|
|
x = np.asanyarray(x)
|
|
return x
|
|
|
|
|
|
def to_torch(
|
|
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
|
dtype: Optional[torch.dtype] = None,
|
|
device: Union[str, int, torch.device] = "cpu",
|
|
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
|
"""
|
|
|
|
:param x: Union[Batch:
|
|
:param dict: param list:
|
|
:param tuple: param np.ndarray:
|
|
:param torch: Tensor]:
|
|
:param dtype: Optional[torch.dtype]: (Default value = None)
|
|
:param device: Union[str:
|
|
:param int: param torch.device]: (Default value = 'cpu')
|
|
:param x: Union[Batch:
|
|
:param list:
|
|
:param np.ndarray:
|
|
:param torch.Tensor]:
|
|
:param dtype: Optional[torch.dtype]: (Default value = None)
|
|
:param device: Union[str:
|
|
:param torch.device]: (Default value = 'cpu')
|
|
:param x: Union[Batch:
|
|
:param dtype: Optional[torch.dtype]: (Default value = None)
|
|
:param device: Union[str:
|
|
|
|
"""
|
|
if isinstance(x, torch.Tensor):
|
|
if dtype is not None:
|
|
x = x.type(dtype)
|
|
x = x.to(device)
|
|
elif isinstance(x, dict):
|
|
for k, v in x.items():
|
|
x[k] = to_torch(v, dtype, device)
|
|
elif isinstance(x, Batch):
|
|
x.to_torch(dtype, device)
|
|
elif isinstance(x, (np.number, np.bool_, Number)):
|
|
x = to_torch(np.asanyarray(x), dtype, device)
|
|
elif isinstance(x, (list, tuple)):
|
|
try:
|
|
x = to_torch(_parse_value(x), dtype, device)
|
|
except TypeError:
|
|
x = [to_torch(e, dtype, device) for e in x]
|
|
else: # fallback
|
|
x = np.asanyarray(x)
|
|
if issubclass(x.dtype.type, (np.bool_, np.number)):
|
|
x = torch.from_numpy(x).to(device)
|
|
if dtype is not None:
|
|
x = x.type(dtype)
|
|
else:
|
|
raise TypeError(f"object {x} cannot be converted to torch.")
|
|
return x
|
|
|
|
|
|
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], y: torch.Tensor) -> Union[dict, Batch, torch.Tensor]:
|
|
"""
|
|
|
|
:param x: Union[torch.Tensor:
|
|
:param dict: param Batch:
|
|
:param np: ndarray]:
|
|
:param y: torch.Tensor:
|
|
:param x: Union[torch.Tensor:
|
|
:param Batch:
|
|
:param np.ndarray]:
|
|
:param y: torch.Tensor:
|
|
:param x: Union[torch.Tensor:
|
|
:param y: torch.Tensor:
|
|
:returns: to_torch(x, dtype=y.dtype, device=y.device)``.
|
|
|
|
"""
|
|
assert isinstance(y, torch.Tensor)
|
|
return to_torch(x, dtype=y.dtype, device=y.device)
|