1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 17:41:18 +08:00
Files
qlib/examples/trade/util.py
Yuchen Fang a03b08bb4c format
2021-01-28 00:41:02 +08:00

304 lines
8.4 KiB
Python

from collections import namedtuple
from torch.nn.utils.rnn import pack_padded_sequence
from tianshou.data import Batch
import numpy as np
import torch
import copy
from typing import Union, Optional
from numbers import Number
def nan_weighted_avg(vals, weights, axis=None):
"""
:param vals: The values to be averaged on.
:param weights: The weights of weighted avrage.
:param axis: On which axis to calculate the weighted avrage. (Default value = None)
"""
assert vals.shape == weights.shape, AssertionError(f"{vals.shape} & {weights.shape}")
vals = vals.copy()
weights = weights.copy()
res = (vals * weights).sum(axis=axis) / weights.sum(axis=axis)
return np.nan_to_num(res, nan=vals[0])
def robust_auc(y_true, y_pred):
"""
Calculate AUC.
"""
try:
return roc_auc_score(y_true, y_pred)
except:
return np.nan
def merge_dicts(d1, d2):
"""
:param d1: Dict 1.
:type d1: dict
:param d2: Dict 2.
:returns: A new dict that is d1 and d2 deep merged.
:rtype: dict
"""
merged = copy.deepcopy(d1)
deep_update(merged, d2, True, [])
return merged
def deep_update(
original, new_dict, new_keys_allowed=False, whitelist=None, override_all_if_type_changes=None,
):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
:param original: Dictionary with default values.
:type original: dict
:param new_dict(dict: dict): Dictionary with values to be updated
:param new_keys_allowed: Whether new keys are allowed. (Default value = False)
:type new_keys_allowed: bool
:param whitelist: List of keys that correspond to dict
values where new subkeys can be introduced. This is only at the top
level. (Default value = None)
:type whitelist: Optional[List[str]]
:param override_all_if_type_changes: List of top level
keys with value=dict, for which we always simply override the
entire value (dict), iff the "type" key in that value dict changes. (Default value = None)
:type override_all_if_type_changes: Optional[List[str]]
:param new_dict:
"""
whitelist = whitelist or []
override_all_if_type_changes = override_all_if_type_changes or []
for k, value in new_dict.items():
if k not in original and not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
# Both orginal value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
# Check old type vs old one. If different, override entire value.
if (
k in override_all_if_type_changes
and "type" in value
and "type" in original[k]
and value["type"] != original[k]["type"]
):
original[k] = value
# Whitelisted key -> ok to add new subkeys.
elif k in whitelist:
deep_update(original[k], value, True)
# Non-whitelisted key.
else:
deep_update(original[k], value, new_keys_allowed)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:
original[k] = value
return original
def get_seqlen(done_seq):
"""
:param done_seq:
"""
seqlen = []
length = 0
for i, done in enumerate(done_seq):
length += 1
if done:
seqlen.append(length)
length = 0
if length > 0:
seqlen.append(length)
return np.array(seqlen)
def generate_seq(seqlen, list):
"""
:param seqlen: param list:
:param list:
"""
res = []
index = 0
maxlen = np.max(seqlen)
for i in seqlen:
if isinstance(list, torch.Tensor):
res.append(torch.cat((list[index : index + i], torch.zeros_like(list[: maxlen - i])), dim=0,))
else:
res.append(np.concatenate((list[index : index + i], np.zeros_like(list[: maxlen - i])), axis=0))
index += i
if isinstance(list, torch.Tensor):
res = torch.stack(res, dim=0)
else:
res = np.stack(res, axis=0)
return res
def sequence_batch(batch):
"""
:param batch:
"""
seqlen = get_seqlen(batch.done)
# print(seqlen.max())
# print(len(seqlen))
res = Batch()
# print(batch.keys())
for v in batch.keys():
if v not in ["policy", "info"]:
res[v] = generate_seq(seqlen, batch[v])
else:
res[v] = batch[v]
res.seqlen = seqlen
return res
def flatten_seq(seq, seqlen):
"""
:param seq: param seqlen:
:param seqlen:
"""
res = []
for i, length in enumerate(seqlen):
res.append(seq[i][:length])
if isinstance(seq, torch.Tensor):
res = torch.cat(res, dim=0)
else:
res = np.concatenate(res, axis=0)
return res
def flatten_batch(batch):
"""
:param batch:
"""
for v in batch.keys():
if v in ["policy", "info", "seqlen"]:
continue
batch[v] = flatten_seq(batch[v], batch.seqlen)
return batch
def to_numpy(
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""
:param x: Union[Batch:
:param dict: param list:
:param tuple: param np.ndarray:
:param torch: Tensor]:
:param x: Union[Batch:
:param list:
:param np.ndarray:
:param torch.Tensor]:
:param x: Union[Batch:
"""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_numpy(v)
elif isinstance(x, Batch):
x.to_numpy()
elif isinstance(x, (list, tuple)):
try:
x = to_numpy(_parse_value(x))
except TypeError:
x = [to_numpy(e) for e in x]
else: # fallback
x = np.asanyarray(x)
return x
def to_torch(
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = "cpu",
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""
:param x: Union[Batch:
:param dict: param list:
:param tuple: param np.ndarray:
:param torch: Tensor]:
:param dtype: Optional[torch.dtype]: (Default value = None)
:param device: Union[str:
:param int: param torch.device]: (Default value = 'cpu')
:param x: Union[Batch:
:param list:
:param np.ndarray:
:param torch.Tensor]:
:param dtype: Optional[torch.dtype]: (Default value = None)
:param device: Union[str:
:param torch.device]: (Default value = 'cpu')
:param x: Union[Batch:
:param dtype: Optional[torch.dtype]: (Default value = None)
:param device: Union[str:
"""
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
x = x.to(device)
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_torch(v, dtype, device)
elif isinstance(x, Batch):
x.to_torch(dtype, device)
elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, (list, tuple)):
try:
x = to_torch(_parse_value(x), dtype, device)
except TypeError:
x = [to_torch(e, dtype, device) for e in x]
else: # fallback
x = np.asanyarray(x)
if issubclass(x.dtype.type, (np.bool_, np.number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
else:
raise TypeError(f"object {x} cannot be converted to torch.")
return x
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], y: torch.Tensor) -> Union[dict, Batch, torch.Tensor]:
"""
:param x: Union[torch.Tensor:
:param dict: param Batch:
:param np: ndarray]:
:param y: torch.Tensor:
:param x: Union[torch.Tensor:
:param Batch:
:param np.ndarray]:
:param y: torch.Tensor:
:param x: Union[torch.Tensor:
:param y: torch.Tensor:
:returns: to_torch(x, dtype=y.dtype, device=y.device)``.
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)