1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 10:31:00 +08:00
Files
qlib/qlib/contrib/meta/data_selection/utils.py
you-n-g 32c3070b73 Refine DDG-DA (#1472)
* Run ddg-da successfully

* Support include valid; More parameters

* Support L2 reg & visualization

* Blackformat

* Enable fill_method

* Support specify handler & optim dataset

* Fix Pylint
2023-04-07 15:00:21 +08:00

114 lines
3.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch
from torch import nn
from qlib.constant import EPS
from qlib.log import get_module_logger
class ICLoss(nn.Module):
def forward(self, pred, y, idx, skip_size=50):
"""forward.
FIXME:
- Some times it will be a slightly different from the result from `pandas.corr()`
- It may be caused by the precision problem of model;
:param pred:
:param y:
:param idx: Assume the level of the idx is (date, inst), and it is sorted
"""
prev = None
diff_point = []
for i, (date, inst) in enumerate(idx):
if date != prev:
diff_point.append(i)
prev = date
diff_point.append(None)
# The lengths of diff_point will be one more larger then diff_point
ic_all = 0.0
skip_n = 0
for start_i, end_i in zip(diff_point, diff_point[1:]):
pred_focus = pred[start_i:end_i] # TODO: just for fake
if pred_focus.shape[0] < skip_size:
# skip some days which have very small amount of stock.
skip_n += 1
continue
y_focus = y[start_i:end_i]
if pred_focus.std() < EPS or y_focus.std() < EPS:
# These cases often happend at the end of test data.
# Usually caused by fillna(0.)
skip_n += 1
continue
ic_day = torch.dot(
(pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(),
(y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(),
)
ic_all += ic_day
if len(diff_point) - 1 - skip_n <= 0:
raise ValueError("No enough data for calculating IC")
if skip_n > 0:
get_module_logger("ICLoss").info(
f"{skip_n} days are skipped due to zero std or small scale of valid samples."
)
ic_mean = ic_all / (len(diff_point) - 1 - skip_n)
return -ic_mean # ic loss
def preds_to_weight_with_clamp(preds, clip_weight=None, clip_method="tanh"):
"""
Clip the weights.
Parameters
----------
clip_weight: float
The clip threshold.
clip_method: str
The clip method. Current available: "clamp", "tanh", and "sigmoid".
"""
if clip_weight is not None:
if clip_method == "clamp":
weights = torch.exp(preds)
weights = weights.clamp(1.0 / clip_weight, clip_weight)
elif clip_method == "tanh":
weights = torch.exp(torch.tanh(preds) * np.log(clip_weight))
elif clip_method == "sigmoid":
# intuitively assume its sum is 1
if clip_weight == 0.0:
weights = torch.ones_like(preds)
else:
sm = nn.Sigmoid()
weights = sm(preds) * clip_weight # TODO: The clip_weight is useless here.
weights = weights / torch.sum(weights) * weights.numel()
else:
raise ValueError("Unknown clip_method")
else:
weights = torch.exp(preds)
return weights
class SingleMetaBase(nn.Module):
def __init__(self, hist_n, clip_weight=None, clip_method="clamp"):
# method can be tanh or clamp
super().__init__()
self.clip_weight = clip_weight
if clip_method in ["tanh", "clamp"]:
if self.clip_weight is not None and self.clip_weight < 1.0:
self.clip_weight = 1 / self.clip_weight
self.clip_method = clip_method
def is_enabled(self):
if self.clip_weight is None:
return True
if self.clip_method == "sigmoid":
if self.clip_weight > 0.0:
return True
else:
if self.clip_weight > 1.0:
return True
return False