mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix: use semantic version comparison for PyTorch scheduler compatibility (#2094)
This commit is contained in:
@@ -10,6 +10,7 @@ import os
|
||||
import gc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from packaging import version
|
||||
from typing import Callable, Optional, Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
@@ -148,7 +149,7 @@ class DNNModelPytorch(Model):
|
||||
if scheduler == "default":
|
||||
# In torch version 2.7.0, the verbose parameter has been removed. Reference Link:
|
||||
# https://github.com/pytorch/pytorch/pull/147301/files#diff-036a7470d5307f13c9a6a51c3a65dd014f00ca02f476c545488cd856bea9bcf2L1313
|
||||
if str(torch.__version__).split("+", maxsplit=1)[0] <= "2.6.0":
|
||||
if version.parse(str(torch.__version__).split("+", maxsplit=1)[0]) <= version.parse("2.6.0"):
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # pylint: disable=E1123
|
||||
self.train_optimizer,
|
||||
|
||||
Reference in New Issue
Block a user