1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Successfully run training

This commit is contained in:
Young
2024-07-10 06:25:30 +00:00
parent a9fc3435ab
commit 4c057f645e
2 changed files with 52 additions and 8 deletions

View File

@@ -488,12 +488,40 @@ class GeneralPTNN(Model):
raise ValueError("unknown metric `%s`" % self.metric)
def _get_fl(self, data: torch.Tensor):
"""
get feature and label from data
- Handle the different data shape of time series and tabular data
Parameters
----------
data : torch.Tensor
input data which maybe 3 dimension or 2 dimension
- 3dim: [batch_size, time_step, feature_dim]
- 2dim: [batch_size, feature_dim]
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
"""
if data.dim() == 3:
# it is a time series dataset
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
elif data.dim() == 2:
# it is a tabular dataset
feature = data[:, 0:-1].to(self.device)
label = data[:, -1].to(self.device)
else:
raise ValueError("Unsupported data shape.")
return feature, label
def train_epoch(self, data_loader):
self.dnn_model.train()
for data, weight in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
feature , label = self._get_fl(data)
pred = self.dnn_model(feature.float())
loss = self.loss_fn(pred, label, weight.to(self.device))
@@ -526,19 +554,18 @@ class GeneralPTNN(Model):
def fit(
self,
dataset,
dataset: Union[DatasetH, TSDatasetH],
evals_result=dict(),
save_path=None,
reweighter=None,
):
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
if reweighter is None:
wl_train = np.ones(len(dl_train))
wl_valid = np.ones(len(dl_valid))
@@ -548,6 +575,15 @@ class GeneralPTNN(Model):
else:
raise ValueError("Unsupported reweighter type.")
# Preprocess for data. To align to Dataset Interface for DataLoader
if ists:
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
else:
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
dl_train = dl_train.values
dl_valid = dl_valid.values
train_loader = DataLoader(
ConcatDataset(dl_train, wl_train),
batch_size=self.batch_size,
@@ -562,6 +598,7 @@ class GeneralPTNN(Model):
num_workers=self.n_jobs,
drop_last=True,
)
del dl_train, dl_valid, wl_train, wl_valid
save_path = get_or_create_path(save_path)
@@ -605,7 +642,7 @@ class GeneralPTNN(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
def predict(self, dataset: Union[DatasetH, TSDatasetH]):
if not self.fitted:
raise ValueError("model is not fitted yet!")

View File

@@ -67,9 +67,16 @@ class TestNN(TestAutoData):
"dropout":0.,
},
),
GeneralPTNN(
n_epochs=2,
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
pt_model_kwargs={
"input_dim":3,
},
),
]
for ds, model in zip((tsds, tbds), model_l):
for ds, model in reversed(list(zip((tsds, tbds), model_l))):
model.fit(ds) # It works
model.predict(ds) # It works
break