1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Update settings.

This commit is contained in:
lwwang1995
2020-12-07 21:17:02 +08:00
committed by you-n-g
parent 70fb760830
commit 666e1ffcbd
2 changed files with 10 additions and 4 deletions

View File

@@ -57,7 +57,7 @@ task:
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-1
lr: 5e-2
early_stop: 10
batch_size: 800
metric: loss

View File

@@ -32,15 +32,17 @@ from ...contrib.model.pytorch_gru import GRUModel
class DailyBatchSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
self.data = self.data_source.data.loc[self.data_source.get_index()]
self.daily_count = self.data.groupby(level=0).size().values
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)
self.daily_count = self.data.groupby(level=0).size().values[1:]
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)[1:]
def __iter__(self):
for idx, count in zip(self.daily_index, self.daily_count):
yield slice(idx, idx + count)
yield np.arange(idx, idx + count)
def __len__(self):
return len(self.data_source)
@@ -202,6 +204,8 @@ class GATs(Model):
self.GAT_model.train()
for data in data_loader:
data = data.squeeze()
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
@@ -222,6 +226,7 @@ class GATs(Model):
for data in data_loader:
data = data.squeeze()
feature = data[:, :, 0:-1].to(self.device)
# feature[torch.isnan(feature)] = 0
label = data[:, -1, -1].to(self.device)
@@ -335,6 +340,7 @@ class GATs(Model):
for data in test_loader:
data = data.squeeze()
feature = data[:, :, 0:-1].to(self.device)
with torch.no_grad():