diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py index 28e2c80ba..d3dbfa1f7 100644 --- a/qlib/rl/utils/finite_env.py +++ b/qlib/rl/utils/finite_env.py @@ -56,11 +56,10 @@ def fill_invalid(obj: int | float | bool | np.ndarray | dict | list | tuple) -> def is_invalid(arr: int | float | bool | np.ndarray | dict | list | tuple) -> bool: - if hasattr(arr, "dtype"): - dtype = getattr(arr, "dtype") - if np.issubdtype(dtype, np.floating): + if isinstance(arr, np.ndarray): + if np.issubdtype(arr.dtype, np.floating): return np.isnan(arr).all() - return (np.iinfo(dtype).max == arr).all() + return (np.iinfo(arr.dtype).max == arr).all() if isinstance(arr, dict): return all(is_invalid(o) for o in arr.values()) if isinstance(arr, (list, tuple)):