mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
Add sample for Gats.
This commit is contained in:
@@ -22,6 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
@@ -30,6 +31,21 @@ from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class DailyBatchSampler(Sampler):
|
||||
def __init__(self, data_souce):
|
||||
self.data_source = data_source
|
||||
self.data = self.data_source.loc[self.data_source.get_index()]
|
||||
self.daily_count = self.data.groupby(level=0).size().values
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)
|
||||
|
||||
def __iter__(self):
|
||||
for idx, count in zip(self.daily_index, self.daily_count):
|
||||
yield slice(idx, idx + count)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
|
||||
class GATs(Model):
|
||||
"""GATs Model
|
||||
|
||||
@@ -235,8 +251,11 @@ class GATs(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
sampler_train = DailyBatchSampler(dl_train)
|
||||
sampler_valid = DailyBatchSampler(dl_valid)
|
||||
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
@@ -307,7 +326,8 @@ class GATs(Model):
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
sampler_test = DailyBatchSampler(dl_test)
|
||||
test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)
|
||||
self.ALSTM_model.eval()
|
||||
preds = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user