mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
gat1
This commit is contained in:
@@ -7,19 +7,16 @@ from pathlib import Path
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.pytorch_gats import GAT
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
@@ -28,14 +28,12 @@ class GAT(Model):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
input dimension
|
||||
output_dim : int
|
||||
output dimension
|
||||
layers : tuple
|
||||
layer sizes
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
@@ -398,10 +396,6 @@ class GATModel(nn.Module):
|
||||
hidden = self.bn1(hidden)
|
||||
|
||||
gamma = self.cal_convariance(hidden, hidden)
|
||||
# gamma = hidden.mm(torch.t(hidden))
|
||||
# gamma = self.leaky_relu(gamma)
|
||||
# gamma = self.softmax(gamma)
|
||||
# gamma = gamma * (torch.ones(x.shape[0], x.shape[0]).to(device) - torch.diag(torch.ones(x.shape[0])).to(device))
|
||||
output = gamma.mm(hidden)
|
||||
output = self.fc(output)
|
||||
output = self.bn2(output)
|
||||
|
||||
Reference in New Issue
Block a user