1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Enhance pytorch nn (#917)

* enhance pytorch_nn

* fix dim bug

* Black format

* Fix pylint error
This commit is contained in:
you-n-g
2022-02-15 19:22:48 +08:00
committed by GitHub
parent 0e8b94a552
commit 60d45ad770
13 changed files with 281 additions and 139 deletions

5
.pylintrc Normal file
View File

@@ -0,0 +1,5 @@
[TYPECHECK]
# https://stackoverflow.com/a/53572939
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*

View File

@@ -14,9 +14,19 @@ Continuous Integration (CI) tools help you stick to the quality standards by run
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page. When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line. 1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).
You can fix the bug by inputing the following code in the command line.
.. code-block:: python .. code-block:: python
pip install black pip install black
python -m black . -l 120 python -m black . -l 120
2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).
Sometime pylint's restrictions are not that reasonable. You can ignore specific errors like this
.. code-block:: python
return -ICLoss()(pred, target, index) # pylint: disable=E1130

View File

@@ -63,8 +63,6 @@ task:
module_path: qlib.contrib.model.pytorch_nn module_path: qlib.contrib.model.pytorch_nn
kwargs: kwargs:
loss: mse loss: mse
input_dim: 157
output_dim: 1
lr: 0.002 lr: 0.002
lr_decay: 0.96 lr_decay: 0.96
lr_decay_steps: 100 lr_decay_steps: 100
@@ -73,6 +71,8 @@ task:
batch_size: 8192 batch_size: 8192
GPU: 0 GPU: 0
weight_decay: 0.0002 weight_decay: 0.0002
pt_model_kwargs:
input_dim: 157
dataset: dataset:
class: DatasetH class: DatasetH
module_path: qlib.data.dataset module_path: qlib.data.dataset

View File

@@ -51,8 +51,6 @@ task:
module_path: qlib.contrib.model.pytorch_nn module_path: qlib.contrib.model.pytorch_nn
kwargs: kwargs:
loss: mse loss: mse
input_dim: 360
output_dim: 1
lr: 0.002 lr: 0.002
lr_decay: 0.96 lr_decay: 0.96
lr_decay_steps: 100 lr_decay_steps: 100
@@ -60,6 +58,8 @@ task:
max_steps: 8000 max_steps: 8000
batch_size: 4096 batch_size: 4096
GPU: 0 GPU: 0
pt_model_kwargs:
input_dim: 360
dataset: dataset:
class: DatasetH class: DatasetH
module_path: qlib.data.dataset module_path: qlib.data.dataset

View File

@@ -9,6 +9,9 @@ from torch import nn
class ICLoss(nn.Module): class ICLoss(nn.Module):
def forward(self, pred, y, idx, skip_size=50): def forward(self, pred, y, idx, skip_size=50):
"""forward. """forward.
FIXME:
- Some times it will be a slightly different from the result from `pandas.corr()`
- It may be caused by the precision problem of model;
:param pred: :param pred:
:param y: :param y:

View File

@@ -10,6 +10,7 @@ from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP from ...data.dataset.handler import DataHandlerLP
from ...model.interpret.base import LightGBMFInt from ...model.interpret.base import LightGBMFInt
from ...data.dataset.weight import Reweighter from ...data.dataset.weight import Reweighter
from qlib.workflow import R
class LGBModel(ModelFT, LightGBMFInt): class LGBModel(ModelFT, LightGBMFInt):
@@ -59,10 +60,12 @@ class LGBModel(ModelFT, LightGBMFInt):
num_boost_round=None, num_boost_round=None,
early_stopping_rounds=None, early_stopping_rounds=None,
verbose_eval=20, verbose_eval=20,
evals_result=dict(), evals_result=None,
reweighter=None, reweighter=None,
**kwargs **kwargs,
): ):
if evals_result is None:
evals_result = {} # in case of unsafety of Python default values
ds_l = self._prepare_data(dataset, reweighter) ds_l = self._prepare_data(dataset, reweighter)
ds, names = list(zip(*ds_l)) ds, names = list(zip(*ds_l))
self.model = lgb.train( self.model = lgb.train(
@@ -76,10 +79,13 @@ class LGBModel(ModelFT, LightGBMFInt):
), ),
verbose_eval=verbose_eval, verbose_eval=verbose_eval,
evals_result=evals_result, evals_result=evals_result,
**kwargs **kwargs,
) )
for k in names: for k in names:
evals_result[k] = list(evals_result[k].values())[0] for key, val in evals_result[k].items():
name = f"{key}.{k}"
for epoch, m in enumerate(val):
R.log_metrics(**{name.replace("@", "_"): m}, step=epoch)
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None: if self.model is None:

View File

@@ -263,8 +263,8 @@ class GATs(Model):
model_dict = self.GAT_model.state_dict() model_dict = self.GAT_model.state_dict()
pretrained_dict = { pretrained_dict = {
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
} # pylint: disable=E1135 }
model_dict.update(pretrained_dict) model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict) self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...") self.logger.info("Loading pretrained model Done...")

View File

@@ -278,8 +278,8 @@ class GATs(Model):
model_dict = self.GAT_model.state_dict() model_dict = self.GAT_model.state_dict()
pretrained_dict = { pretrained_dict = {
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
} # pylint: disable=E1135 }
model_dict.update(pretrained_dict) model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict) self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...") self.logger.info("Loading pretrained model Done...")

View File

@@ -4,11 +4,12 @@
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import defaultdict
import os import os
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from typing import Text, Union from typing import Callable, Optional, Text, Union
from sklearn.metrics import roc_auc_score, mean_squared_error from sklearn.metrics import roc_auc_score, mean_squared_error
import torch import torch
@@ -20,9 +21,18 @@ from ...model.base import Model
from ...data.dataset import DatasetH from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.weight import Reweighter from ...data.dataset.weight import Reweighter
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path from ...utils import (
auto_filter_kwargs,
init_instance_by_config,
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
)
from ...log import get_module_logger from ...log import get_module_logger
from ...workflow import R from ...workflow import R
from qlib.contrib.meta.data_selection.utils import ICLoss
from torch.nn import DataParallel
from torch.utils.data import DataLoader, SequentialSampler
class DNNModelPytorch(Model): class DNNModelPytorch(Model):
@@ -49,9 +59,6 @@ class DNNModelPytorch(Model):
def __init__( def __init__(
self, self,
input_dim=360,
output_dim=1,
layers=(256,),
lr=0.001, lr=0.001,
max_steps=300, max_steps=300,
batch_size=2000, batch_size=2000,
@@ -64,14 +71,23 @@ class DNNModelPytorch(Model):
GPU=0, GPU=0,
seed=None, seed=None,
weight_decay=0.0, weight_decay=0.0,
**kwargs data_parall=False,
scheduler: Optional[Union[Callable]] = "default", # when it is Callable, it accept one argument named optimizer
init_model=None,
eval_train_metric=True,
pt_model_uri="qlib.contrib.model.pytorch_nn.Net",
pt_model_kwargs={
"input_dim": 360,
"layers": (256,),
},
valid_key=DataHandlerLP.DK_L,
# TODO: Infer Key is a more reasonable key. But it requires more detailed processing on label processing
): ):
# Set logger. # Set logger.
self.logger = get_module_logger("DNNModelPytorch") self.logger = get_module_logger("DNNModelPytorch")
self.logger.info("DNN pytorch version...") self.logger.info("DNN pytorch version...")
# set hyper-parameters. # set hyper-parameters.
self.layers = layers
self.lr = lr self.lr = lr
self.max_steps = max_steps self.max_steps = max_steps
self.batch_size = batch_size self.batch_size = batch_size
@@ -81,41 +97,36 @@ class DNNModelPytorch(Model):
self.lr_decay_steps = lr_decay_steps self.lr_decay_steps = lr_decay_steps
self.optimizer = optimizer.lower() self.optimizer = optimizer.lower()
self.loss_type = loss self.loss_type = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") if isinstance(GPU, str):
self.device = torch.device(GPU)
else:
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed self.seed = seed
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.data_parall = data_parall
self.eval_train_metric = eval_train_metric
self.valid_key = valid_key
self.best_step = None
self.logger.info( self.logger.info(
"DNN parameters setting:" "DNN parameters setting:"
"\nlayers : {}" f"\nlr : {lr}"
"\nlr : {}" f"\nmax_steps : {max_steps}"
"\nmax_steps : {}" f"\nbatch_size : {batch_size}"
"\nbatch_size : {}" f"\nearly_stop_rounds : {early_stop_rounds}"
"\nearly_stop_rounds : {}" f"\neval_steps : {eval_steps}"
"\neval_steps : {}" f"\nlr_decay : {lr_decay}"
"\nlr_decay : {}" f"\nlr_decay_steps : {lr_decay_steps}"
"\nlr_decay_steps : {}" f"\noptimizer : {optimizer}"
"\noptimizer : {}" f"\nloss_type : {loss}"
"\nloss_type : {}" f"\nseed : {seed}"
"\nseed : {}" f"\ndevice : {self.device}"
"\ndevice : {}" f"\nuse_GPU : {self.use_gpu}"
"\nuse_GPU : {}" f"\nweight_decay : {weight_decay}"
"\nweight_decay : {}".format( f"\nenable data parall : {self.data_parall}"
layers, f"\npt_model_uri: {pt_model_uri}"
lr, f"\npt_model_kwargs: {pt_model_kwargs}"
max_steps,
batch_size,
early_stop_rounds,
eval_steps,
lr_decay,
lr_decay_steps,
optimizer,
loss,
seed,
self.device,
self.use_gpu,
weight_decay,
)
) )
if self.seed is not None: if self.seed is not None:
@@ -126,7 +137,14 @@ class DNNModelPytorch(Model):
raise NotImplementedError("loss {} is not supported!".format(loss)) raise NotImplementedError("loss {} is not supported!".format(loss))
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type) if init_model is None:
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
if self.data_parall:
self.dnn_model = DataParallel(self.dnn_model).to(self.device)
else:
self.dnn_model = init_model
self.logger.info("model:\n{:}".format(self.dnn_model)) self.logger.info("model:\n{:}".format(self.dnn_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model))) self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
@@ -137,19 +155,24 @@ class DNNModelPytorch(Model):
else: else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
# Reduce learning rate when loss has stopped decrease if scheduler == "default":
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # Reduce learning rate when loss has stopped decrease
self.train_optimizer, self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
mode="min", self.train_optimizer,
factor=0.5, mode="min",
patience=10, factor=0.5,
verbose=True, patience=10,
threshold=0.0001, verbose=True,
threshold_mode="rel", threshold=0.0001,
cooldown=0, threshold_mode="rel",
min_lr=0.00001, cooldown=0,
eps=1e-08, min_lr=0.00001,
) eps=1e-08,
)
elif scheduler is None:
self.scheduler = None
else:
self.scheduler = scheduler(optimizer=self.train_optimizer)
self.fitted = False self.fitted = False
self.dnn_model.to(self.device) self.dnn_model.to(self.device)
@@ -166,40 +189,44 @@ class DNNModelPytorch(Model):
save_path=None, save_path=None,
reweighter=None, reweighter=None,
): ):
df_train, df_valid = dataset.prepare( has_valid = "valid" in dataset.segments
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L segments = ["train", "valid"]
) vars = ["x", "y", "w"]
x_train, y_train = df_train["feature"], df_train["label"] all_df = defaultdict(dict) # x_train, x_valid y_train, y_valid w_train, w_valid
x_valid, y_valid = df_valid["feature"], df_valid["label"] all_t = defaultdict(dict) # tensors
for seg in segments:
if seg in dataset.segments:
# df_train df_valid
df = dataset.prepare(
seg, col_set=["feature", "label"], data_key=self.valid_key if seg == "valid" else DataHandlerLP.DK_L
)
all_df["x"][seg] = df["feature"]
all_df["y"][seg] = df["label"]
if reweighter is None:
all_df["w"][seg] = pd.DataFrame(np.ones_like(all_df["y"][seg].values), index=df.index)
elif isinstance(reweighter, Reweighter):
all_df["w"][seg] = pd.DataFrame(reweighter.reweight(df))
else:
raise ValueError("Unsupported reweighter type.")
if reweighter is None: # get tensors
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index) for v in vars:
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index) all_t[v][seg] = torch.from_numpy(all_df[v][seg].values).float()
elif isinstance(reweighter, Reweighter): # if seg == "valid": # accelerate the eval of validation
w_train = pd.DataFrame(reweighter.reweight(df_train)) all_t[v][seg] = all_t[v][seg].to(self.device) # This will consume a lot of memory !!!!
w_valid = pd.DataFrame(reweighter.reweight(df_valid))
else: evals_result[seg] = []
raise ValueError("Unsupported reweighter type.")
save_path = get_or_create_path(save_path) save_path = get_or_create_path(save_path)
stop_steps = 0 stop_steps = 0
train_loss = 0 train_loss = 0
best_loss = np.inf best_loss = np.inf
evals_result["train"] = []
evals_result["valid"] = []
# train # train
self.logger.info("training...") self.logger.info("training...")
self.fitted = True self.fitted = True
# return # return
# prepare training data # prepare training data
x_train_values = torch.from_numpy(x_train.values).float() train_num = all_t["y"]["train"].shape[0]
y_train_values = torch.from_numpy(y_train.values).float()
w_train_values = torch.from_numpy(w_train.values).float()
train_num = y_train_values.shape[0]
# prepare validation data
x_val_auto = torch.from_numpy(x_valid.values).float().to(self.device)
y_val_auto = torch.from_numpy(y_valid.values).float().to(self.device)
w_val_auto = torch.from_numpy(w_valid.values).float().to(self.device)
for step in range(1, self.max_steps + 1): for step in range(1, self.max_steps + 1):
if stop_steps >= self.early_stop_rounds: if stop_steps >= self.early_stop_rounds:
@@ -210,9 +237,9 @@ class DNNModelPytorch(Model):
self.dnn_model.train() self.dnn_model.train()
self.train_optimizer.zero_grad() self.train_optimizer.zero_grad()
choice = np.random.choice(train_num, self.batch_size) choice = np.random.choice(train_num, self.batch_size)
x_batch_auto = x_train_values[choice].to(self.device) x_batch_auto = all_t["x"]["train"][choice].to(self.device)
y_batch_auto = y_train_values[choice].to(self.device) y_batch_auto = all_t["y"]["train"][choice].to(self.device)
w_batch_auto = w_train_values[choice].to(self.device) w_batch_auto = all_t["w"]["train"][choice].to(self.device)
# forward # forward
preds = self.dnn_model(x_batch_auto) preds = self.dnn_model(x_batch_auto)
@@ -226,44 +253,84 @@ class DNNModelPytorch(Model):
train_loss += loss.val train_loss += loss.val
# for evert `eval_steps` steps or at the last steps, we will evaluate the model. # for evert `eval_steps` steps or at the last steps, we will evaluate the model.
if step % self.eval_steps == 0 or step == self.max_steps: if step % self.eval_steps == 0 or step == self.max_steps:
stop_steps += 1 if has_valid:
train_loss /= self.eval_steps stop_steps += 1
train_loss /= self.eval_steps
with torch.no_grad(): with torch.no_grad():
self.dnn_model.eval() self.dnn_model.eval()
loss_val = AverageMeter()
# forward # forward
preds = self.dnn_model(x_val_auto) preds = self._nn_predict(all_t["x"]["valid"], return_cpu=False)
cur_loss_val = self.get_loss(preds, w_val_auto, y_val_auto, self.loss_type) cur_loss_val = self.get_loss(preds, all_t["w"]["valid"], all_t["y"]["valid"], self.loss_type)
loss_val.update(cur_loss_val.item()) loss_val = cur_loss_val.item()
R.log_metrics(val_loss=loss_val.val, step=step) metric_val = (
if verbose: self.get_metric(
self.logger.info( preds.reshape(-1), all_t["y"]["valid"].reshape(-1), all_df["x"]["valid"].index
"[Step {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val) )
) .detach()
evals_result["train"].append(train_loss) .cpu()
evals_result["valid"].append(loss_val.val) .numpy()
if loss_val.val < best_loss: .item()
)
R.log_metrics(val_loss=loss_val, step=step)
R.log_metrics(val_metric=metric_val, step=step)
if self.eval_train_metric:
metric_train = (
self.get_metric(
self._nn_predict(all_t["x"]["train"], return_cpu=False),
all_t["y"]["train"].reshape(-1),
all_df["x"]["train"].index,
)
.detach()
.cpu()
.numpy()
.item()
)
R.log_metrics(train_metric=metric_train, step=step)
else:
metric_train = -1
if verbose: if verbose:
self.logger.info( self.logger.info(
"\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format( f"[Step {step}]: train_loss {train_loss:.6f}, valid_loss {loss_val:.6f}, train_metric {metric_train:.6f}, valid_metric {metric_val:.6f}"
best_loss, loss_val.val
)
) )
best_loss = loss_val.val evals_result["train"].append(train_loss)
stop_steps = 0 evals_result["valid"].append(loss_val)
torch.save(self.dnn_model.state_dict(), save_path) if loss_val < best_loss:
train_loss = 0 if verbose:
# update learning rate self.logger.info(
self.scheduler.step(cur_loss_val) "\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format(
best_loss, loss_val
)
)
best_loss = loss_val
self.best_step = step
R.log_metrics(best_step=self.best_step, step=step)
stop_steps = 0
torch.save(self.dnn_model.state_dict(), save_path)
train_loss = 0
# update learning rate
if self.scheduler is not None:
auto_filter_kwargs(self.scheduler.step, warning=False)(metrics=cur_loss_val, epoch=step)
R.log_metrics(lr=self.get_lr(), step=step)
else:
# retraining mode
if self.scheduler is not None:
self.scheduler.step(epoch=step)
# restore the optimal parameters after training if has_valid:
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device)) # restore the optimal parameters after training
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))
if self.use_gpu: if self.use_gpu:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_lr(self):
assert len(self.train_optimizer.param_groups) == 1
return self.train_optimizer.param_groups[0]["lr"]
def get_loss(self, pred, w, target, loss_type): def get_loss(self, pred, w, target, loss_type):
pred, w, target = pred.reshape(-1), w.reshape(-1), target.reshape(-1)
if loss_type == "mse": if loss_type == "mse":
sqr_loss = torch.mul(pred - target, pred - target) sqr_loss = torch.mul(pred - target, pred - target)
loss = torch.mul(sqr_loss, w).mean() loss = torch.mul(sqr_loss, w).mean()
@@ -274,15 +341,54 @@ class DNNModelPytorch(Model):
else: else:
raise NotImplementedError("loss {} is not supported!".format(loss_type)) raise NotImplementedError("loss {} is not supported!".format(loss_type))
def get_metric(self, pred, target, index):
# NOTE: the order of the index must follow <datetime, instrument> sorted order
return -ICLoss()(pred, target, index) # pylint: disable=E1130
def _nn_predict(self, data, return_cpu=True):
"""Reusing predicting NN.
Scenarios
1) test inference (data may come from CPU and expect the output data is on CPU)
2) evaluation on training (data may come from GPU)
"""
if isinstance(data, torch.Tensor) and data.device.type != "cpu":
# GPU data
# CUDA data don't support pin_memory and multi-processing workers
num_workers = 0
pin_memory = False
else:
# CPU data
if not isinstance(data, torch.Tensor):
if isinstance(data, pd.DataFrame):
data = data.values
# else: CPU Tensor
num_workers = 8
pin_memory = True
data_loader = DataLoader(
data,
sampler=SequentialSampler(data),
batch_size=self.batch_size,
drop_last=False,
num_workers=num_workers,
pin_memory=pin_memory,
)
preds = []
self.dnn_model.eval()
with torch.no_grad():
for x in data_loader:
preds.append(self.dnn_model(x.to(self.device)).detach().reshape(-1))
if return_cpu:
preds = np.concatenate([pr.cpu().numpy() for pr in preds])
else:
preds = torch.cat(preds, axis=0)
return preds
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted: if not self.fitted:
raise ValueError("model is not fitted yet!") raise ValueError("model is not fitted yet!")
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device) preds = self._nn_predict(x_test_pd)
self.dnn_model.eval() return pd.Series(preds.reshape(-1), index=x_test_pd.index)
with torch.no_grad():
preds = self.dnn_model(x_test).detach().cpu().numpy()
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
def save(self, filename, **kwargs): def save(self, filename, **kwargs):
with save_multiple_parts_file(filename) as model_dir: with save_multiple_parts_file(filename) as model_dir:
@@ -322,16 +428,22 @@ class AverageMeter:
class Net(nn.Module): class Net(nn.Module):
def __init__(self, input_dim, output_dim, layers=(256, 512, 768, 512, 256, 128, 64), loss="mse"): def __init__(self, input_dim, output_dim=1, layers=(256,), act="LeakyReLU"):
super(Net, self).__init__() super(Net, self).__init__()
layers = [input_dim] + list(layers) layers = [input_dim] + list(layers)
dnn_layers = [] dnn_layers = []
drop_input = nn.Dropout(0.05) drop_input = nn.Dropout(0.05)
dnn_layers.append(drop_input) dnn_layers.append(drop_input)
hidden_units = None hidden_units = input_dim
for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])): for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
fc = nn.Linear(_input_dim, hidden_units) fc = nn.Linear(_input_dim, hidden_units)
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False) if act == "LeakyReLU":
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
elif act == "SiLU":
activation = nn.SiLU()
else:
raise NotImplementedError(f"This type of input is not supported")
bn = nn.BatchNorm1d(hidden_units) bn = nn.BatchNorm1d(hidden_units)
seq = nn.Sequential(fc, bn, activation) seq = nn.Sequential(fc, bn, activation)
dnn_layers.append(seq) dnn_layers.append(seq)

View File

@@ -13,3 +13,14 @@ class ConcatDataset(Dataset):
def __len__(self): def __len__(self):
return min(len(d) for d in self.datasets) return min(len(d) for d in self.datasets)
class IndexSampler:
def __init__(self, sampler):
self.sampler = sampler
def __getitem__(self, i: int):
return self.sampler[i], i
def __len__(self):
return len(self.sampler)

View File

@@ -823,7 +823,7 @@ def fill_placeholder(config: dict, config_extend: dict):
return config return config
def auto_filter_kwargs(func: Callable) -> Callable: def auto_filter_kwargs(func: Callable, warning=True) -> Callable:
""" """
this will work like a decoration function this will work like a decoration function
@@ -846,7 +846,8 @@ def auto_filter_kwargs(func: Callable) -> Callable:
for k, v in kwargs.items(): for k, v in kwargs.items():
# if `func` don't accept variable keyword arguments like `**kwargs` and have not according named arguments # if `func` don't accept variable keyword arguments like `**kwargs` and have not according named arguments
if spec.varkw is None and k not in spec.args: if spec.varkw is None and k not in spec.args:
log.warning(f"The parameter `{k}` with value `{v}` is ignored.") if warning:
log.warning(f"The parameter `{k}` with value `{v}` is ignored.")
else: else:
new_kwargs[k] = v new_kwargs[k] = v
return func(*args, **new_kwargs) return func(*args, **new_kwargs)

View File

@@ -20,6 +20,9 @@ def experiment_exit_handler():
The `atexit` handler should be put in the last, since, as long as the program ends, it will be called. The `atexit` handler should be put in the last, since, as long as the program ends, it will be called.
Thus, if any exception or user interuption occurs beforehead, we should handle them first. Once `R` is Thus, if any exception or user interuption occurs beforehead, we should handle them first. Once `R` is
ended, another call of `R.end_exp` will not take effect. ended, another call of `R.end_exp` will not take effect.
Limitations:
- If pdb is used in the your program, excepthook will not be triggered when it ends. The status will be finished
""" """
sys.excepthook = experiment_exception_hook # handle uncaught exception sys.excepthook = experiment_exception_hook # handle uncaught exception
atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends

View File

@@ -79,23 +79,14 @@ class TestDataset(TestAutoData):
# 3) get both index and data # 3) get both index and data
# NOTE: We don't want to reply on pytorch, so this test can't be included. It is just a example # NOTE: We don't want to reply on pytorch, so this test can't be included. It is just a example
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from qlib.model.utils import IndexSampler
class IdxSampler:
def __init__(self, sampler):
self.sampler = sampler
def __getitem__(self, i: int):
return self.sampler[i], i
def __len__(self):
return len(self.sampler)
i = len(tsds) - 1 i = len(tsds) - 1
idx = tsds.get_index() idx = tsds.get_index()
tsds[i] tsds[i]
idx[i] idx[i]
s_w_i = IdxSampler(tsds) s_w_i = IndexSampler(tsds)
test_loader = DataLoader(s_w_i) test_loader = DataLoader(s_w_i)
s_w_i[3] s_w_i[3]