mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Successfully run training
This commit is contained in:
@@ -488,12 +488,40 @@ class GeneralPTNN(Model):
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
def _get_fl(self, data: torch.Tensor):
|
||||
"""
|
||||
get feature and label from data
|
||||
- Handle the different data shape of time series and tabular data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : torch.Tensor
|
||||
input data which maybe 3 dimension or 2 dimension
|
||||
- 3dim: [batch_size, time_step, feature_dim]
|
||||
- 2dim: [batch_size, feature_dim]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
if data.dim() == 3:
|
||||
# it is a time series dataset
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
elif data.dim() == 2:
|
||||
# it is a tabular dataset
|
||||
feature = data[:, 0:-1].to(self.device)
|
||||
label = data[:, -1].to(self.device)
|
||||
else:
|
||||
raise ValueError("Unsupported data shape.")
|
||||
return feature, label
|
||||
|
||||
def train_epoch(self, data_loader):
|
||||
self.dnn_model.train()
|
||||
|
||||
for data, weight in data_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
feature , label = self._get_fl(data)
|
||||
|
||||
pred = self.dnn_model(feature.float())
|
||||
loss = self.loss_fn(pred, label, weight.to(self.device))
|
||||
@@ -526,19 +554,18 @@ class GeneralPTNN(Model):
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset,
|
||||
dataset: Union[DatasetH, TSDatasetH],
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
reweighter=None,
|
||||
):
|
||||
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
if reweighter is None:
|
||||
wl_train = np.ones(len(dl_train))
|
||||
wl_valid = np.ones(len(dl_valid))
|
||||
@@ -548,6 +575,15 @@ class GeneralPTNN(Model):
|
||||
else:
|
||||
raise ValueError("Unsupported reweighter type.")
|
||||
|
||||
# Preprocess for data. To align to Dataset Interface for DataLoader
|
||||
if ists:
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
else:
|
||||
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
|
||||
dl_train = dl_train.values
|
||||
dl_valid = dl_valid.values
|
||||
|
||||
train_loader = DataLoader(
|
||||
ConcatDataset(dl_train, wl_train),
|
||||
batch_size=self.batch_size,
|
||||
@@ -562,6 +598,7 @@ class GeneralPTNN(Model):
|
||||
num_workers=self.n_jobs,
|
||||
drop_last=True,
|
||||
)
|
||||
del dl_train, dl_valid, wl_train, wl_valid
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
@@ -605,7 +642,7 @@ class GeneralPTNN(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: Union[DatasetH, TSDatasetH]):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
|
||||
@@ -67,9 +67,16 @@ class TestNN(TestAutoData):
|
||||
"dropout":0.,
|
||||
},
|
||||
),
|
||||
GeneralPTNN(
|
||||
n_epochs=2,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
|
||||
pt_model_kwargs={
|
||||
"input_dim":3,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for ds, model in zip((tsds, tbds), model_l):
|
||||
for ds, model in reversed(list(zip((tsds, tbds), model_l))):
|
||||
model.fit(ds) # It works
|
||||
model.predict(ds) # It works
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user