mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix(security): restrict pickle deserialization to safe classes (#2076)
This commit is contained in:
@@ -2,17 +2,18 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast, List
|
||||
from typing import List, cast
|
||||
|
||||
import cachetools
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import EPS_T
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
|
||||
|
||||
@@ -162,7 +163,7 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = pickle.load(fstream)
|
||||
dataset = restricted_pickle_load(fstream)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
if index_only:
|
||||
|
||||
171
qlib/utils/pickle_utils.py
Normal file
171
qlib/utils/pickle_utils.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Secure pickle utilities to prevent arbitrary code execution through deserialization.
|
||||
|
||||
This module provides a secure alternative to pickle.load() and pickle.loads()
|
||||
that restricts deserialization to a whitelist of safe classes.
|
||||
"""
|
||||
|
||||
import io
|
||||
import pickle
|
||||
from typing import Any, BinaryIO, Set, Tuple
|
||||
|
||||
# Whitelist of safe classes that are allowed to be unpickled
|
||||
# These are common data types used in qlib that should be safe to deserialize
|
||||
SAFE_PICKLE_CLASSES: Set[Tuple[str, str]] = {
|
||||
# python builtins
|
||||
("builtins", "slice"),
|
||||
("builtins", "range"),
|
||||
("builtins", "dict"),
|
||||
("builtins", "list"),
|
||||
("builtins", "tuple"),
|
||||
("builtins", "set"),
|
||||
("builtins", "frozenset"),
|
||||
("builtins", "bytearray"),
|
||||
("builtins", "bytes"),
|
||||
("builtins", "str"),
|
||||
("builtins", "int"),
|
||||
("builtins", "float"),
|
||||
("builtins", "bool"),
|
||||
("builtins", "complex"),
|
||||
("builtins", "type"),
|
||||
("builtins", "property"),
|
||||
# common utility classes
|
||||
("datetime", "datetime"),
|
||||
("datetime", "date"),
|
||||
("datetime", "time"),
|
||||
("datetime", "timedelta"),
|
||||
("datetime", "timezone"),
|
||||
("decimal", "Decimal"),
|
||||
("collections", "OrderedDict"),
|
||||
("collections", "defaultdict"),
|
||||
("collections", "Counter"),
|
||||
("collections", "namedtuple"),
|
||||
("enum", "Enum"),
|
||||
("pathlib", "Path"),
|
||||
("pathlib", "PosixPath"),
|
||||
("pathlib", "WindowsPath"),
|
||||
("qlib.data.dataset.handler", "DataHandler"),
|
||||
("qlib.data.dataset.handler", "DataHandlerLP"),
|
||||
("qlib.data.dataset.loader", "StaticDataLoader"),
|
||||
}
|
||||
|
||||
|
||||
TRUSTED_MODULE_PREFIXES = (
|
||||
"pandas",
|
||||
"numpy",
|
||||
)
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
"""Custom unpickler that only allows safe classes to be deserialized.
|
||||
|
||||
This prevents arbitrary code execution through malicious pickle files by
|
||||
restricting deserialization to a whitelist of safe classes.
|
||||
|
||||
Example:
|
||||
>>> with open("data.pkl", "rb") as f:
|
||||
... data = RestrictedUnpickler(f).load()
|
||||
"""
|
||||
|
||||
def find_class(self, module: str, name: str):
|
||||
"""Override find_class to restrict allowed classes.
|
||||
|
||||
Args:
|
||||
module: Module name of the class
|
||||
name: Class name
|
||||
|
||||
Returns:
|
||||
The class object if it's in the whitelist
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the class is not in the whitelist
|
||||
"""
|
||||
if module.startswith(TRUSTED_MODULE_PREFIXES):
|
||||
return super().find_class(module, name)
|
||||
|
||||
# 2. explicit whitelist (qlib internal)
|
||||
if (module, name) in SAFE_PICKLE_CLASSES:
|
||||
return super().find_class(module, name)
|
||||
|
||||
raise pickle.UnpicklingError(
|
||||
f"Forbidden class: {module}.{name}. "
|
||||
f"Only whitelisted classes are allowed for security reasons. "
|
||||
f"This is to prevent arbitrary code execution through pickle deserialization."
|
||||
)
|
||||
|
||||
|
||||
def restricted_pickle_load(file: BinaryIO) -> Any:
|
||||
"""Safely load a pickle file with restricted classes.
|
||||
|
||||
This is a drop-in replacement for pickle.load() that prevents
|
||||
arbitrary code execution by only allowing whitelisted classes.
|
||||
|
||||
Args:
|
||||
file: An opened file object in binary mode
|
||||
|
||||
Returns:
|
||||
The unpickled Python object
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the pickle contains forbidden classes
|
||||
|
||||
Example:
|
||||
>>> with open("data.pkl", "rb") as f:
|
||||
... data = restricted_pickle_load(f)
|
||||
"""
|
||||
return RestrictedUnpickler(file).load()
|
||||
|
||||
|
||||
def restricted_pickle_loads(data: bytes) -> Any:
|
||||
"""Safely load a pickle from bytes with restricted classes.
|
||||
|
||||
This is a drop-in replacement for pickle.loads() that prevents
|
||||
arbitrary code execution by only allowing whitelisted classes.
|
||||
|
||||
Args:
|
||||
data: Bytes object containing pickled data
|
||||
|
||||
Returns:
|
||||
The unpickled Python object
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the pickle contains forbidden classes
|
||||
|
||||
Example:
|
||||
>>> data = b'\\x80\\x04\\x95...'
|
||||
>>> obj = restricted_pickle_loads(data)
|
||||
"""
|
||||
file_like = io.BytesIO(data)
|
||||
return RestrictedUnpickler(file_like).load()
|
||||
|
||||
|
||||
def add_safe_class(module: str, name: str) -> None:
|
||||
"""Add a class to the whitelist of safe classes for unpickling.
|
||||
|
||||
Use this function to extend the whitelist if your code needs to deserialize
|
||||
additional classes. However, be very careful when adding classes, as this
|
||||
could potentially introduce security vulnerabilities.
|
||||
|
||||
Args:
|
||||
module: Module name of the class (e.g., 'my_package.my_module')
|
||||
name: Class name (e.g., 'MyClass')
|
||||
|
||||
Warning:
|
||||
Only add classes that you fully control and trust. Adding arbitrary
|
||||
classes from external packages could introduce security risks.
|
||||
|
||||
Example:
|
||||
>>> add_safe_class('my_package.models', 'CustomModel')
|
||||
"""
|
||||
SAFE_PICKLE_CLASSES.add((module, name))
|
||||
|
||||
|
||||
def get_safe_classes() -> Set[Tuple[str, str]]:
|
||||
"""Get a copy of the current whitelist of safe classes.
|
||||
|
||||
Returns:
|
||||
A set of (module, name) tuples representing allowed classes
|
||||
"""
|
||||
return SAFE_PICKLE_CLASSES.copy()
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import unittest
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
|
||||
class HandlerTests(TestAutoData):
|
||||
@@ -23,7 +23,7 @@ class HandlerTests(TestAutoData):
|
||||
dh.to_pickle(fname, dump_all=True)
|
||||
|
||||
with open(fname, "rb") as f:
|
||||
dh_d = pickle.load(f)
|
||||
dh_d = restricted_pickle_load(f)
|
||||
|
||||
self.assertTrue(dh_d._data.equals(df))
|
||||
self.assertTrue(dh_d._infer is dh_d._data)
|
||||
|
||||
Reference in New Issue
Block a user