1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

format by black

This commit is contained in:
Dong Zhou
2021-07-21 13:28:43 +08:00
committed by you-n-g
parent 07655f2d5b
commit a2c38c979e
2 changed files with 10 additions and 5 deletions

View File

@@ -122,7 +122,7 @@ class MTSDatasetH(DatasetH):
shuffle=True,
drop_last=False,
input_size=None,
**kwargs
**kwargs,
):
assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage"

View File

@@ -13,6 +13,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
try:
from torch.utils.tensorboard import SummaryWriter
except:
@@ -84,8 +85,10 @@ class TRAModel(Model):
assert memory_mode in ["sample", "daily"], "invalid memory mode"
assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}"
assert transport_method == "none" or tra_config['num_states'] > 1, "optimal transport requires `num_states` > 1"
assert memory_mode != "daily" or tra_config['src_info'] == 'TPE', "daily transport can only support TPE as `src_info`"
assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1"
assert (
memory_mode != "daily" or tra_config["src_info"] == "TPE"
), "daily transport can only support TPE as `src_info`"
if transport_method == "router" and not eval_train:
self.logger.warning("`eval_train` will be ignored when using TRA.router")
@@ -246,7 +249,9 @@ class TRAModel(Model):
all_preds, choice, prob = self.tra(hidden, state)
if not is_pretrain and self.transport_method != "none":
loss, pred, L, P = self.transport_fn(all_preds, label, choice, prob, count, self.transport_method, training=False)
loss, pred, L, P = self.transport_fn(
all_preds, label, choice, prob, count, self.transport_method, training=False
)
data_set.assign_data(index, L) # save loss to memory
else:
pred = all_preds.mean(dim=1)
@@ -614,7 +619,7 @@ class TRA(nn.Module):
def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"):
super().__init__()
assert src_info in ['LR', 'TPE', 'LR_TPE'], 'invalid `src_info`'
assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`"
self.num_states = num_states
self.tau = tau