1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix(security): use RestrictedUnpickler in load_instance (#2153)

* fix(security): enforce RestrictedUnpickler for load_instance to prevent unsafe pickle deserialization

* fix: lint error
This commit is contained in:
Linlang
2026-03-10 20:45:38 +08:00
committed by GitHub
parent 2fb9380b34
commit 3097dcc995
59 changed files with 38 additions and 57 deletions

View File

@@ -76,8 +76,11 @@ jobs:
run: |
make mypy
# Due to issues that cannot be automatically fixed when running `nbqa black . -l 120 --check --diff` on Jupyter notebooks,
# we reverted to a version of `black` earlier than 26.1.0 before performing the checks.
- name: Check Qlib ipynb with nbqa
run: |
python -m pip install "black<26.1"
make nbqa
- name: Test data downloads

View File

@@ -23,7 +23,6 @@ import sys
from importlib.metadata import version as ver
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.

View File

@@ -19,7 +19,6 @@ from qlib.model.base import ModelFT
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
# To register new datasets, please add them here.
ALLOW_DATASET = ["Alpha158", "Alpha360"]
# To register new datasets, please add their configurations here.

View File

@@ -8,7 +8,6 @@ import pandas as pd
from qlib.data.dataset import DatasetH
device = "cuda" if torch.cuda.is_available() else "cpu"

View File

@@ -1,9 +1,10 @@
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from qlib.utils.pickle_utils import restricted_pickle_load
sns.set(color_codes=True)
plt.rcParams["font.sans-serif"] = "SimHei"
plt.rcParams["axes.unicode_minus"] = False
@@ -18,7 +19,7 @@ from tqdm.auto import tqdm
# +
with open("./internal_data_s20.pkl", "rb") as f:
data = pickle.load(f)
data = restricted_pickle_load(f)
data.data_ic_df.columns.names = ["start_date", "end_date"]
@@ -52,7 +53,7 @@ pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean(
# +
with open("./tasks_s20.pkl", "rb") as f:
tasks = pickle.load(f)
tasks = restricted_pickle_load(f)
task_df = {}
for t in tasks:

View File

@@ -4,11 +4,11 @@
import fire
import qlib
import pickle
from qlib.constant import REG_CN
from qlib.config import HIGH_FREQ_CONFIG
from qlib.utils import init_instance_by_config
from qlib.utils.pickle_utils import restricted_pickle_load
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.ops import Operators
from qlib.data.data import Cal
@@ -125,10 +125,10 @@ class HighfreqWorkflow:
del dataset, dataset_backtest
##=============reload dataset=============
with open("dataset.pkl", "rb") as file_dataset:
dataset = pickle.load(file_dataset)
dataset = restricted_pickle_load(file_dataset)
with open("dataset_backtest.pkl", "rb") as file_dataset_backtest:
dataset_backtest = pickle.load(file_dataset_backtest)
dataset_backtest = restricted_pickle_load(file_dataset_backtest)
self._prepare_calender_cache()
##=============reinit dataset=============

View File

@@ -9,7 +9,6 @@ from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_GBDT_TASK
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir

View File

@@ -95,7 +95,6 @@ pos 0.000000
[1706497:MainThread](2021-12-07 14:08:30,627) INFO - qlib.timer - [log.py:113] - Time cost: 0.014s | waiting `async_log` Done
"""
from copy import deepcopy
import qlib
import fire

View File

@@ -7,6 +7,7 @@ There are two parts including first_train and update_online_pred.
Firstly, we will finish the training and set the trained models to the `online` models.
Next, we will finish updating online predictions.
"""
import copy
import fire
import qlib

View File

@@ -6,6 +6,7 @@ NOTE:
- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:
- Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier
"""
from datetime import date, datetime as dt
import os
from pathlib import Path

View File

@@ -1,13 +1,15 @@
import pickle
import os
import pandas as pd
from tqdm import tqdm
from qlib.utils.pickle_utils import restricted_pickle_load
for tag in ["test", "valid"]:
files = os.listdir(os.path.join("data/orders/", tag))
dfs = []
for f in tqdm(files):
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
with open(os.path.join("data/orders/", tag, f), "rb") as fr:
df = restricted_pickle_load(fr)
df = df.drop(["$close0"], axis=1)
dfs.append(df)

View File

@@ -3,12 +3,12 @@
import qlib
import fire
import pickle
from datetime import datetime
from qlib.constant import REG_CN
from qlib.data.dataset.handler import DataHandlerLP
from qlib.utils import init_instance_by_config
from qlib.utils.pickle_utils import restricted_pickle_load
from qlib.tests.data import GetData
@@ -42,7 +42,7 @@ class RollingDataWorkflow:
def _load_pre_handler(self, path):
with open(path, "rb") as file_dataset:
pre_handler = pickle.load(file_dataset)
pre_handler = restricted_pickle_load(file_dataset)
return pre_handler
def rolling_process(self):

View File

@@ -7,6 +7,7 @@ Qlib provides two kinds of interfaces.
The interface of (1) is `qrun XXX.yaml`. The interface of (2) is script like this, which nearly does the same thing as `qrun XXX.yaml`
"""
import qlib
from qlib.constant import REG_CN
from qlib.utils import init_instance_by_config, flatten_dict
@@ -15,7 +16,6 @@ from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir

View File

@@ -69,10 +69,9 @@ rl = [
"torch",
"numpy<2.0.0",
]
# We exclude black version 26.1.0 due to known issues with nbqa when formatting Jupyter notebooks,
# which can cause false-positive --check results and inconsistent notebook formatting.
lint = [
"black!=26.1.0",
"black",
"pylint",
"mypy<1.5.0",
"flake8",

View File

@@ -18,7 +18,6 @@ from tqdm.auto import tqdm
from ..utils.time import Freq
PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]

View File

@@ -4,6 +4,5 @@
import fire
from qlib.tests.data import GetData
if __name__ == "__main__":
fire.Fire(GetData)

View File

@@ -10,6 +10,7 @@ Two modes are supported
- server
"""
from __future__ import annotations
import os

View File

@@ -11,7 +11,6 @@ from qlib.utils import init_instance_by_config
from qlib.data.dataset import DatasetH
device = "cuda" if torch.cuda.is_available() else "cpu"

View File

@@ -20,7 +20,6 @@ from ..data import D
from ..config import C
from ..data.dataset.utils import get_level_index
logger = get_module_logger("Evaluate")

View File

@@ -3,5 +3,4 @@
from .data_selection import MetaTaskDS, MetaDatasetDS, MetaModelDS
__all__ = ["MetaTaskDS", "MetaDatasetDS", "MetaModelDS"]

View File

@@ -4,5 +4,4 @@
from .dataset import MetaDatasetDS, MetaTaskDS
from .model import MetaModelDS
__all__ = ["MetaDatasetDS", "MetaTaskDS", "MetaModelDS"]

View File

@@ -317,7 +317,7 @@ class TabnetModel(Model):
feature = x_train_values.float().to(self.device)
label = y_train_values.float().to(self.device)
priors = 1 - S_mask
(vec, sparse_loss) = self.tabnet_model(feature, priors)
vec, sparse_loss = self.tabnet_model(feature, priors)
f = self.tabnet_decoder(vec)
loss = self.pretrain_loss_fn(label, f, S_mask)
@@ -348,7 +348,7 @@ class TabnetModel(Model):
S_mask = S_mask.to(self.device)
priors = 1 - S_mask
with torch.no_grad():
(vec, sparse_loss) = self.tabnet_model(feature, priors)
vec, sparse_loss = self.tabnet_model(feature, priors)
f = self.tabnet_decoder(vec)
loss = self.pretrain_loss_fn(label, f, S_mask)

View File

@@ -12,6 +12,7 @@ from ...data import D
from ...config import C
from ...log import get_module_logger
from ...utils import get_next_trading_date
from ...utils.pickle_utils import restricted_pickle_load
from ...backtest.exchange import Exchange
log = get_module_logger("utils")
@@ -30,7 +31,7 @@ def load_instance(file_path):
if not file_path.exists():
raise ValueError("Cannot find file {}".format(file_path))
with file_path.open("rb") as fr:
instance = pickle.load(fr)
instance = restricted_pickle_load(fr)
return instance

View File

@@ -3,5 +3,4 @@
from .analysis_model_performance import model_performance_graph
__all__ = ["model_performance_graph"]

View File

@@ -7,5 +7,4 @@ from .report import report_graph
from .rank_label import rank_label_graph
from .risk_analysis import risk_analysis_graph
__all__ = ["cumulative_return_graph", "score_ic_graph", "report_graph", "rank_label_graph", "risk_analysis_graph"]

View File

@@ -12,6 +12,7 @@ Here is an example.
fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)
"""
import pandas as pd
import numpy as np
from qlib.contrib.report.data.base import FeaAnalyser

View File

@@ -7,6 +7,7 @@ Assumptions
- The analyse each feature individually
"""
import pandas as pd
from qlib.log import TimeInspector
from qlib.contrib.report.utils import sub_fig_generator

View File

@@ -16,7 +16,6 @@ from .rule_strategy import (
from .cost_control import SoftTopkStrategy
__all__ = [
"TopkDropoutStrategy",
"WeightStrategyBase",

View File

@@ -5,5 +5,4 @@ from .base import BaseOptimizer
from .optimizer import PortfolioOptimizer
from .enhanced_indexing import EnhancedIndexingOptimizer
__all__ = ["BaseOptimizer", "PortfolioOptimizer", "EnhancedIndexingOptimizer"]

View File

@@ -9,7 +9,6 @@ from typing import Union, Optional, Dict, Any, List
from qlib.log import get_module_logger
from .base import BaseOptimizer
logger = get_module_logger("EnhancedIndexingOptimizer")

View File

@@ -4,6 +4,7 @@
"""
This order generator is for strategies based on WeightStrategyBase
"""
from ...backtest.position import Position
from ...backtest.exchange import Exchange

View File

@@ -5,6 +5,7 @@ This module is not a necessary part of Qlib.
They are just some tools for convenience
It is should not imported into the core part of qlib
"""
import torch
import numpy as np
import pandas as pd

View File

@@ -13,7 +13,6 @@ import yaml
from .config import TunerConfigManager
args_parser = argparse.ArgumentParser(prog="tuner")
args_parser.add_argument(
"-c",

View File

@@ -6,7 +6,6 @@
from hyperopt import hp
TopkAmountStrategySpace = {
"topk": hp.choice("topk", [30, 35, 40]),
"buffer_margin": hp.choice("buffer_margin", [200, 250, 300]),

View File

@@ -3,5 +3,4 @@
from .record_temp import MultiSegRecord
from .record_temp import SignalMseRecord
__all__ = ["MultiSegRecord", "SignalMseRecord"]

View File

@@ -36,7 +36,6 @@ from .cache import (
MemoryCalendarCache,
)
__all__ = [
"D",
"CalendarProvider",

View File

@@ -19,7 +19,6 @@ from .loader import DataLoader
from . import processor as processor_module
from . import loader as data_loader_module
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]

View File

@@ -13,6 +13,7 @@ The calculation of both <period_time, feature> and <observe_time, feature> data
2) concatenate all th collasped data, we will get data with format <observe_time, feature>.
Qlib will use the operator `P` to perform the collapse.
"""
import numpy as np
import pandas as pd
from qlib.data.ops import ElemOperator

View File

@@ -3,5 +3,4 @@
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT
__all__ = ["CalendarStorage", "InstrumentStorage", "FeatureStorage", "CalVT", "InstVT", "InstKT"]

View File

@@ -5,5 +5,4 @@ import warnings
from .base import Model
__all__ = ["Model", "warnings"]

View File

@@ -4,5 +4,4 @@
from .task import MetaTask
from .dataset import MetaTaskDataset
__all__ = ["MetaTask", "MetaTaskDataset"]

View File

@@ -6,7 +6,6 @@ from .poet import POETCovEstimator
from .shrink import ShrinkCovEstimator
from .structured import StructuredCovEstimator
__all__ = [
"RiskModel",
"POETCovEstimator",

View File

@@ -9,7 +9,6 @@ import tempfile
from importlib import import_module
from ruamel.yaml import YAML
DELETE_KEY = "_delete_"

View File

@@ -3,6 +3,7 @@
"""
This module covers some utility functions that operate on data or basic object
"""
from copy import deepcopy
from typing import List, Union

View File

@@ -3,6 +3,7 @@
"""
Time related utils are compiled in this script
"""
import bisect
from datetime import datetime, time, date, timedelta
from typing import List, Optional, Tuple, Union
@@ -14,7 +15,6 @@ import pandas as pd
from qlib.config import C
from qlib.constant import REG_CN, REG_TW, REG_US
CN_TIME = [
datetime.strptime("9:30", "%H:%M"),
datetime.strptime("11:30", "%H:%M"),

View File

@@ -16,7 +16,6 @@ from .recorder import Recorder
from ..log import get_module_logger
from ..utils.exceptions import ExpAlreadyExistError
logger = get_module_logger("workflow")

View File

@@ -22,7 +22,6 @@ from ..utils.data import deepcopy_basic_type
from ..utils.exceptions import QlibException
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
logger = get_module_logger("workflow", logging.INFO)

View File

@@ -3,6 +3,7 @@
"""
TaskGenerator module can generate many tasks based on TaskGen and some task templates.
"""
import abc
import copy
import pandas as pd

View File

@@ -12,6 +12,7 @@ A task in TaskManager consists of 3 parts
- tasks status: the status of the task
- tasks result: A user can get the task with the task description and task result.
"""
import concurrent
import pickle
import time

View File

@@ -22,7 +22,6 @@ from data_collector.index import IndexBase
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
from data_collector.utils import get_instruments
NEW_COMPANIES_URL = (
"https://oss-ch.csindex.com.cn/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
)

View File

@@ -19,7 +19,6 @@ from time import mktime
from datetime import datetime as dt
import time
_CG_CRYPTO_SYMBOLS = None

View File

@@ -16,7 +16,6 @@ from tqdm import tqdm
from loguru import logger
from fake_useragent import UserAgent
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
@@ -24,7 +23,6 @@ from data_collector.index import IndexBase
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
from data_collector.utils import get_instruments
WIKI_URL = "https://en.wikipedia.org/wiki"
WIKI_INDEX_NAME_MAP = {

View File

@@ -21,6 +21,8 @@ from functools import partial
from concurrent.futures import ProcessPoolExecutor
from bs4 import BeautifulSoup
from qlib.utils.pickle_utils import restricted_pickle_load
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231"
@@ -265,7 +267,7 @@ def get_hs_stock_symbols() -> list:
symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)
if symbol_cache_path.exists():
with symbol_cache_path.open("rb") as fp:
cache_symbols = pickle.load(fp)
cache_symbols = restricted_pickle_load(fp)
symbols |= cache_symbols
with symbol_cache_path.open("wb") as fp:
pickle.dump(symbols, fp)

View File

@@ -4,6 +4,5 @@
import fire
from qlib.tests.data import GetData
if __name__ == "__main__":
fire.Fire(GetData)

View File

@@ -3,7 +3,6 @@ import os
import numpy
from setuptools import Extension, setup
NUMPY_INCLUDE = numpy.get_include()

View File

@@ -17,7 +17,6 @@ from qlib.rl.utils.finite_env import (
generate_nan_observation,
)
_test_space = gym.spaces.Dict(
{
"sensors": gym.spaces.Dict(

View File

@@ -13,7 +13,6 @@ from qlib.workflow import R
from qlib.tests import TestAutoData
from qlib.tests.config import GBDT_MODEL, get_dataset_config, CSI300_MARKET
CSI300_GBDT_TASK = {
"model": GBDT_MODEL,
"dataset": get_dataset_config(

View File

@@ -16,7 +16,6 @@ sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
from dump_bin import DumpDataAll, DumpDataFix
DATA_DIR = Path(__file__).parent.joinpath("test_dump_data")
SOURCE_DIR = DATA_DIR.joinpath("source")
SOURCE_DIR.mkdir(exist_ok=True, parents=True)

View File

@@ -19,7 +19,6 @@ from dump_pit import DumpPitData
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit")))
from collector import Run
pd.set_option("display.width", 1000)
pd.set_option("display.max_columns", None)