mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-13 01:11:00 +08:00
Compare commits
2 Commits
main
...
migrate_gy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2276552c1f | ||
|
|
9eee6d33b0 |
4
.github/workflows/test_qlib_from_pip.yml
vendored
4
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -1,9 +1,5 @@
|
||||
name: Test qlib from pip
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
7
.github/workflows/test_qlib_from_source.yml
vendored
7
.github/workflows/test_qlib_from_source.yml
vendored
@@ -1,9 +1,5 @@
|
||||
name: Test qlib from source
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
@@ -80,11 +76,8 @@ jobs:
|
||||
run: |
|
||||
make mypy
|
||||
|
||||
# Due to issues that cannot be automatically fixed when running `nbqa black . -l 120 --check --diff` on Jupyter notebooks,
|
||||
# we reverted to a version of `black` earlier than 26.1.0 before performing the checks.
|
||||
- name: Check Qlib ipynb with nbqa
|
||||
run: |
|
||||
python -m pip install "black<26.1"
|
||||
make nbqa
|
||||
|
||||
- name: Test data downloads
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
name: Test qlib from source slow
|
||||
|
||||
concurrency:
|
||||
cancel-in-progress: true
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -22,7 +22,6 @@ dist/
|
||||
qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
qlib/_version.py
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
|
||||
25
Makefile
25
Makefile
@@ -74,37 +74,34 @@ prerequisite:
|
||||
|
||||
# Install the package in editable mode.
|
||||
dependencies:
|
||||
python -m pip install --no-cache-dir -e .
|
||||
python -m pip install -e .
|
||||
|
||||
lightgbm:
|
||||
python -m pip install --no-cache-dir lightgbm --prefer-binary
|
||||
python -m pip install lightgbm --prefer-binary
|
||||
|
||||
rl:
|
||||
python -m pip install --no-cache-dir -e .[rl]
|
||||
python -m pip install -e .[rl]
|
||||
|
||||
develop:
|
||||
python -m pip install --no-cache-dir -e .[dev]
|
||||
python -m pip install -e .[dev]
|
||||
|
||||
lint:
|
||||
python -m pip install --no-cache-dir -e .[lint]
|
||||
python -m pip install -e .[lint]
|
||||
|
||||
docs:
|
||||
python -m pip install --no-cache-dir -e .[docs]
|
||||
python -m pip install -e .[docs]
|
||||
|
||||
package:
|
||||
python -m pip install --no-cache-dir -e .[package]
|
||||
python -m pip install -e .[package]
|
||||
|
||||
test:
|
||||
python -m pip install --no-cache-dir -e .[test]
|
||||
python -m pip install -e .[test]
|
||||
|
||||
analysis:
|
||||
python -m pip install --no-cache-dir -e .[analysis]
|
||||
|
||||
client:
|
||||
python -m pip install --no-cache-dir -e .[client]
|
||||
python -m pip install -e .[analysis]
|
||||
|
||||
all:
|
||||
python -m pip install --no-cache-dir -e .[pywinpty,dev,lint,docs,package,test,analysis,rl]
|
||||
python -m pip install -e .[pywinpty,dev,lint,docs,package,test,analysis,rl]
|
||||
|
||||
install: prerequisite dependencies
|
||||
|
||||
@@ -116,7 +113,7 @@ dev: prerequisite all
|
||||
|
||||
# Check lint with black.
|
||||
black:
|
||||
black . -l 120 --check --diff --exclude qlib/_version.py
|
||||
black . -l 120 --check --diff
|
||||
|
||||
# Check code folder with pylint.
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
|
||||
11
README.md
11
README.md
@@ -17,13 +17,14 @@ We are excited to announce the release of **RD-Agent**📢, a powerful tool that
|
||||
|
||||
RD-Agent is now available on [GitHub](https://github.com/microsoft/RD-Agent), and we welcome your star🌟!
|
||||
|
||||
To learn more, please visit the [RD-Agent repository](https://github.com/microsoft/RD-Agent). We have prepared several public demo videos for you:
|
||||
To learn more, please visit our [♾️Demo page](https://rdagent.azurewebsites.net/). Here, you will find demo videos in both English and Chinese to help you better understand the scenario and usage of RD-Agent.
|
||||
|
||||
We have prepared several demo videos for you:
|
||||
| Scenario | Demo video (English) | Demo video (中文) |
|
||||
| -- | ------ | ------ |
|
||||
| Quant Factor Mining | [YouTube](https://www.youtube.com/watch?v=X4DK2QZKaKY&t=6s) | [YouTube](https://www.youtube.com/watch?v=X4DK2QZKaKY&t=6s) |
|
||||
| Quant Factor Mining from reports | [YouTube](https://www.youtube.com/watch?v=ECLTXVcSx-c) | [YouTube](https://www.youtube.com/watch?v=ECLTXVcSx-c) |
|
||||
| Quant Model Optimization | [YouTube](https://www.youtube.com/watch?v=dm0dWL49Bc0&t=104s) | [YouTube](https://www.youtube.com/watch?v=dm0dWL49Bc0&t=104s) |
|
||||
| Quant Factor Mining | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/factor_loop?lang=zh) |
|
||||
| Quant Factor Mining from reports | [Link](https://rdagent.azurewebsites.net/report_factor?lang=en) | [Link](https://rdagent.azurewebsites.net/report_factor?lang=zh) |
|
||||
| Quant Model Optimization | [Link](https://rdagent.azurewebsites.net/model_loop?lang=en) | [Link](https://rdagent.azurewebsites.net/model_loop?lang=zh) |
|
||||
|
||||
- 📃**Paper**: [R&D-Agent-Quant: A Multi-Agent Framework for Data-Centric Factors and Model Joint Optimization](https://arxiv.org/abs/2505.15155)
|
||||
- 👾**Code**: https://github.com/microsoft/RD-Agent/
|
||||
@@ -323,7 +324,7 @@ We recommend users to prepare their own data if they have a high-quality dataset
|
||||
```
|
||||
2. Start a new Docker container
|
||||
```bash
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app pyqlib/qlib_image_stable:stable
|
||||
docker run -it --name <container name> -v <Mounted local directory>:/app qlib_image_stable
|
||||
```
|
||||
3. At this point you are in the docker environment and can run the qlib scripts. An example:
|
||||
```bash
|
||||
|
||||
@@ -42,7 +42,7 @@ Example
|
||||
|
||||
.. math::
|
||||
|
||||
DEA = EMA(DIF, 9)
|
||||
DEA = \frac{EMA(DIF, 9)}{CLOSE}
|
||||
|
||||
Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
|
||||
@@ -51,7 +51,7 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data.dataset.loader import QlibDataLoader
|
||||
>> MACD_EXP = '2 * ((EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9))'
|
||||
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
|
||||
>> fields = [MACD_EXP] # MACD
|
||||
>> names = ['MACD']
|
||||
>> labels = ['Ref($close, -2)/Ref($close, -1) - 1'] # label
|
||||
@@ -66,17 +66,17 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
feature label
|
||||
MACD LABEL
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 0.008781 -0.019672
|
||||
SH600004 0.006699 -0.014721
|
||||
SH600006 0.005714 0.002911
|
||||
SH600008 0.000798 0.009818
|
||||
SH600009 0.017015 -0.017758
|
||||
2010-01-04 SH600000 -0.011547 -0.019672
|
||||
SH600004 0.002745 -0.014721
|
||||
SH600006 0.010133 0.002911
|
||||
SH600008 -0.001113 0.009818
|
||||
SH600009 0.025878 -0.017758
|
||||
... ... ...
|
||||
2017-12-29 SZ300124 0.015071 -0.005074
|
||||
SZ300136 -0.015466 0.056352
|
||||
SZ300144 0.013082 0.011853
|
||||
SZ300251 -0.001026 0.021739
|
||||
SZ300315 -0.007559 0.012455
|
||||
2017-12-29 SZ300124 0.007306 -0.005074
|
||||
SZ300136 -0.013492 0.056352
|
||||
SZ300144 -0.000966 0.011853
|
||||
SZ300251 0.004383 0.021739
|
||||
SZ300315 -0.030557 0.012455
|
||||
|
||||
Reference
|
||||
=========
|
||||
|
||||
@@ -21,7 +21,8 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from importlib.metadata import version as ver
|
||||
import pkg_resources
|
||||
|
||||
|
||||
# -- General configuration ------------------------------------------------
|
||||
|
||||
@@ -62,9 +63,9 @@ author = "Microsoft"
|
||||
# built documents.
|
||||
#
|
||||
# The short X.Y version.
|
||||
version = ver("pyqlib")
|
||||
version = pkg_resources.get_distribution("pyqlib").version
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
release = version
|
||||
release = pkg_resources.get_distribution("pyqlib").version
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
|
||||
@@ -129,7 +129,7 @@ For example, it looks quite long and complicated:
|
||||
|
||||
|
||||
But using string is not the only way to implement the expression. You can also implement expression by code.
|
||||
Here is an example which does the same thing as above examples.
|
||||
Here is an exmaple which does the same thing as above examples.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -71,7 +71,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
)
|
||||
|
||||
- Override the `predict` method
|
||||
- The parameters must include the parameter `dataset`, which will be used to get the test dataset.
|
||||
- The parameters must include the parameter `dataset`, which will be userd to get the test dataset.
|
||||
- Return the `prediction score`.
|
||||
- Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
|
||||
- Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
|
||||
|
||||
@@ -19,6 +19,7 @@ from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
# To register new datasets, please add them here.
|
||||
ALLOW_DATASET = ["Alpha158", "Alpha360"]
|
||||
# To register new datasets, please add their configurations here.
|
||||
|
||||
@@ -8,6 +8,7 @@ import pandas as pd
|
||||
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
sns.set(color_codes=True)
|
||||
plt.rcParams["font.sans-serif"] = "SimHei"
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
@@ -19,7 +18,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
# +
|
||||
with open("./internal_data_s20.pkl", "rb") as f:
|
||||
data = restricted_pickle_load(f)
|
||||
data = pickle.load(f)
|
||||
|
||||
data.data_ic_df.columns.names = ["start_date", "end_date"]
|
||||
|
||||
@@ -53,7 +52,7 @@ pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean(
|
||||
|
||||
# +
|
||||
with open("./tasks_s20.pkl", "rb") as f:
|
||||
tasks = restricted_pickle_load(f)
|
||||
tasks = pickle.load(f)
|
||||
|
||||
task_df = {}
|
||||
for t in tasks:
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
import fire
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.config import HIGH_FREQ_CONFIG
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
@@ -125,10 +125,10 @@ class HighfreqWorkflow:
|
||||
del dataset, dataset_backtest
|
||||
##=============reload dataset=============
|
||||
with open("dataset.pkl", "rb") as file_dataset:
|
||||
dataset = restricted_pickle_load(file_dataset)
|
||||
dataset = pickle.load(file_dataset)
|
||||
|
||||
with open("dataset_backtest.pkl", "rb") as file_dataset_backtest:
|
||||
dataset_backtest = restricted_pickle_load(file_dataset_backtest)
|
||||
dataset_backtest = pickle.load(file_dataset_backtest)
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
|
||||
@@ -9,6 +9,7 @@ from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
|
||||
@@ -95,6 +95,7 @@ pos 0.000000
|
||||
[1706497:MainThread](2021-12-07 14:08:30,627) INFO - qlib.timer - [log.py:113] - Time cost: 0.014s | waiting `async_log` Done
|
||||
"""
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
import qlib
|
||||
import fire
|
||||
|
||||
@@ -7,7 +7,6 @@ There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import fire
|
||||
import qlib
|
||||
|
||||
@@ -6,7 +6,6 @@ NOTE:
|
||||
- !!!!!!!!!!!!!!!TODO!!!!!!!!!!!!!!!!!!!:
|
||||
- Its structure is not well designed and very ugly, your contribution is welcome to make importing dataset easier
|
||||
"""
|
||||
|
||||
from datetime import date, datetime as dt
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import pickle
|
||||
import os
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
for tag in ["test", "valid"]:
|
||||
files = os.listdir(os.path.join("data/orders/", tag))
|
||||
dfs = []
|
||||
for f in tqdm(files):
|
||||
with open(os.path.join("data/orders/", tag, f), "rb") as fr:
|
||||
df = restricted_pickle_load(fr)
|
||||
df = pickle.load(open(os.path.join("data/orders/", tag, f), "rb"))
|
||||
df = df.drop(["$close0"], axis=1)
|
||||
dfs.append(df)
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class RollingDataWorkflow:
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = restricted_pickle_load(file_dataset)
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
|
||||
@@ -7,7 +7,6 @@ Qlib provides two kinds of interfaces.
|
||||
|
||||
The interface of (1) is `qrun XXX.yaml`. The interface of (2) is script like this, which nearly does the same thing as `qrun XXX.yaml`
|
||||
"""
|
||||
|
||||
import qlib
|
||||
from qlib.constant import REG_CN
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
@@ -16,6 +15,7 @@ from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
|
||||
@@ -45,7 +45,7 @@ dependencies = [
|
||||
"pymongo",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"gym",
|
||||
"gymnasium<=0.26.2",
|
||||
"cvxpy",
|
||||
"joblib",
|
||||
"matplotlib",
|
||||
@@ -69,7 +69,6 @@ rl = [
|
||||
"torch",
|
||||
"numpy<2.0.0",
|
||||
]
|
||||
|
||||
lint = [
|
||||
"black",
|
||||
"pylint",
|
||||
@@ -102,10 +101,6 @@ analysis = [
|
||||
"plotly",
|
||||
"statsmodels",
|
||||
]
|
||||
client = [
|
||||
"python-socketio<6",
|
||||
"tables",
|
||||
]
|
||||
|
||||
# In the process of releasing a new version, when checking the manylinux package with twine, an error is reported:
|
||||
# InvalidDistribution: Invalid distribution metadata: unrecognized or malformed field 'license-file'
|
||||
@@ -122,4 +117,3 @@ qrun = "qlib.cli.run:run"
|
||||
[tool.setuptools_scm]
|
||||
local_scheme = "no-local-version"
|
||||
version_scheme = "guess-next-dev"
|
||||
write_to = "qlib/_version.py"
|
||||
|
||||
@@ -4,10 +4,7 @@ from pathlib import Path
|
||||
|
||||
from setuptools_scm import get_version
|
||||
|
||||
try:
|
||||
from ._version import version as __version__
|
||||
except ImportError:
|
||||
__version__ = get_version(root="..", relative_to=__file__)
|
||||
__version__ = get_version(root="..", relative_to=__file__)
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import logging
|
||||
import os
|
||||
@@ -143,10 +140,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
_command_log = [line for line in _command_log if _remote_uri in line]
|
||||
if len(_command_log) > 0:
|
||||
for _c in _command_log:
|
||||
if isinstance(_c, str):
|
||||
_temp_mount = _c.split(" ")[2]
|
||||
else:
|
||||
_temp_mount = _c.decode("utf-8").split(" ")[2]
|
||||
_temp_mount = _c.decode("utf-8").split(" ")[2]
|
||||
_temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount
|
||||
if _temp_mount == _mount_path:
|
||||
_is_mount = True
|
||||
|
||||
@@ -18,6 +18,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
from ..utils.time import Freq
|
||||
|
||||
|
||||
PORT_METRIC = Dict[str, Tuple[pd.DataFrame, dict]]
|
||||
INDICATOR_METRIC = Dict[str, Tuple[pd.DataFrame, Indicator]]
|
||||
|
||||
|
||||
@@ -4,5 +4,6 @@
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
||||
|
||||
@@ -87,7 +87,7 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
"""
|
||||
This is a Qlib CLI entrance.
|
||||
User can run the whole Quant research workflow defined by a configure file
|
||||
- the code is located here ``qlib/cli/run.py``
|
||||
- the code is located here ``qlib/cli/run.py`
|
||||
|
||||
User can specify a base_config file in your workflow.yml file by adding "BASE_CONFIG_PATH".
|
||||
Qlib will load the configuration in BASE_CONFIG_PATH first, and the user only needs to update the custom fields
|
||||
|
||||
@@ -10,7 +10,6 @@ Two modes are supported
|
||||
- server
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
@@ -11,6 +11,7 @@ from qlib.utils import init_instance_by_config
|
||||
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from ..data import D
|
||||
from ..config import C
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
|
||||
|
||||
@@ -3,4 +3,5 @@
|
||||
|
||||
from .data_selection import MetaTaskDS, MetaDatasetDS, MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaTaskDS", "MetaDatasetDS", "MetaModelDS"]
|
||||
|
||||
@@ -4,4 +4,5 @@
|
||||
from .dataset import MetaDatasetDS, MetaTaskDS
|
||||
from .model import MetaModelDS
|
||||
|
||||
|
||||
__all__ = ["MetaDatasetDS", "MetaTaskDS", "MetaModelDS"]
|
||||
|
||||
@@ -51,7 +51,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
w = reweighter.reweight(df)
|
||||
else:
|
||||
raise ValueError("Unsupported reweighter type.")
|
||||
ds_l.append((lgb.Dataset(x.values, label=y, weight=w, free_raw_data=False), key))
|
||||
ds_l.append((lgb.Dataset(x.values, label=y, weight=w), key))
|
||||
return ds_l
|
||||
|
||||
def fit(
|
||||
@@ -109,10 +109,8 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
ds_l = self._prepare_data(dataset, reweighter)
|
||||
dtrain, _ = ds_l[0]
|
||||
|
||||
if dtrain.construct().num_data() == 0:
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
|
||||
@@ -10,7 +10,6 @@ import os
|
||||
import gc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from packaging import version
|
||||
from typing import Callable, Optional, Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
@@ -149,7 +148,7 @@ class DNNModelPytorch(Model):
|
||||
if scheduler == "default":
|
||||
# In torch version 2.7.0, the verbose parameter has been removed. Reference Link:
|
||||
# https://github.com/pytorch/pytorch/pull/147301/files#diff-036a7470d5307f13c9a6a51c3a65dd014f00ca02f476c545488cd856bea9bcf2L1313
|
||||
if version.parse(str(torch.__version__).split("+", maxsplit=1)[0]) <= version.parse("2.6.0"):
|
||||
if str(torch.__version__).split("+", maxsplit=1)[0] <= "2.6.0":
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # pylint: disable=E1123
|
||||
self.train_optimizer,
|
||||
|
||||
@@ -317,7 +317,7 @@ class TabnetModel(Model):
|
||||
feature = x_train_values.float().to(self.device)
|
||||
label = y_train_values.float().to(self.device)
|
||||
priors = 1 - S_mask
|
||||
vec, sparse_loss = self.tabnet_model(feature, priors)
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
|
||||
@@ -348,7 +348,7 @@ class TabnetModel(Model):
|
||||
S_mask = S_mask.to(self.device)
|
||||
priors = 1 - S_mask
|
||||
with torch.no_grad():
|
||||
vec, sparse_loss = self.tabnet_model(feature, priors)
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
|
||||
@@ -12,7 +12,6 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
from ...utils import get_next_trading_date
|
||||
from ...utils.pickle_utils import restricted_pickle_load
|
||||
from ...backtest.exchange import Exchange
|
||||
|
||||
log = get_module_logger("utils")
|
||||
@@ -31,7 +30,7 @@ def load_instance(file_path):
|
||||
if not file_path.exists():
|
||||
raise ValueError("Cannot find file {}".format(file_path))
|
||||
with file_path.open("rb") as fr:
|
||||
instance = restricted_pickle_load(fr)
|
||||
instance = pickle.load(fr)
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
@@ -3,4 +3,5 @@
|
||||
|
||||
from .analysis_model_performance import model_performance_graph
|
||||
|
||||
|
||||
__all__ = ["model_performance_graph"]
|
||||
|
||||
@@ -7,4 +7,5 @@ from .report import report_graph
|
||||
from .rank_label import rank_label_graph
|
||||
from .risk_analysis import risk_analysis_graph
|
||||
|
||||
|
||||
__all__ = ["cumulative_return_graph", "score_ic_graph", "report_graph", "rank_label_graph", "risk_analysis_graph"]
|
||||
|
||||
@@ -12,7 +12,6 @@ Here is an example.
|
||||
fa.plot_all(wspace=0.3, sub_figsize=(12, 3), col_n=5)
|
||||
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.contrib.report.data.base import FeaAnalyser
|
||||
|
||||
@@ -7,7 +7,6 @@ Assumptions
|
||||
- The analyse each feature individually
|
||||
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.contrib.report.utils import sub_fig_generator
|
||||
|
||||
@@ -14,7 +14,6 @@ from qlib.model.meta.task import MetaTask
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.typehint import Literal
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.utils import replace_task_handler_with_cache
|
||||
|
||||
@@ -299,7 +298,7 @@ class DDGDA(Rolling):
|
||||
# but their task test segment are not aligned! It worked in my previous experiment.
|
||||
# So the misalignment will not affect the effectiveness of the method.
|
||||
with self._internal_data_path.open("rb") as f:
|
||||
internal_data = restricted_pickle_load(f)
|
||||
internal_data = pickle.load(f)
|
||||
|
||||
md = MetaDatasetDS(exp_name=internal_data, **kwargs)
|
||||
|
||||
@@ -361,7 +360,7 @@ class DDGDA(Rolling):
|
||||
)
|
||||
|
||||
with self._internal_data_path.open("rb") as f:
|
||||
internal_data = restricted_pickle_load(f)
|
||||
internal_data = pickle.load(f)
|
||||
mds = MetaDatasetDS(exp_name=internal_data, **kwargs)
|
||||
|
||||
# 3) meta model make inference and get new qlib task
|
||||
|
||||
@@ -16,6 +16,7 @@ from .rule_strategy import (
|
||||
|
||||
from .cost_control import SoftTopkStrategy
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TopkDropoutStrategy",
|
||||
"WeightStrategyBase",
|
||||
|
||||
@@ -1,117 +1,101 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This strategy is not well maintained
|
||||
"""
|
||||
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
from .signal_strategy import WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
class SoftTopkStrategy(WeightStrategyBase):
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
dataset=None,
|
||||
topk=None,
|
||||
model,
|
||||
dataset,
|
||||
topk,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
max_sold_weight=1.0,
|
||||
trade_impact_limit=None,
|
||||
risk_degree=0.95,
|
||||
buy_method="first_fill",
|
||||
trade_exchange=None,
|
||||
level_infra=None,
|
||||
common_infra=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Refactored SoftTopkStrategy with a budget-constrained rebalancing engine.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
topk : int
|
||||
The number of top-N stocks to be held in the portfolio.
|
||||
trade_impact_limit : float
|
||||
Maximum weight change for each stock in one trade. If None, fallback to max_sold_weight.
|
||||
max_sold_weight : float
|
||||
Backward-compatible alias for trade_impact_limit. Use 1.0 to effectively disable the limit.
|
||||
top-N stocks to buy
|
||||
risk_degree : float
|
||||
The target percentage of total value to be invested.
|
||||
position percentage of total value buy_method:
|
||||
|
||||
rank_fill: assign the weight stocks that rank high first(1/topk max)
|
||||
average_fill: assign the weight to the stocks rank high averagely.
|
||||
"""
|
||||
super(SoftTopkStrategy, self).__init__(
|
||||
model=model, dataset=dataset, order_generator_cls_or_obj=order_generator_cls_or_obj, **kwargs
|
||||
model, dataset, order_generator_cls_or_obj, trade_exchange, level_infra, common_infra, **kwargs
|
||||
)
|
||||
|
||||
self.topk = topk
|
||||
self.trade_impact_limit = trade_impact_limit if trade_impact_limit is not None else max_sold_weight
|
||||
self.max_sold_weight = max_sold_weight
|
||||
self.risk_degree = risk_degree
|
||||
self.buy_method = buy_method
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing
|
||||
"""
|
||||
# It will use 95% amount of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time, **kwargs):
|
||||
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
|
||||
"""
|
||||
Generates target position using Proportional Budget Allocation.
|
||||
Ensures deterministic sells and synchronized buys under impact limits.
|
||||
Parameters
|
||||
----------
|
||||
score:
|
||||
pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
|
||||
current:
|
||||
current position, use Position() class
|
||||
trade_date:
|
||||
trade date
|
||||
|
||||
generate target position from score for this date and the current position
|
||||
|
||||
The cache is not considered in the position
|
||||
"""
|
||||
# TODO:
|
||||
# If the current stock list is more than topk(eg. The weights are modified
|
||||
# by risk control), the weight will not be handled correctly.
|
||||
buy_signal_stocks = set(score.sort_values(ascending=False).iloc[: self.topk].index)
|
||||
cur_stock_weight = current.get_stock_weight_dict(only_stock=True)
|
||||
|
||||
if self.topk is None or self.topk <= 0:
|
||||
return {}
|
||||
|
||||
def apply_impact_limit(weight):
|
||||
return weight if self.trade_impact_limit is None else min(weight, self.trade_impact_limit)
|
||||
|
||||
ideal_per_stock = self.risk_degree / self.topk
|
||||
ideal_list = score.sort_values(ascending=False).iloc[: self.topk].index.tolist()
|
||||
|
||||
cur_weights = current.get_stock_weight_dict(only_stock=True)
|
||||
initial_total_weight = sum(cur_weights.values())
|
||||
|
||||
# --- Case A: Cold Start ---
|
||||
if not cur_weights:
|
||||
fill = apply_impact_limit(ideal_per_stock)
|
||||
return {code: fill for code in ideal_list}
|
||||
|
||||
# --- Case B: Rebalancing ---
|
||||
all_tickers = set(cur_weights.keys()) | set(ideal_list)
|
||||
next_weights = {t: cur_weights.get(t, 0.0) for t in all_tickers}
|
||||
|
||||
# Phase 1: Deterministic Sell Phase
|
||||
released_cash = 0.0
|
||||
for t in list(next_weights.keys()):
|
||||
cur = next_weights[t]
|
||||
if cur <= 1e-8:
|
||||
continue
|
||||
|
||||
if t not in ideal_list:
|
||||
sell = apply_impact_limit(cur)
|
||||
next_weights[t] -= sell
|
||||
released_cash += sell
|
||||
elif cur > ideal_per_stock + 1e-8:
|
||||
excess = cur - ideal_per_stock
|
||||
sell = apply_impact_limit(excess)
|
||||
next_weights[t] -= sell
|
||||
released_cash += sell
|
||||
|
||||
# Phase 2: Budget Calculation
|
||||
# Budget = Cash from sells + Available space from target risk degree
|
||||
total_budget = released_cash + (self.risk_degree - initial_total_weight)
|
||||
|
||||
# Phase 3: Proportional Buy Allocation
|
||||
if total_budget > 1e-8:
|
||||
shortfalls = {
|
||||
t: (ideal_per_stock - next_weights.get(t, 0.0))
|
||||
for t in ideal_list
|
||||
if next_weights.get(t, 0.0) < ideal_per_stock - 1e-8
|
||||
}
|
||||
|
||||
if shortfalls:
|
||||
total_shortfall = sum(shortfalls.values())
|
||||
# Normalize total_budget to not exceed total_shortfall
|
||||
available_to_spend = min(total_budget, total_shortfall)
|
||||
|
||||
for t, shortfall in shortfalls.items():
|
||||
# Every stock gets its fair share based on its distance to target
|
||||
share_of_budget = (shortfall / total_shortfall) * available_to_spend
|
||||
|
||||
# Capped by impact limit
|
||||
max_buy_cap = apply_impact_limit(shortfall)
|
||||
|
||||
next_weights[t] += min(share_of_budget, max_buy_cap)
|
||||
|
||||
return {k: v for k, v in next_weights.items() if v > 1e-8}
|
||||
if len(cur_stock_weight) == 0:
|
||||
final_stock_weight = {code: 1 / self.topk for code in buy_signal_stocks}
|
||||
else:
|
||||
final_stock_weight = copy.deepcopy(cur_stock_weight)
|
||||
sold_stock_weight = 0.0
|
||||
for stock_id in final_stock_weight:
|
||||
if stock_id not in buy_signal_stocks:
|
||||
sw = min(self.max_sold_weight, final_stock_weight[stock_id])
|
||||
sold_stock_weight += sw
|
||||
final_stock_weight[stock_id] -= sw
|
||||
if self.buy_method == "first_fill":
|
||||
for stock_id in buy_signal_stocks:
|
||||
add_weight = min(
|
||||
max(1 / self.topk - final_stock_weight.get(stock_id, 0), 0.0),
|
||||
sold_stock_weight,
|
||||
)
|
||||
final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + add_weight
|
||||
sold_stock_weight -= add_weight
|
||||
elif self.buy_method == "average_fill":
|
||||
for stock_id in buy_signal_stocks:
|
||||
final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + sold_stock_weight / len(
|
||||
buy_signal_stocks
|
||||
)
|
||||
else:
|
||||
raise ValueError("Buy method not found")
|
||||
return final_stock_weight
|
||||
|
||||
@@ -5,4 +5,5 @@ from .base import BaseOptimizer
|
||||
from .optimizer import PortfolioOptimizer
|
||||
from .enhanced_indexing import EnhancedIndexingOptimizer
|
||||
|
||||
|
||||
__all__ = ["BaseOptimizer", "PortfolioOptimizer", "EnhancedIndexingOptimizer"]
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Union, Optional, Dict, Any, List
|
||||
from qlib.log import get_module_logger
|
||||
from .base import BaseOptimizer
|
||||
|
||||
|
||||
logger = get_module_logger("EnhancedIndexingOptimizer")
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
"""
|
||||
This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ This module is not a necessary part of Qlib.
|
||||
They are just some tools for convenience
|
||||
It is should not imported into the core part of qlib
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -13,6 +13,7 @@ import yaml
|
||||
|
||||
from .config import TunerConfigManager
|
||||
|
||||
|
||||
args_parser = argparse.ArgumentParser(prog="tuner")
|
||||
args_parser.add_argument(
|
||||
"-c",
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
from hyperopt import hp
|
||||
|
||||
|
||||
TopkAmountStrategySpace = {
|
||||
"topk": hp.choice("topk", [30, 35, 40]),
|
||||
"buffer_margin": hp.choice("buffer_margin", [200, 250, 300]),
|
||||
|
||||
@@ -8,6 +8,7 @@ import os
|
||||
import yaml
|
||||
import json
|
||||
import copy
|
||||
import pickle
|
||||
import logging
|
||||
import importlib
|
||||
import subprocess
|
||||
@@ -17,7 +18,6 @@ import numpy as np
|
||||
from abc import abstractmethod
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils.pickle_utils import restricted_pickle_load
|
||||
from hyperopt import fmin, tpe
|
||||
from hyperopt import STATUS_OK, STATUS_FAIL
|
||||
|
||||
@@ -136,7 +136,7 @@ class QLibTuner(Tuner):
|
||||
exp_result_dir = os.path.join(self.ex_dir, QLibTuner.EXP_RESULT_DIR.format(estimator_ex_id))
|
||||
exp_result_path = os.path.join(exp_result_dir, QLibTuner.EXP_RESULT_NAME)
|
||||
with open(exp_result_path, "rb") as fp:
|
||||
analysis_df = restricted_pickle_load(fp)
|
||||
analysis_df = pickle.load(fp)
|
||||
|
||||
# 4. Get the backtest factor which user want to optimize, if user want to maximize the factor, then reverse the result
|
||||
res = analysis_df.loc[self.optim_config.report_type].loc[self.optim_config.report_factor]
|
||||
|
||||
@@ -3,4 +3,5 @@
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
|
||||
|
||||
__all__ = ["MultiSegRecord", "SignalMseRecord"]
|
||||
|
||||
@@ -36,6 +36,7 @@ from .cache import (
|
||||
MemoryCalendarCache,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"D",
|
||||
"CalendarProvider",
|
||||
|
||||
@@ -30,7 +30,6 @@ from ..utils import (
|
||||
normalize_cache_fields,
|
||||
normalize_cache_instruments,
|
||||
)
|
||||
from ..utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
from ..log import get_module_logger
|
||||
from .base import Feature
|
||||
@@ -226,7 +225,7 @@ class CacheUtils:
|
||||
cache_path = Path(cache_path)
|
||||
meta_path = cache_path.with_suffix(".meta")
|
||||
with meta_path.open("rb") as f:
|
||||
d = restricted_pickle_load(f)
|
||||
d = pickle.load(f)
|
||||
with meta_path.open("wb") as f:
|
||||
try:
|
||||
d["meta"]["last_visit"] = str(time.time())
|
||||
@@ -593,7 +592,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri())}:expression-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = restricted_pickle_load(f)
|
||||
d = pickle.load(f)
|
||||
instrument = d["info"]["instrument"]
|
||||
field = d["info"]["field"]
|
||||
freq = d["info"]["freq"]
|
||||
@@ -960,7 +959,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
im = DiskDatasetCache.IndexManager(cp_cache_uri)
|
||||
with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri())}:dataset-{cache_uri}"):
|
||||
with meta_path.open("rb") as f:
|
||||
d = restricted_pickle_load(f)
|
||||
d = pickle.load(f)
|
||||
instruments = d["info"]["instruments"]
|
||||
fields = d["info"]["fields"]
|
||||
freq = d["info"]["freq"]
|
||||
|
||||
@@ -2,15 +2,15 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division, print_function
|
||||
|
||||
import json
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import socketio
|
||||
|
||||
import qlib
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
import pickle
|
||||
|
||||
|
||||
class Client:
|
||||
@@ -96,7 +96,7 @@ class Client:
|
||||
self.logger.debug("connected")
|
||||
# The pickle is for passing some parameters with special type(such as
|
||||
# pd.Timestamp)
|
||||
request_content = {"head": head_info, "body": json.dumps(request_content, default=str)}
|
||||
request_content = {"head": head_info, "body": pickle.dumps(request_content, protocol=C.dump_protocol_version)}
|
||||
self.sio.on(request_type + "_response", request_callback)
|
||||
self.logger.debug("try sending")
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
|
||||
@@ -19,6 +19,7 @@ from .loader import DataLoader
|
||||
from . import processor as processor_module
|
||||
from . import loader as data_loader_module
|
||||
|
||||
|
||||
DATA_KEY_TYPE = Literal["raw", "infer", "learn"]
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import pandas as pd
|
||||
@@ -10,7 +11,6 @@ from typing import Tuple, Union, List, Dict
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
|
||||
@@ -283,7 +283,7 @@ class StaticDataLoader(DataLoader, Serializable):
|
||||
self._data = pd.read_parquet(self._config, engine="pyarrow")
|
||||
else:
|
||||
with Path(self._config).open("rb") as f:
|
||||
self._data = restricted_pickle_load(f)
|
||||
self._data = pickle.load(f)
|
||||
elif isinstance(self._config, pd.DataFrame):
|
||||
self._data = self._config
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ class SeriesDFilter(BaseDFilter):
|
||||
for _ts, _bool in timestamp_series.items():
|
||||
# there is likely to be NAN when the filter series don't have the
|
||||
# bool value, so we just change the NAN into False
|
||||
if np.isnan(_bool):
|
||||
if _bool == np.nan:
|
||||
_bool = False
|
||||
if _lbool is None:
|
||||
_cur_start = _ts
|
||||
|
||||
@@ -13,7 +13,6 @@ The calculation of both <period_time, feature> and <observe_time, feature> data
|
||||
2) concatenate all th collasped data, we will get data with format <observe_time, feature>.
|
||||
Qlib will use the operator `P` to perform the collapse.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.data.ops import ElemOperator
|
||||
|
||||
@@ -3,4 +3,5 @@
|
||||
|
||||
from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT
|
||||
|
||||
|
||||
__all__ = ["CalendarStorage", "InstrumentStorage", "FeatureStorage", "CalVT", "InstVT", "InstKT"]
|
||||
|
||||
@@ -156,7 +156,7 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
def index(self, value: CalVT) -> int:
|
||||
self.check()
|
||||
calendar = self._read_calendar()
|
||||
return calendar.index(value)
|
||||
return int(np.argwhere(calendar == value)[0])
|
||||
|
||||
def insert(self, index: int, value: CalVT):
|
||||
calendar = self._read_calendar()
|
||||
|
||||
@@ -5,4 +5,5 @@ import warnings
|
||||
|
||||
from .base import Model
|
||||
|
||||
|
||||
__all__ = ["Model", "warnings"]
|
||||
|
||||
@@ -4,4 +4,5 @@
|
||||
from .task import MetaTask
|
||||
from .dataset import MetaTaskDataset
|
||||
|
||||
|
||||
__all__ = ["MetaTask", "MetaTaskDataset"]
|
||||
|
||||
@@ -6,6 +6,7 @@ from .poet import POETCovEstimator
|
||||
from .shrink import ShrinkCovEstimator
|
||||
from .structured import StructuredCovEstimator
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RiskModel",
|
||||
"POETCovEstimator",
|
||||
|
||||
@@ -9,6 +9,7 @@ import tempfile
|
||||
from importlib import import_module
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
||||
DELETE_KEY = "_delete_"
|
||||
|
||||
|
||||
|
||||
@@ -2,18 +2,17 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, cast
|
||||
from typing import cast, List
|
||||
|
||||
import cachetools
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from qlib.backtest.decision import TradeRange, TradeRangeByTime
|
||||
from qlib.constant import EPS_T
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
|
||||
|
||||
|
||||
@@ -163,7 +162,7 @@ class HandlerIntradayProcessedData(BaseIntradayProcessedData):
|
||||
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
with open(path, "rb") as fstream:
|
||||
dataset = restricted_pickle_load(fstream)
|
||||
dataset = pickle.load(fstream)
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
|
||||
if index_only:
|
||||
|
||||
@@ -5,9 +5,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from qlib.typehint import final
|
||||
from .simulator import ActType, StateType
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, List, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.base import ProcessedDataProvider
|
||||
|
||||
@@ -6,11 +6,11 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym.spaces import Discrete
|
||||
from gymnasium.spaces import Discrete
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.policy import BasePolicy, PPOPolicy, DQNPolicy
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ from __future__ import annotations
|
||||
import weakref
|
||||
from typing import Any, Callable, cast, Dict, Generic, Iterable, Iterator, Optional, Tuple
|
||||
|
||||
import gym
|
||||
from gym import Space
|
||||
import gymnasium as gym
|
||||
from gymnasium import Space
|
||||
|
||||
from qlib.rl.aux_info import AuxiliaryInfoCollector
|
||||
from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter
|
||||
|
||||
@@ -13,7 +13,7 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
"""
|
||||
This module covers some utility functions that operate on data or basic object
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import contextlib
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import pkgutil
|
||||
import re
|
||||
import sys
|
||||
@@ -19,7 +20,6 @@ from typing import Any, Dict, List, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from qlib.typehint import InstConf
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
|
||||
def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
@@ -168,10 +168,10 @@ def init_instance_by_config(
|
||||
|
||||
pr_path = os.path.join(pr.netloc, path) if bool(pr.path) else pr.netloc
|
||||
with open(os.path.normpath(pr_path), "rb") as f:
|
||||
return restricted_pickle_load(f)
|
||||
return pickle.load(f)
|
||||
else:
|
||||
with config.open("rb") as f:
|
||||
return restricted_pickle_load(f)
|
||||
return pickle.load(f)
|
||||
|
||||
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from qlib.config import C
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
|
||||
class ObjManager:
|
||||
@@ -117,7 +116,7 @@ class FileManager(ObjManager):
|
||||
|
||||
def load_obj(self, name):
|
||||
with (self.path / name).open("rb") as f:
|
||||
return restricted_pickle_load(f)
|
||||
return pickle.load(f)
|
||||
|
||||
def exists(self, name):
|
||||
return (self.path / name).exists()
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Secure pickle utilities to prevent arbitrary code execution through deserialization.
|
||||
|
||||
This module provides a secure alternative to pickle.load() and pickle.loads()
|
||||
that restricts deserialization to a whitelist of safe classes.
|
||||
"""
|
||||
|
||||
import io
|
||||
import pickle
|
||||
from typing import Any, BinaryIO, Set, Tuple
|
||||
|
||||
# Whitelist of safe classes that are allowed to be unpickled
|
||||
# These are common data types used in qlib that should be safe to deserialize
|
||||
SAFE_PICKLE_CLASSES: Set[Tuple[str, str]] = {
|
||||
# python builtins
|
||||
("builtins", "slice"),
|
||||
("builtins", "range"),
|
||||
("builtins", "dict"),
|
||||
("builtins", "list"),
|
||||
("builtins", "tuple"),
|
||||
("builtins", "set"),
|
||||
("builtins", "frozenset"),
|
||||
("builtins", "bytearray"),
|
||||
("builtins", "bytes"),
|
||||
("builtins", "str"),
|
||||
("builtins", "int"),
|
||||
("builtins", "float"),
|
||||
("builtins", "bool"),
|
||||
("builtins", "complex"),
|
||||
("builtins", "type"),
|
||||
("builtins", "property"),
|
||||
# common utility classes
|
||||
("datetime", "datetime"),
|
||||
("datetime", "date"),
|
||||
("datetime", "time"),
|
||||
("datetime", "timedelta"),
|
||||
("datetime", "timezone"),
|
||||
("decimal", "Decimal"),
|
||||
("collections", "OrderedDict"),
|
||||
("collections", "defaultdict"),
|
||||
("collections", "Counter"),
|
||||
("collections", "namedtuple"),
|
||||
("enum", "Enum"),
|
||||
("pathlib", "Path"),
|
||||
("pathlib", "PosixPath"),
|
||||
("pathlib", "WindowsPath"),
|
||||
("qlib.data.dataset.handler", "DataHandler"),
|
||||
("qlib.data.dataset.handler", "DataHandlerLP"),
|
||||
("qlib.data.dataset.loader", "StaticDataLoader"),
|
||||
}
|
||||
|
||||
|
||||
TRUSTED_MODULE_PREFIXES = (
|
||||
"pandas",
|
||||
"numpy",
|
||||
)
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
"""Custom unpickler that only allows safe classes to be deserialized.
|
||||
|
||||
This prevents arbitrary code execution through malicious pickle files by
|
||||
restricting deserialization to a whitelist of safe classes.
|
||||
|
||||
Example:
|
||||
>>> with open("data.pkl", "rb") as f:
|
||||
... data = RestrictedUnpickler(f).load()
|
||||
"""
|
||||
|
||||
def find_class(self, module: str, name: str):
|
||||
"""Override find_class to restrict allowed classes.
|
||||
|
||||
Args:
|
||||
module: Module name of the class
|
||||
name: Class name
|
||||
|
||||
Returns:
|
||||
The class object if it's in the whitelist
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the class is not in the whitelist
|
||||
"""
|
||||
if module.startswith(TRUSTED_MODULE_PREFIXES):
|
||||
return super().find_class(module, name)
|
||||
|
||||
# 2. explicit whitelist (qlib internal)
|
||||
if (module, name) in SAFE_PICKLE_CLASSES:
|
||||
return super().find_class(module, name)
|
||||
|
||||
raise pickle.UnpicklingError(
|
||||
f"Forbidden class: {module}.{name}. "
|
||||
f"Only whitelisted classes are allowed for security reasons. "
|
||||
f"This is to prevent arbitrary code execution through pickle deserialization."
|
||||
)
|
||||
|
||||
|
||||
def restricted_pickle_load(file: BinaryIO) -> Any:
|
||||
"""Safely load a pickle file with restricted classes.
|
||||
|
||||
This is a drop-in replacement for pickle.load() that prevents
|
||||
arbitrary code execution by only allowing whitelisted classes.
|
||||
|
||||
Args:
|
||||
file: An opened file object in binary mode
|
||||
|
||||
Returns:
|
||||
The unpickled Python object
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the pickle contains forbidden classes
|
||||
|
||||
Example:
|
||||
>>> with open("data.pkl", "rb") as f:
|
||||
... data = restricted_pickle_load(f)
|
||||
"""
|
||||
return RestrictedUnpickler(file).load()
|
||||
|
||||
|
||||
def restricted_pickle_loads(data: bytes) -> Any:
|
||||
"""Safely load a pickle from bytes with restricted classes.
|
||||
|
||||
This is a drop-in replacement for pickle.loads() that prevents
|
||||
arbitrary code execution by only allowing whitelisted classes.
|
||||
|
||||
Args:
|
||||
data: Bytes object containing pickled data
|
||||
|
||||
Returns:
|
||||
The unpickled Python object
|
||||
|
||||
Raises:
|
||||
pickle.UnpicklingError: If the pickle contains forbidden classes
|
||||
|
||||
Example:
|
||||
>>> data = b'\\x80\\x04\\x95...'
|
||||
>>> obj = restricted_pickle_loads(data)
|
||||
"""
|
||||
file_like = io.BytesIO(data)
|
||||
return RestrictedUnpickler(file_like).load()
|
||||
|
||||
|
||||
def add_safe_class(module: str, name: str) -> None:
|
||||
"""Add a class to the whitelist of safe classes for unpickling.
|
||||
|
||||
Use this function to extend the whitelist if your code needs to deserialize
|
||||
additional classes. However, be very careful when adding classes, as this
|
||||
could potentially introduce security vulnerabilities.
|
||||
|
||||
Args:
|
||||
module: Module name of the class (e.g., 'my_package.my_module')
|
||||
name: Class name (e.g., 'MyClass')
|
||||
|
||||
Warning:
|
||||
Only add classes that you fully control and trust. Adding arbitrary
|
||||
classes from external packages could introduce security risks.
|
||||
|
||||
Example:
|
||||
>>> add_safe_class('my_package.models', 'CustomModel')
|
||||
"""
|
||||
SAFE_PICKLE_CLASSES.add((module, name))
|
||||
|
||||
|
||||
def get_safe_classes() -> Set[Tuple[str, str]]:
|
||||
"""Get a copy of the current whitelist of safe classes.
|
||||
|
||||
Returns:
|
||||
A set of (module, name) tuples representing allowed classes
|
||||
"""
|
||||
return SAFE_PICKLE_CLASSES.copy()
|
||||
@@ -3,7 +3,6 @@
|
||||
"""
|
||||
Time related utils are compiled in this script
|
||||
"""
|
||||
|
||||
import bisect
|
||||
from datetime import datetime, time, date, timedelta
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -15,6 +14,7 @@ import pandas as pd
|
||||
from qlib.config import C
|
||||
from qlib.constant import REG_CN, REG_TW, REG_US
|
||||
|
||||
|
||||
CN_TIME = [
|
||||
datetime.strptime("9:30", "%H:%M"),
|
||||
datetime.strptime("11:30", "%H:%M"),
|
||||
|
||||
@@ -16,6 +16,7 @@ from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
from ..utils.exceptions import ExpAlreadyExistError
|
||||
|
||||
|
||||
logger = get_module_logger("workflow")
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from ..utils.data import deepcopy_basic_type
|
||||
from ..utils.exceptions import QlibException
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
@@ -475,13 +476,7 @@ class PortAnaRecord(ACRecordTemp):
|
||||
if self.backtest_config["start_time"] is None:
|
||||
self.backtest_config["start_time"] = dt_values.min()
|
||||
if self.backtest_config["end_time"] is None:
|
||||
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), -1)
|
||||
warnings.warn(
|
||||
"No explicit backtest end_time provided. "
|
||||
"Qlib requires one extra calendar step to determine the right boundary of a bar. "
|
||||
"Therefore the end_time is shifted backward by one trading day from "
|
||||
f"{dt_values.max()} -> {self.backtest_config['end_time']}."
|
||||
)
|
||||
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
|
||||
|
||||
artifact_objects = {}
|
||||
# custom strategy and get backtest
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
"""
|
||||
TaskGenerator module can generate many tasks based on TaskGen and some task templates.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import pandas as pd
|
||||
@@ -107,13 +106,15 @@ def handler_mod(task: dict, rolling_gen):
|
||||
rg (RollingGen): an instance of RollingGen
|
||||
"""
|
||||
try:
|
||||
handler_kwargs = task["dataset"]["kwargs"]["handler"]["kwargs"]
|
||||
handler_end_time = handler_kwargs.get("end_time")
|
||||
test_seg_end_time = task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
|
||||
# if the end of test_segments is None (open-ended segment, i.e., "until now") or end_time < the end of test_segments,
|
||||
# then change end_time to allow load more data
|
||||
if test_seg_end_time is None or rolling_gen.ta.cal_interval(handler_end_time, test_seg_end_time) < 0:
|
||||
handler_kwargs["end_time"] = copy.deepcopy(test_seg_end_time)
|
||||
interval = rolling_gen.ta.cal_interval(
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1],
|
||||
)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if interval < 0:
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
|
||||
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
|
||||
)
|
||||
except KeyError:
|
||||
# Maybe dataset do not have handler, then do nothing.
|
||||
pass
|
||||
|
||||
@@ -12,7 +12,6 @@ A task in TaskManager consists of 3 parts
|
||||
- tasks status: the status of the task
|
||||
- tasks result: A user can get the task with the task description and task result.
|
||||
"""
|
||||
|
||||
import concurrent
|
||||
import pickle
|
||||
import time
|
||||
@@ -29,7 +28,6 @@ from tqdm.cli import tqdm
|
||||
|
||||
from .utils import get_mongodb
|
||||
from ...config import C
|
||||
from ...utils.pickle_utils import restricted_pickle_loads
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -133,7 +131,7 @@ class TaskManager:
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = restricted_pickle_loads(task[k])
|
||||
task[k] = pickle.loads(task[k])
|
||||
return task
|
||||
|
||||
def _dict_to_str(self, flt):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from loguru import logger
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
import qlib
|
||||
from tqdm import tqdm
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ class DataHealthChecker:
|
||||
self.large_step_threshold_price = large_step_threshold_price
|
||||
self.large_step_threshold_volume = large_step_threshold_volume
|
||||
self.missing_data_num = missing_data_num
|
||||
self.qlib_dir = os.path.abspath(os.path.expanduser(qlib_dir))
|
||||
|
||||
if csv_path:
|
||||
assert os.path.isdir(csv_path), f"{csv_path} should be a directory."
|
||||
@@ -69,43 +68,6 @@ class DataHealthChecker:
|
||||
self.data[instrument] = df
|
||||
print(df)
|
||||
|
||||
# NOTE:
|
||||
# This check is added due to a known issue in Qlib where feature paths
|
||||
# are constructed using lowercased instrument names. On case-sensitive
|
||||
# file systems (e.g. Linux), uppercase directory names under `features/`
|
||||
# will cause data loading failures.
|
||||
#
|
||||
# See: https://github.com/microsoft/qlib/issues/2053
|
||||
def check_features_dir_lowercase(self) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Check whether all subdirectories under `<qlib_dir>/features` are named in lowercase.
|
||||
|
||||
This validation helps prevent data loading issues on case-sensitive
|
||||
file systems caused by uppercase instrument directory names.
|
||||
"""
|
||||
if not self.qlib_dir:
|
||||
return None
|
||||
|
||||
features_dir = os.path.join(self.qlib_dir, "features")
|
||||
if not os.path.isdir(features_dir):
|
||||
logger.warning(f"`features` directory not found under {self.qlib_dir}")
|
||||
return None
|
||||
|
||||
bad_dirs = []
|
||||
for name in os.listdir(features_dir):
|
||||
full_path = os.path.join(features_dir, name)
|
||||
if os.path.isdir(full_path) and name != name.lower():
|
||||
bad_dirs.append(name)
|
||||
|
||||
if bad_dirs:
|
||||
result_df = pd.DataFrame({"non_lowercase_dir": bad_dirs})
|
||||
return result_df
|
||||
else:
|
||||
logger.info(
|
||||
f"✅ All subdirectories under `{os.path.join(self.qlib_dir, 'features')}` are named in lowercase."
|
||||
)
|
||||
return None
|
||||
|
||||
def check_missing_data(self) -> Optional[pd.DataFrame]:
|
||||
"""Check if any data is missing in the DataFrame."""
|
||||
result_dict = {
|
||||
@@ -215,13 +177,11 @@ class DataHealthChecker:
|
||||
check_large_step_changes_result = self.check_large_step_changes()
|
||||
check_required_columns_result = self.check_required_columns()
|
||||
check_missing_factor_result = self.check_missing_factor()
|
||||
check_features_dir_case_result = self.check_features_dir_lowercase()
|
||||
if (
|
||||
check_large_step_changes_result is not None
|
||||
or check_large_step_changes_result is not None
|
||||
or check_required_columns_result is not None
|
||||
or check_missing_factor_result is not None
|
||||
or check_features_dir_case_result is not None
|
||||
):
|
||||
print(f"\nSummary of data health check ({len(self.data)} files checked):")
|
||||
print("-------------------------------------------------")
|
||||
@@ -237,11 +197,6 @@ class DataHealthChecker:
|
||||
if isinstance(check_missing_factor_result, pd.DataFrame):
|
||||
logger.warning(f"The factor column does not exist or is empty")
|
||||
print(check_missing_factor_result)
|
||||
if isinstance(check_features_dir_case_result, pd.DataFrame):
|
||||
logger.warning(
|
||||
f"Some subdirectories under `{os.path.join(self.qlib_dir, 'features')}` contain uppercase letters, please rename them to lowercase manually."
|
||||
)
|
||||
print(check_features_dir_case_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -45,7 +45,7 @@ class InfoCollector:
|
||||
"pymongo",
|
||||
"loguru",
|
||||
"lightgbm",
|
||||
"gym",
|
||||
"gymnasium",
|
||||
"cvxpy",
|
||||
"joblib",
|
||||
"matplotlib",
|
||||
|
||||
@@ -280,20 +280,11 @@ class Normalize:
|
||||
self._symbol_field_name = symbol_field_name
|
||||
self._end_date = kwargs.get("end_date", None)
|
||||
self._max_workers = max_workers
|
||||
self.interval = kwargs.get("interval", "1d")
|
||||
|
||||
self._normalize_obj = normalize_class(
|
||||
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
|
||||
)
|
||||
|
||||
def format_data(self, df: pd.DataFrame):
|
||||
if self.interval == "1d":
|
||||
try:
|
||||
pd.to_datetime(df.iloc[-1]["date"], format="%Y-%m-%d", errors="raise")
|
||||
except Exception:
|
||||
df = df.iloc[:-1]
|
||||
return df
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
@@ -309,18 +300,14 @@ class Normalize:
|
||||
keep_default_na=False,
|
||||
na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns},
|
||||
)
|
||||
df = self.format_data(df=df)
|
||||
|
||||
if not df.empty:
|
||||
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
||||
df = df[_mask]
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
else:
|
||||
logger.warning(f"{file_path.stem} source data is empty and will not undergo normalization processing.")
|
||||
# NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified.
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if df is not None and not df.empty:
|
||||
if self._end_date is not None:
|
||||
_mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date)
|
||||
df = df[_mask]
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
@@ -22,6 +22,7 @@ from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
|
||||
from data_collector.utils import get_instruments
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = (
|
||||
"https://oss-ch.csindex.com.cn/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from time import mktime
|
||||
from datetime import datetime as dt
|
||||
import time
|
||||
|
||||
|
||||
_CG_CRYPTO_SYMBOLS = None
|
||||
|
||||
|
||||
|
||||
@@ -7,14 +7,13 @@ import sys
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
from io import StringIO
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from fake_useragent import UserAgent
|
||||
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
@@ -23,6 +22,7 @@ from data_collector.index import IndexBase
|
||||
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
|
||||
from data_collector.utils import get_instruments
|
||||
|
||||
|
||||
WIKI_URL = "https://en.wikipedia.org/wiki"
|
||||
|
||||
WIKI_INDEX_NAME_MAP = {
|
||||
@@ -51,7 +51,6 @@ class WIKIIndex(IndexBase):
|
||||
)
|
||||
|
||||
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
|
||||
self._ua = UserAgent()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -113,8 +112,7 @@ class WIKIIndex(IndexBase):
|
||||
return _calendar_list
|
||||
|
||||
def _request_new_companies(self) -> requests.Response:
|
||||
headers = {"User-Agent": self._ua.random}
|
||||
resp = requests.get(self._target_url, timeout=None, headers=headers)
|
||||
resp = requests.get(self._target_url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {self._target_url}")
|
||||
|
||||
@@ -130,7 +128,7 @@ class WIKIIndex(IndexBase):
|
||||
def get_new_companies(self):
|
||||
logger.info(f"get new companies {self.index_name} ......")
|
||||
_data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()
|
||||
df_list = pd.read_html(StringIO(_data.text))
|
||||
df_list = pd.read_html(_data.text)
|
||||
for _df in df_list:
|
||||
_df = self.filter_df(_df)
|
||||
if (_df is not None) and (not _df.empty):
|
||||
@@ -228,11 +226,7 @@ class SP500Index(WIKIIndex):
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
logger.info(f"get sp500 history changes......")
|
||||
# NOTE: may update the index of the table
|
||||
# Add headers to avoid 403 Forbidden error from Wikipedia
|
||||
headers = {"User-Agent": self._ua.random}
|
||||
response = requests.get(self.WIKISP500_CHANGES_URL, headers=headers, timeout=None)
|
||||
response.raise_for_status()
|
||||
changes_df = pd.read_html(StringIO(response.text))[-1]
|
||||
changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1]
|
||||
changes_df = changes_df.iloc[:, [0, 1, 3]]
|
||||
changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]
|
||||
changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])
|
||||
|
||||
@@ -3,4 +3,3 @@ requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
fake-useragent
|
||||
|
||||
@@ -7,6 +7,7 @@ import importlib
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
@@ -20,9 +21,6 @@ from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from bs4 import BeautifulSoup
|
||||
import baostock as bs
|
||||
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
@@ -69,16 +67,9 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
def _get_calendar(end_date):
|
||||
bs.login()
|
||||
rs = bs.query_trade_dates(start_date="2005-01-01", end_date=end_date)
|
||||
data_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
data_list.append(rs.get_row_data())
|
||||
bs.logout()
|
||||
df = pd.DataFrame(data_list, columns=rs.fields)
|
||||
trade_days = df[df["is_trading_day"] == "1"]["calendar_date"]
|
||||
return sorted(map(pd.Timestamp, trade_days.to_list()))
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url, timeout=None).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
if calendar is None:
|
||||
@@ -89,17 +80,30 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
|
||||
else:
|
||||
if bench_code.upper() == "ALL":
|
||||
import akshare as ak # pylint: disable=C0415
|
||||
|
||||
trade_date_df = ak.tool_trade_date_hist_sina()
|
||||
trade_date_list = trade_date_df["trade_date"].tolist()
|
||||
trade_date_list = [pd.Timestamp(d) for d in trade_date_list]
|
||||
dates = pd.DatetimeIndex(trade_date_list)
|
||||
filtered_dates = dates[(dates >= "2000-01-04") & (dates <= pd.Timestamp.today().normalize())]
|
||||
calendar = filtered_dates.tolist()
|
||||
@deco_retry
|
||||
def _get_calendar(month):
|
||||
_cal = []
|
||||
try:
|
||||
resp = requests.get(
|
||||
SZSE_CALENDAR_URL.format(month=month, random=random.random), timeout=None
|
||||
).json()
|
||||
for _r in resp["data"]:
|
||||
if int(_r["jybz"]):
|
||||
_cal.append(pd.Timestamp(_r["jyrq"]))
|
||||
except Exception as e:
|
||||
raise ValueError(f"{month}-->{e}") from e
|
||||
return _cal
|
||||
|
||||
month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M")
|
||||
calendar = []
|
||||
for _m in month_range:
|
||||
cal = _get_calendar(_m.strftime("%Y-%m"))
|
||||
if cal:
|
||||
calendar += cal
|
||||
calendar = list(filter(lambda x: x <= pd.Timestamp.now(), calendar))
|
||||
else:
|
||||
end_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||
calendar = _get_calendar(end_date=end_date)
|
||||
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
|
||||
_CALENDAR_MAP[bench_code] = calendar
|
||||
logger.info(f"end of get calendar list: {bench_code}.")
|
||||
return calendar
|
||||
@@ -276,7 +280,7 @@ def get_hs_stock_symbols() -> list:
|
||||
symbol_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if symbol_cache_path.exists():
|
||||
with symbol_cache_path.open("rb") as fp:
|
||||
cache_symbols = restricted_pickle_load(fp)
|
||||
cache_symbols = pickle.load(fp)
|
||||
symbols |= cache_symbols
|
||||
with symbol_cache_path.open("wb") as fp:
|
||||
pickle.dump(symbols, fp)
|
||||
@@ -293,14 +297,20 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
import akshare as ak # pylint: disable=C0415
|
||||
|
||||
global _US_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
df = ak.get_us_stock_name()
|
||||
_symbols = df["symbol"].to_list()
|
||||
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
|
||||
resp = requests.get(url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
try:
|
||||
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
|
||||
@@ -613,6 +613,10 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalizeUS:
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
|
||||
@@ -9,5 +9,4 @@ yahooquery
|
||||
joblib
|
||||
beautifulsoup4
|
||||
bs4
|
||||
soupsieve
|
||||
akshare
|
||||
soupsieve
|
||||
@@ -4,5 +4,6 @@
|
||||
import fire
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(GetData)
|
||||
|
||||
11
setup.py
11
setup.py
@@ -2,11 +2,22 @@ import os
|
||||
|
||||
import numpy
|
||||
from setuptools import Extension, setup
|
||||
from setuptools_scm import get_version
|
||||
|
||||
|
||||
def read(rel_path: str) -> str:
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
with open(os.path.join(here, rel_path), encoding="utf-8") as fp:
|
||||
return fp.read()
|
||||
|
||||
|
||||
NUMPY_INCLUDE = numpy.get_include()
|
||||
|
||||
|
||||
VERSION = get_version(root=".", relative_to=__file__)
|
||||
|
||||
setup(
|
||||
version=VERSION,
|
||||
ext_modules=[
|
||||
Extension(
|
||||
"qlib.data._libs.rolling",
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from qlib.contrib.strategy.cost_control import SoftTopkStrategy
|
||||
|
||||
|
||||
class MockPosition:
|
||||
def __init__(self, weights):
|
||||
self.weights = weights
|
||||
|
||||
def get_stock_weight_dict(self, only_stock=True):
|
||||
return self.weights
|
||||
|
||||
|
||||
def test_soft_topk_logic():
|
||||
# Initial: A=0.8, B=0.2 (Total=1.0). Target Risk=0.95.
|
||||
# Scores: A and B are low, C and D are topk.
|
||||
scores = pd.Series({"C": 0.9, "D": 0.8, "A": 0.1, "B": 0.1})
|
||||
current_pos = MockPosition({"A": 0.8, "B": 0.2})
|
||||
|
||||
topk = 2
|
||||
risk_degree = 0.95
|
||||
impact_limit = 0.1 # Max change per step
|
||||
|
||||
def create_test_strategy(impact_limit_value):
|
||||
strat = SoftTopkStrategy.__new__(SoftTopkStrategy)
|
||||
strat.topk = topk
|
||||
strat.risk_degree = risk_degree
|
||||
strat.trade_impact_limit = impact_limit_value
|
||||
return strat
|
||||
|
||||
# 1. With impact limit: Expect deterministic sell and limited buy
|
||||
strat_i = create_test_strategy(impact_limit)
|
||||
res_i = strat_i.generate_target_weight_position(scores, current_pos, None, None)
|
||||
|
||||
# A should be exactly 0.8 - 0.1 = 0.7
|
||||
assert abs(res_i["A"] - 0.7) < 1e-8
|
||||
# B should be exactly 0.2 - 0.1 = 0.1
|
||||
assert abs(res_i["B"] - 0.1) < 1e-8
|
||||
# Total sells = 0.2 released. New budget = 0.2 + (0.95 - 1.0) = 0.15.
|
||||
# C and D share 0.15 -> 0.075 each.
|
||||
assert abs(res_i["C"] - 0.075) < 1e-8
|
||||
assert abs(res_i["D"] - 0.075) < 1e-8
|
||||
|
||||
# 2. Without impact limit: Expect full liquidation and full target fill
|
||||
strat_c = create_test_strategy(1.0)
|
||||
res_c = strat_c.generate_target_weight_position(scores, current_pos, None, None)
|
||||
|
||||
# A, B not in topk -> Liquidated
|
||||
assert "A" not in res_c and "B" not in res_c
|
||||
# C, D should reach ideal_per_stock (0.95/2 = 0.475)
|
||||
assert abs(res_c["C"] - 0.475) < 1e-8
|
||||
assert abs(res_c["D"] - 0.475) < 1e-8
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -1,38 +0,0 @@
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from qlib.contrib.strategy.cost_control import SoftTopkStrategy
|
||||
|
||||
|
||||
class MockPosition:
|
||||
def __init__(self, weights):
|
||||
self.weights = weights
|
||||
|
||||
def get_stock_weight_dict(self, only_stock=True):
|
||||
return self.weights
|
||||
|
||||
|
||||
def create_test_strategy(topk, risk_degree, impact_limit):
|
||||
strat = SoftTopkStrategy.__new__(SoftTopkStrategy)
|
||||
strat.topk = topk
|
||||
strat.risk_degree = risk_degree
|
||||
strat.trade_impact_limit = impact_limit
|
||||
return strat
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("impact_limit", "expected_fill"),
|
||||
[
|
||||
(0.1, 0.1),
|
||||
(1.0, 0.475),
|
||||
],
|
||||
)
|
||||
def test_soft_topk_cold_start_impact_limit(impact_limit, expected_fill):
|
||||
scores = pd.Series({"C": 0.9, "D": 0.8, "A": 0.1, "B": 0.1})
|
||||
current_pos = MockPosition({})
|
||||
|
||||
strat = create_test_strategy(topk=2, risk_degree=0.95, impact_limit=impact_limit)
|
||||
res = strat.generate_target_weight_position(scores, current_pos, None, None)
|
||||
|
||||
assert abs(res["C"] - expected_fill) < 1e-8
|
||||
assert abs(res["D"] - expected_fill) < 1e-8
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.utils.pickle_utils import restricted_pickle_load
|
||||
|
||||
|
||||
class HandlerTests(TestAutoData):
|
||||
@@ -23,7 +23,7 @@ class HandlerTests(TestAutoData):
|
||||
dh.to_pickle(fname, dump_all=True)
|
||||
|
||||
with open(fname, "rb") as f:
|
||||
dh_d = restricted_pickle_load(f)
|
||||
dh_d = pickle.load(f)
|
||||
|
||||
self.assertTrue(dh_d._data.equals(df))
|
||||
self.assertTrue(dh_d._infer is dh_d._data)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from collections import Counter
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from tianshou.data import Batch, Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
@@ -17,6 +17,7 @@ from qlib.rl.utils.finite_env import (
|
||||
generate_nan_observation,
|
||||
)
|
||||
|
||||
|
||||
_test_space = gym.spaces.Dict(
|
||||
{
|
||||
"sensors": gym.spaces.Dict(
|
||||
|
||||
@@ -7,10 +7,10 @@ import logging
|
||||
import re
|
||||
from typing import Any, Tuple
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from tianshou.data import Collector, Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from tianshou.policy import PPOPolicy
|
||||
|
||||
from qlib.config import C
|
||||
|
||||
@@ -13,6 +13,7 @@ from qlib.workflow import R
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.tests.config import GBDT_MODEL, get_dataset_config, CSI300_MARKET
|
||||
|
||||
|
||||
CSI300_GBDT_TASK = {
|
||||
"model": GBDT_MODEL,
|
||||
"dataset": get_dataset_config(
|
||||
|
||||
@@ -16,6 +16,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
from dump_bin import DumpDataAll, DumpDataFix
|
||||
|
||||
|
||||
DATA_DIR = Path(__file__).parent.joinpath("test_dump_data")
|
||||
SOURCE_DIR = DATA_DIR.joinpath("source")
|
||||
SOURCE_DIR.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -19,6 +19,7 @@ from dump_pit import DumpPitData
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit")))
|
||||
from collector import Run
|
||||
|
||||
|
||||
pd.set_option("display.width", 1000)
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user