From a2c38c979efe7d19ae05337dea718bd1ad88127d Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Wed, 21 Jul 2021 13:28:43 +0800 Subject: [PATCH] format by black --- qlib/contrib/data/dataset.py | 2 +- qlib/contrib/model/pytorch_tra.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py index 8989a6156..37d979c5b 100644 --- a/qlib/contrib/data/dataset.py +++ b/qlib/contrib/data/dataset.py @@ -122,7 +122,7 @@ class MTSDatasetH(DatasetH): shuffle=True, drop_last=False, input_size=None, - **kwargs + **kwargs, ): assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage" diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index b8db66916..6f12e3a3a 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F + try: from torch.utils.tensorboard import SummaryWriter except: @@ -84,8 +85,10 @@ class TRAModel(Model): assert memory_mode in ["sample", "daily"], "invalid memory mode" assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}" - assert transport_method == "none" or tra_config['num_states'] > 1, "optimal transport requires `num_states` > 1" - assert memory_mode != "daily" or tra_config['src_info'] == 'TPE', "daily transport can only support TPE as `src_info`" + assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1" + assert ( + memory_mode != "daily" or tra_config["src_info"] == "TPE" + ), "daily transport can only support TPE as `src_info`" if transport_method == "router" and not eval_train: self.logger.warning("`eval_train` will be ignored when using TRA.router") @@ -246,7 +249,9 @@ class TRAModel(Model): all_preds, choice, prob = self.tra(hidden, state) if not is_pretrain and self.transport_method != "none": - loss, pred, L, P = self.transport_fn(all_preds, label, choice, prob, count, self.transport_method, training=False) + loss, pred, L, P = self.transport_fn( + all_preds, label, choice, prob, count, self.transport_method, training=False + ) data_set.assign_data(index, L) # save loss to memory else: pred = all_preds.mean(dim=1) @@ -614,7 +619,7 @@ class TRA(nn.Module): def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"): super().__init__() - assert src_info in ['LR', 'TPE', 'LR_TPE'], 'invalid `src_info`' + assert src_info in ["LR", "TPE", "LR_TPE"], "invalid `src_info`" self.num_states = num_states self.tau = tau