1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix(security): restrict pickle deserialization to safe classes (#2076)

This commit is contained in:
Linlang
2025-12-30 11:00:51 +08:00
committed by GitHub
parent 3472e82d5c
commit 477160e4ac
3 changed files with 180 additions and 8 deletions

View File

@@ -2,17 +2,18 @@
# Licensed under the MIT License.
from __future__ import annotations
import os
from pathlib import Path
from typing import cast, List
from typing import List, cast
import cachetools
import pandas as pd
import pickle
import os
from qlib.backtest import Exchange, Order
from qlib.backtest.decision import TradeRange, TradeRangeByTime
from qlib.constant import EPS_T
from qlib.utils.pickle_utils import restricted_pickle_load
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
@@ -162,7 +163,7 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
with open(path, "rb") as fstream:
dataset = pickle.load(fstream)
dataset = restricted_pickle_load(fstream)
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
if index_only:

171
qlib/utils/pickle_utils.py Normal file
View File

@@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Secure pickle utilities to prevent arbitrary code execution through deserialization.
This module provides a secure alternative to pickle.load() and pickle.loads()
that restricts deserialization to a whitelist of safe classes.
"""
import io
import pickle
from typing import Any, BinaryIO, Set, Tuple
# Whitelist of safe classes that are allowed to be unpickled
# These are common data types used in qlib that should be safe to deserialize
SAFE_PICKLE_CLASSES: Set[Tuple[str, str]] = {
# python builtins
("builtins", "slice"),
("builtins", "range"),
("builtins", "dict"),
("builtins", "list"),
("builtins", "tuple"),
("builtins", "set"),
("builtins", "frozenset"),
("builtins", "bytearray"),
("builtins", "bytes"),
("builtins", "str"),
("builtins", "int"),
("builtins", "float"),
("builtins", "bool"),
("builtins", "complex"),
("builtins", "type"),
("builtins", "property"),
# common utility classes
("datetime", "datetime"),
("datetime", "date"),
("datetime", "time"),
("datetime", "timedelta"),
("datetime", "timezone"),
("decimal", "Decimal"),
("collections", "OrderedDict"),
("collections", "defaultdict"),
("collections", "Counter"),
("collections", "namedtuple"),
("enum", "Enum"),
("pathlib", "Path"),
("pathlib", "PosixPath"),
("pathlib", "WindowsPath"),
("qlib.data.dataset.handler", "DataHandler"),
("qlib.data.dataset.handler", "DataHandlerLP"),
("qlib.data.dataset.loader", "StaticDataLoader"),
}
TRUSTED_MODULE_PREFIXES = (
"pandas",
"numpy",
)
class RestrictedUnpickler(pickle.Unpickler):
"""Custom unpickler that only allows safe classes to be deserialized.
This prevents arbitrary code execution through malicious pickle files by
restricting deserialization to a whitelist of safe classes.
Example:
>>> with open("data.pkl", "rb") as f:
... data = RestrictedUnpickler(f).load()
"""
def find_class(self, module: str, name: str):
"""Override find_class to restrict allowed classes.
Args:
module: Module name of the class
name: Class name
Returns:
The class object if it's in the whitelist
Raises:
pickle.UnpicklingError: If the class is not in the whitelist
"""
if module.startswith(TRUSTED_MODULE_PREFIXES):
return super().find_class(module, name)
# 2. explicit whitelist (qlib internal)
if (module, name) in SAFE_PICKLE_CLASSES:
return super().find_class(module, name)
raise pickle.UnpicklingError(
f"Forbidden class: {module}.{name}. "
f"Only whitelisted classes are allowed for security reasons. "
f"This is to prevent arbitrary code execution through pickle deserialization."
)
def restricted_pickle_load(file: BinaryIO) -> Any:
"""Safely load a pickle file with restricted classes.
This is a drop-in replacement for pickle.load() that prevents
arbitrary code execution by only allowing whitelisted classes.
Args:
file: An opened file object in binary mode
Returns:
The unpickled Python object
Raises:
pickle.UnpicklingError: If the pickle contains forbidden classes
Example:
>>> with open("data.pkl", "rb") as f:
... data = restricted_pickle_load(f)
"""
return RestrictedUnpickler(file).load()
def restricted_pickle_loads(data: bytes) -> Any:
"""Safely load a pickle from bytes with restricted classes.
This is a drop-in replacement for pickle.loads() that prevents
arbitrary code execution by only allowing whitelisted classes.
Args:
data: Bytes object containing pickled data
Returns:
The unpickled Python object
Raises:
pickle.UnpicklingError: If the pickle contains forbidden classes
Example:
>>> data = b'\\x80\\x04\\x95...'
>>> obj = restricted_pickle_loads(data)
"""
file_like = io.BytesIO(data)
return RestrictedUnpickler(file_like).load()
def add_safe_class(module: str, name: str) -> None:
"""Add a class to the whitelist of safe classes for unpickling.
Use this function to extend the whitelist if your code needs to deserialize
additional classes. However, be very careful when adding classes, as this
could potentially introduce security vulnerabilities.
Args:
module: Module name of the class (e.g., 'my_package.my_module')
name: Class name (e.g., 'MyClass')
Warning:
Only add classes that you fully control and trust. Adding arbitrary
classes from external packages could introduce security risks.
Example:
>>> add_safe_class('my_package.models', 'CustomModel')
"""
SAFE_PICKLE_CLASSES.add((module, name))
def get_safe_classes() -> Set[Tuple[str, str]]:
"""Get a copy of the current whitelist of safe classes.
Returns:
A set of (module, name) tuples representing allowed classes
"""
return SAFE_PICKLE_CLASSES.copy()

View File

@@ -1,10 +1,10 @@
import os
import pickle
import shutil
import unittest
from qlib.tests import TestAutoData
from qlib.data import D
from qlib.data.dataset.handler import DataHandlerLP
from qlib.tests import TestAutoData
from qlib.utils.pickle_utils import restricted_pickle_load
class HandlerTests(TestAutoData):
@@ -23,7 +23,7 @@ class HandlerTests(TestAutoData):
dh.to_pickle(fname, dump_all=True)
with open(fname, "rb") as f:
dh_d = pickle.load(f)
dh_d = restricted_pickle_load(f)
self.assertTrue(dh_d._data.equals(df))
self.assertTrue(dh_d._infer is dh_d._data)