mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
format by black
This commit is contained in:
@@ -122,7 +122,7 @@ class MTSDatasetH(DatasetH):
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
input_size=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
|
||||
@@ -13,6 +13,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except:
|
||||
@@ -84,8 +85,10 @@ class TRAModel(Model):
|
||||
|
||||
assert memory_mode in ["sample", "daily"], "invalid memory mode"
|
||||
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
|
||||
assert transport_method == "none" or tra_config['num_states'] > 1, "optimal transport requires `num_states` > 1"
|
||||
assert memory_mode != "daily" or tra_config['src_info'] == 'TPE', "daily transport can only support TPE as `src_info`"
|
||||
assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1"
|
||||
assert (
|
||||
memory_mode != "daily" or tra_config["src_info"] == "TPE"
|
||||
), "daily transport can only support TPE as `src_info`"
|
||||
|
||||
if transport_method == "router" and not eval_train:
|
||||
self.logger.warning("`eval_train` will be ignored when using TRA.router")
|
||||
@@ -246,7 +249,9 @@ class TRAModel(Model):
|
||||
all_preds, choice, prob = self.tra(hidden, state)
|
||||
|
||||
if not is_pretrain and self.transport_method != "none":
|
||||
loss, pred, L, P = self.transport_fn(all_preds, label, choice, prob, count, self.transport_method, training=False)
|
||||
loss, pred, L, P = self.transport_fn(
|
||||
all_preds, label, choice, prob, count, self.transport_method, training=False
|
||||
)
|
||||
data_set.assign_data(index, L) # save loss to memory
|
||||
else:
|
||||
pred = all_preds.mean(dim=1)
|
||||
@@ -614,7 +619,7 @@ class TRA(nn.Module):
|
||||
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
|
||||
super().__init__()
|
||||
|
||||
assert src_info in ['LR', 'TPE', 'LR_TPE'], 'invalid `src_info`'
|
||||
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
|
||||
|
||||
self.num_states = num_states
|
||||
self.tau = tau
|
||||
|
||||
Reference in New Issue
Block a user