1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00

Add sample for Gats.

This commit is contained in:
lwwang1995
2020-12-07 15:07:47 +08:00
committed by you-n-g
parent 65a9a72a88
commit 71ad651514

View File

@@ -22,6 +22,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
from ...model.base import Model
from ...data.dataset import DatasetH
@@ -30,6 +31,21 @@ from ...contrib.model.pytorch_lstm import LSTMModel
from ...contrib.model.pytorch_gru import GRUModel
class DailyBatchSampler(Sampler):
def __init__(self, data_souce):
self.data_source = data_source
self.data = self.data_source.loc[self.data_source.get_index()]
self.daily_count = self.data.groupby(level=0).size().values
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)
def __iter__(self):
for idx, count in zip(self.daily_index, self.daily_count):
yield slice(idx, idx + count)
def __len__(self):
return len(self.data_source)
class GATs(Model):
"""GATs Model
@@ -235,8 +251,11 @@ class GATs(Model):
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
sampler_train = DailyBatchSampler(dl_train)
sampler_valid = DailyBatchSampler(dl_valid)
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
if save_path == None:
save_path = create_save_path(save_path)
@@ -307,7 +326,8 @@ class GATs(Model):
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
sampler_test = DailyBatchSampler(dl_test)
test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)
self.ALSTM_model.eval()
preds = []