mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Update settings.
This commit is contained in:
@@ -57,7 +57,7 @@ task:
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-1
|
||||
lr: 5e-2
|
||||
early_stop: 10
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
|
||||
@@ -32,15 +32,17 @@ from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
|
||||
class DailyBatchSampler(Sampler):
|
||||
|
||||
def __init__(self, data_source):
|
||||
|
||||
self.data_source = data_source
|
||||
self.data = self.data_source.data.loc[self.data_source.get_index()]
|
||||
self.daily_count = self.data.groupby(level=0).size().values
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)
|
||||
self.daily_count = self.data.groupby(level=0).size().values[1:]
|
||||
self.daily_index = np.roll(np.cumsum(self.daily_count), 1)[1:]
|
||||
|
||||
def __iter__(self):
|
||||
for idx, count in zip(self.daily_index, self.daily_count):
|
||||
yield slice(idx, idx + count)
|
||||
yield np.arange(idx, idx + count)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
@@ -202,6 +204,8 @@ class GATs(Model):
|
||||
self.GAT_model.train()
|
||||
|
||||
for data in data_loader:
|
||||
|
||||
data = data.squeeze()
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
@@ -222,6 +226,7 @@ class GATs(Model):
|
||||
|
||||
for data in data_loader:
|
||||
|
||||
data = data.squeeze()
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
@@ -335,6 +340,7 @@ class GATs(Model):
|
||||
|
||||
for data in test_loader:
|
||||
|
||||
data = data.squeeze()
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user