1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00
This commit is contained in:
Hong Zhang
2020-11-26 13:49:13 +08:00
parent f185f48185
commit 398f67f8d8
2 changed files with 6 additions and 15 deletions

View File

@@ -7,19 +7,16 @@ from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.pytorch_gats import GAT
from qlib.contrib.data.handler import ALPHA360_Denoise
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data
# from qlib.model.learner import train_model
from qlib.utils import init_instance_by_config
import pickle
if __name__ == "__main__":

View File

@@ -28,14 +28,12 @@ class GAT(Model):
Parameters
----------
input_dim : int
input dimension
output_dim : int
output dimension
layers : tuple
layer sizes
lr : float
learning rate
d_feat : int
input dimensions for each time step
metric : str
the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
@@ -398,10 +396,6 @@ class GATModel(nn.Module):
hidden = self.bn1(hidden)
gamma = self.cal_convariance(hidden, hidden)
# gamma = hidden.mm(torch.t(hidden))
# gamma = self.leaky_relu(gamma)
# gamma = self.softmax(gamma)
# gamma = gamma * (torch.ones(x.shape[0], x.shape[0]).to(device) - torch.diag(torch.ones(x.shape[0])).to(device))
output = gamma.mm(hidden)
output = self.fc(output)
output = self.bn2(output)