mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
30 Commits
optimize_w
...
fix_docume
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5fafba36f2 | ||
|
|
8a087d0db9 | ||
|
|
2ae4be426a | ||
|
|
6ed83f7c04 | ||
|
|
917e3a725e | ||
|
|
b1e0e77c97 | ||
|
|
ea245f5435 | ||
|
|
3779b5186a | ||
|
|
194284b1ac | ||
|
|
1bb8f2fa23 | ||
|
|
39f88daaa7 | ||
|
|
98f569eed2 | ||
|
|
ceff886f49 | ||
|
|
15b64768e2 | ||
|
|
8bf2678676 | ||
|
|
fb80e318e2 | ||
|
|
ecbeeafdc1 | ||
|
|
69e28ceab8 | ||
|
|
4c30e5827b | ||
|
|
5387ea5c1f | ||
|
|
05d67b3828 | ||
|
|
38edac5069 | ||
|
|
b4b7a2fdd4 | ||
|
|
480f233e3f | ||
|
|
953621ac7e | ||
|
|
87a026fef3 | ||
|
|
8676303077 | ||
|
|
1a32ba1806 | ||
|
|
842b8e8563 | ||
|
|
7d7e96a655 |
31
.github/workflows/python-publish.yml
vendored
31
.github/workflows/python-publish.yml
vendored
@@ -19,7 +19,24 @@ jobs:
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python
|
||||
# This is because on macos systems you can install pyqlib using
|
||||
# `pip install pyqlib` installs, it does not recognize the
|
||||
# `pyqlib-<version>-cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_11_0_x86_64.whl`.
|
||||
# So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib-<veresion>-cp38-cp37m
|
||||
# `pyqlib-<version>-cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib-<veresion>-cp38-cp37m-macosx_10_15_x86_64.whl`.
|
||||
# Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'macos-11' && matrix.python-version == '3.7'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os == 'macos-11' && matrix.python-version == '3.8'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: "3.8.16"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: matrix.os != 'macos-11'
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -27,15 +44,15 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install setuptools wheel twine
|
||||
- name: Build wheel on Windows
|
||||
- name: Build wheel on ${{ matrix.os }}
|
||||
run: |
|
||||
pip install numpy
|
||||
pip install cython
|
||||
python setup.py bdist_wheel
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/*
|
||||
|
||||
@@ -55,10 +72,10 @@ jobs:
|
||||
python-version: 3.7
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install twine
|
||||
pip install twine
|
||||
- name: Build and publish
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
twine upload dist/pyqlib-*-manylinux*.whl
|
||||
|
||||
6
.github/workflows/release-drafter.yml
vendored
6
.github/workflows/release-drafter.yml
vendored
@@ -6,8 +6,14 @@ on:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
update_release_draft:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Drafts your next Release notes as Pull Requests are merged into "master"
|
||||
|
||||
9
.github/workflows/test_qlib_from_pip.yml
vendored
9
.github/workflows/test_qlib_from_pip.yml
vendored
@@ -8,13 +8,15 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: ${{ false }} # FIXME: temporarily disable... Due to we are rushing a feature
|
||||
timeout-minutes: 120
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -44,9 +46,6 @@ jobs:
|
||||
- name: Qlib installation test
|
||||
run: |
|
||||
python -m pip install pyqlib
|
||||
# Specify the numpy version because the numpy upgrade caused the CI test to fail,
|
||||
# and this line of code will be removed when the next version of qlib is released.
|
||||
python -m pip install "numpy<1.23"
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
|
||||
10
.github/workflows/test_qlib_from_source.yml
vendored
10
.github/workflows/test_qlib_from_source.yml
vendored
@@ -14,7 +14,10 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -38,10 +41,8 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0
|
||||
run: |
|
||||
python -m pip install pip==23.0
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
@@ -104,6 +105,7 @@ jobs:
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
|
||||
@@ -14,7 +14,10 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
|
||||
# Since macos-latest changed from 12.7.4 to 14.4.1,
|
||||
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
|
||||
# so we limit the macos version to macos-12.
|
||||
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
@@ -38,10 +41,8 @@ jobs:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0
|
||||
run: |
|
||||
python -m pip install pip==23.0
|
||||
python -m pip install --upgrade pip
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -22,10 +22,6 @@ dist/
|
||||
qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
qlib/finco/prompt_cache.json
|
||||
qlib/finco/finco_workspace/
|
||||
qlib/finco/knowledge/*/knowledge.pkl
|
||||
qlib/finco/knowledge/*/storage.yml
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
@@ -52,4 +48,4 @@ tags
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
.idea/
|
||||
.idea/
|
||||
@@ -5,6 +5,12 @@
|
||||
# Required
|
||||
version: 2
|
||||
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.7"
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
@@ -14,7 +20,6 @@ formats: all
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: pip
|
||||
@@ -139,7 +139,7 @@ This table demonstrates the supported Python version of `Qlib`:
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. **Conda** is suggested for managing your Python environment.
|
||||
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
1. `Qlib`Requires `tables` package, `hdf5` in tables does not support python3.9.
|
||||
@@ -172,6 +172,8 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
|
||||
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
|
||||
|
||||
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
|
||||
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
|
||||
@@ -321,7 +323,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
# Main Challenges & Solutions in Quant Research
|
||||
Quant investment is an very unique scenario with lots of key challenges to be solved.
|
||||
Quant investment is a very unique scenario with lots of key challenges to be solved.
|
||||
Currently, Qlib provides some solutions for several of them.
|
||||
|
||||
## Forecasting: Finding Valuable Signals/Patterns
|
||||
@@ -360,7 +362,7 @@ Here is a list of models built on `Qlib`.
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
The performance of each model on the `Alpha158` and `Alpha360` dataset can be found [here](examples/benchmarks/README.md).
|
||||
The performance of each model on the `Alpha158` and `Alpha360` datasets can be found [here](examples/benchmarks/README.md).
|
||||
|
||||
### Run a single model
|
||||
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
|
||||
|
||||
@@ -52,7 +52,7 @@ Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency
|
||||
Qlib Format Dataset
|
||||
-------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows. User can also use numpy to load `.bin` file to validate data.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
The price volume data look different from the actual dealing price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
|
||||
|
||||
Here are some discussions about the price adjusting of Qlib.
|
||||
@@ -140,12 +140,13 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
+-----------+-------+
|
||||
| symbol | close |
|
||||
+===========+=======+
|
||||
| SH600000 | 120 |
|
||||
+-----------+-------+
|
||||
|
||||
symbol,close
|
||||
SH600000,120
|
||||
|
||||
- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
- CSV file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -153,11 +154,13 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
symbol,date,close,open,volume
|
||||
SH600000,2020-11-01,120,121,12300000
|
||||
SH600000,2020-11-02,123,120,12300000
|
||||
+---------+------------+-------+------+----------+
|
||||
| symbol | date | close | open | volume |
|
||||
+=========+============+=======+======+==========+
|
||||
| SH600000| 2020-11-01 | 120 | 121 | 12300000 |
|
||||
+---------+------------+-------+------+----------+
|
||||
| SH600000| 2020-11-02 | 123 | 120 | 12300000 |
|
||||
+---------+------------+-------+------+----------+
|
||||
|
||||
|
||||
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
|
||||
|
||||
@@ -36,7 +36,7 @@ Name Description
|
||||
the training process of models which enable algorithms controlling the
|
||||
training process.
|
||||
|
||||
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are learnable. They are learned
|
||||
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are trainable. They are trained
|
||||
based on the `Learning Framework` layer and then applied to multiple scenarios
|
||||
in `Workflow` layer. The supported learning paradigms can be categorized into
|
||||
reinforcement learning and supervised learning. The learning framework
|
||||
@@ -51,7 +51,7 @@ Name Description
|
||||
modules. With these signals `Decision Generator` will generate the target
|
||||
trading decisions(i.e. portfolio, orders)
|
||||
If RL-based Strategies are adopted, the `Policy` is learned in a end-to-end way,
|
||||
the trading deicsions are generated directly.
|
||||
the trading decisions are generated directly.
|
||||
Decisions will be executed by `Execution Env`
|
||||
(i.e. the trading market). There may be multiple levels of `Strategy`
|
||||
and `Executor` (e.g. an *order executor trading strategy and intraday order executor*
|
||||
|
||||
@@ -16,7 +16,7 @@ This ``Quick Start`` guide tries to demonstrate
|
||||
Installation
|
||||
============
|
||||
|
||||
Users can easily intsall ``Qlib`` according to the following steps:
|
||||
Users can easily install ``Qlib`` according to the following steps:
|
||||
|
||||
- Before installing ``Qlib`` from source, users need to install some dependencies:
|
||||
|
||||
|
||||
@@ -5,3 +5,4 @@ scipy
|
||||
scikit-learn
|
||||
pandas
|
||||
tianshou
|
||||
sphinx_rtd_theme
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
experiment_name: finCo
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
@@ -10,7 +9,6 @@ data_handler_config: &data_handler_config
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
label: ["Ref($close, -21) / Ref($close, -1) - 1"]
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
@@ -29,7 +27,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -48,8 +48,7 @@ task:
|
||||
class: LinearModel
|
||||
module_path: qlib.contrib.model.linear
|
||||
kwargs:
|
||||
estimator: ridge
|
||||
alpha: 0.05
|
||||
estimator: ols
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
@@ -73,7 +72,7 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
- class: MultiPassPortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -136,7 +136,7 @@ If you want to contribute your new models, you can follow the steps below.
|
||||
- `README.md`: a brief introduction to your models
|
||||
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
|
||||
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
|
||||
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py#LL286C22-L286C22) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
|
||||
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
|
||||
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
|
||||
|
||||
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))
|
||||
|
||||
@@ -324,7 +324,6 @@ class TRAModel(Model):
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
|
||||
"""LSTM Model
|
||||
|
||||
Args:
|
||||
@@ -414,7 +413,6 @@ class PositionalEncoding(nn.Module):
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
@@ -475,7 +473,6 @@ class Transformer(nn.Module):
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction errors & latent representation as inputs,
|
||||
|
||||
@@ -27,13 +27,11 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
|
||||
2. Please follow following steps to download example data
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
wget http://fintech.msra.cn/stock_data/downloads/highfreq_orderboook_example_data.tar.bz2
|
||||
tar xf highfreq_orderboook_example_data.tar.bz2
|
||||
python ../../scripts/get_data.py download_data --target_dir . --file_name highfreq_orderbook_example_data.zip
|
||||
```
|
||||
|
||||
3. Please import the example data to your mongo db
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
python create_dataset.py initialize_library # Initialization Libraries
|
||||
python create_dataset.py import_data # Initialization Libraries
|
||||
```
|
||||
@@ -42,7 +40,6 @@ python create_dataset.py import_data # Initialization Libraries
|
||||
|
||||
After importing these data, you run `example.py` to create some high-frequency features.
|
||||
```bash
|
||||
cd examples/orderbook_data/
|
||||
pytest -s --disable-warnings example.py # If you want run all examples
|
||||
pytest -s --disable-warnings example.py::TestClass::test_exp_10 # If you want to run specific example
|
||||
```
|
||||
|
||||
2
pyproject.toml
Normal file
2
pyproject.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "numpy", "Cython"]
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.2.99"
|
||||
__version__ = "0.9.4.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
@@ -162,13 +162,15 @@ def create_account_instance(
|
||||
init_cash=init_cash,
|
||||
position_dict=position_dict,
|
||||
pos_type=pos_type,
|
||||
benchmark_config={}
|
||||
if benchmark is None
|
||||
else {
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
benchmark_config=(
|
||||
{}
|
||||
if benchmark is None
|
||||
else {
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -622,9 +622,11 @@ class Indicator:
|
||||
print(
|
||||
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq,
|
||||
trade_start_time
|
||||
if isinstance(trade_start_time, str)
|
||||
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
(
|
||||
trade_start_time
|
||||
if isinstance(trade_start_time, str)
|
||||
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
),
|
||||
fulfill_rate,
|
||||
price_advantage,
|
||||
positive_rate,
|
||||
|
||||
@@ -486,8 +486,5 @@ class QlibConfig(Config):
|
||||
return self._registered
|
||||
|
||||
|
||||
DEFAULT_QLIB_DOT_PATH = Path("~/.qlib/").expanduser()
|
||||
|
||||
|
||||
# global config
|
||||
C = QlibConfig(_default_config)
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
logger = get_module_logger("analysis", logging.INFO)
|
||||
|
||||
|
||||
class AnalyzerTemp:
|
||||
def __init__(self, recorder, output_dir=None, **kwargs):
|
||||
self.recorder = recorder
|
||||
self.output_dir = Path(output_dir) if output_dir else "./"
|
||||
|
||||
def load(self, name: str):
|
||||
"""
|
||||
It behaves the same as self.recorder.load_object.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
return self.recorder.load_object(name)
|
||||
|
||||
def analyse(self, **kwargs):
|
||||
"""
|
||||
Analyse data index, distribution .etc
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
||||
Return
|
||||
------
|
||||
The handled data.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `analysis` method.")
|
||||
|
||||
|
||||
class HFAnalyzer(AnalyzerTemp):
|
||||
"""
|
||||
This is the Signal Analysis class that generates the analysis results such as IC and IR.
|
||||
|
||||
default output image filename is "HFAnalyzerTable.jpeg"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self):
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Average Return": long_short_r.mean(),
|
||||
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
|
||||
}
|
||||
)
|
||||
|
||||
table = [[k, v] for (k, v) in metrics.items()]
|
||||
plt.table(cellText=table, loc="center")
|
||||
plt.axis("off")
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
|
||||
plt.clf()
|
||||
|
||||
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
|
||||
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
|
||||
plt.title("HFAnalyzer")
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
|
||||
return "HFAnalyzer.jpeg"
|
||||
|
||||
|
||||
class SignalAnalyzer(AnalyzerTemp):
|
||||
"""
|
||||
This is the Signal Analysis class that generates the analysis results such as IC and IR.
|
||||
|
||||
default output image filename is "signalAnalysis.jpeg"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self, dataset=None, **kwargs):
|
||||
label = self.load("label.pkl")
|
||||
|
||||
plt.hist(label)
|
||||
plt.title("SignalAnalyzer")
|
||||
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
|
||||
|
||||
return "signalAnalysis.jpeg"
|
||||
@@ -1,8 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
from qlib.utils.data import update_config
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
@@ -59,13 +57,12 @@ class Alpha360(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
data_loader: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
_data_loader = {
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
@@ -77,14 +74,12 @@ class Alpha360(DataHandlerLP):
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
if data_loader is not None:
|
||||
update_config(_data_loader, data_loader)
|
||||
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=_data_loader,
|
||||
data_loader=data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
**kwargs
|
||||
@@ -158,13 +153,12 @@ class Alpha158(DataHandlerLP):
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
data_loader: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
_data_loader = {
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
@@ -176,13 +170,11 @@ class Alpha158(DataHandlerLP):
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
if data_loader is not None:
|
||||
update_config(_data_loader, data_loader)
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=_data_loader,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
|
||||
@@ -3,6 +3,7 @@ Here is a batch of evaluation functions.
|
||||
|
||||
The interface should be redesigned carefully in the future.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Tuple
|
||||
from qlib import get_module_logger
|
||||
|
||||
@@ -511,7 +511,6 @@ class TRAModel(Model):
|
||||
|
||||
|
||||
class RNN(nn.Module):
|
||||
|
||||
"""RNN Model
|
||||
|
||||
Args:
|
||||
@@ -601,7 +600,6 @@ class PositionalEncoding(nn.Module):
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
"""Transformer Model
|
||||
|
||||
Args:
|
||||
@@ -649,7 +647,6 @@ class Transformer(nn.Module):
|
||||
|
||||
|
||||
class TRA(nn.Module):
|
||||
|
||||
"""Temporal Routing Adaptor (TRA)
|
||||
|
||||
TRA takes historical prediction errors & latent representation as inputs,
|
||||
|
||||
@@ -373,7 +373,6 @@ class WeightStrategyBase(BaseSignalStrategy):
|
||||
|
||||
|
||||
class EnhancedIndexingStrategy(WeightStrategyBase):
|
||||
|
||||
"""Enhanced Indexing Strategy
|
||||
|
||||
Enhanced indexing combines the arts of active management and passive management,
|
||||
|
||||
@@ -536,7 +536,6 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
if len(fields) == 0:
|
||||
raise ValueError("fields cannot be empty")
|
||||
fields = fields.copy()
|
||||
column_names = [str(f) for f in fields]
|
||||
return column_names
|
||||
|
||||
|
||||
@@ -318,9 +318,13 @@ class CSZScoreNorm(Processor):
|
||||
# try not modify original dataframe
|
||||
if not isinstance(self.fields_group, list):
|
||||
self.fields_group = [self.fields_group]
|
||||
for g in self.fields_group:
|
||||
cols = get_group_columns(df, g)
|
||||
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func)
|
||||
# depress warning by references:
|
||||
# https://stackoverflow.com/questions/20625582/how-to-deal-with-settingwithcopywarning-in-pandas
|
||||
# https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html#getting-and-setting-options
|
||||
with pd.option_context("mode.chained_assignment", None):
|
||||
for g in self.fields_group:
|
||||
cols = get_group_columns(df, g)
|
||||
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func)
|
||||
return df
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
|
||||
OPENAI_API_KEY=your_api_key
|
||||
|
||||
# USE_AZURE=True
|
||||
# AZURE_API_BASE=your_api_base
|
||||
# AZURE_API_VERSION=your_api_version
|
||||
|
||||
# use gpt-4 means more token but more wait time
|
||||
# MODEL=gpt-4
|
||||
# MAX_TOKENS=1600
|
||||
# MAX_RETRY=1000
|
||||
|
||||
|
||||
MAX_TOKENS=1600
|
||||
MAX_RETRY=120
|
||||
|
||||
CONTINOUS_MODE=True
|
||||
DEBUG_MODE=True
|
||||
|
||||
# TEMPERATURE=
|
||||
@@ -1,22 +0,0 @@
|
||||
# This is an experimental branch of "`FI`nancial `CO`pilot of `Qlib`"
|
||||
|
||||
## Installation
|
||||
|
||||
- To run this module, you need to first install Qlib following the instruction in [install-from-source](/README.md#install-from-source) or follow:
|
||||
|
||||
```python
|
||||
python -m pip install git+https://github.com/microsoft/qlib.git@finco
|
||||
```
|
||||
|
||||
- then you need to install other dependencies of finco:
|
||||
```python
|
||||
python -m pip install pydantic openai python-dotenv
|
||||
```
|
||||
|
||||
## Quick run
|
||||
|
||||
To run this module, you can start the workflow easily with one command:
|
||||
|
||||
```sh
|
||||
cd qlib/finco; python cli.py "your prompt"
|
||||
```
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
def get_finco_path() -> Path:
|
||||
"""
|
||||
return the template path
|
||||
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
|
||||
"""
|
||||
return DIRNAME
|
||||
@@ -1,15 +0,0 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import WorkflowManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
wm = WorkflowManager()
|
||||
wm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
@@ -1,15 +0,0 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import LearnManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
lm = LearnManager()
|
||||
lm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
@@ -1,32 +0,0 @@
|
||||
# TODO: use pydantic for other modules in Qlib
|
||||
# from pydantic_settings import BaseSettings
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class Config(SingletonBaseClass):
|
||||
"""
|
||||
This config is for fast demo purpose.
|
||||
Please use BaseSettings insetead in the future
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.temperature = 0.5 if os.getenv("TEMPERATURE") is None else float(os.getenv("TEMPERATURE"))
|
||||
self.max_tokens = 800 if os.getenv("MAX_TOKENS") is None else int(os.getenv("MAX_TOKENS"))
|
||||
|
||||
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.azure_api_base = os.getenv("AZURE_API_BASE")
|
||||
self.azure_api_version = os.getenv("AZURE_API_VERSION")
|
||||
self.model = os.getenv("MODEL") or ("gpt-35-turbo" if self.use_azure else "gpt-3.5-turbo")
|
||||
|
||||
self.max_retry = int(os.getenv("MAX_RETRY")) if os.getenv("MAX_RETRY") is not None else None
|
||||
|
||||
self.continuous_mode = (
|
||||
os.getenv("CONTINOUS_MODE") == "True" if os.getenv("CONTINOUS_MODE") is not None else False
|
||||
)
|
||||
self.debug_mode = os.getenv("DEBUG_MODE") == "True" if os.getenv("DEBUG_MODE") is not None else False
|
||||
self.workspace = os.getenv("WORKSPACE") if os.getenv("WORKSPACE") is not None else "./finco_workspace"
|
||||
self.max_past_message_include = int(os.getenv("MAX_PAST_MESSAGE_INCLUDE") or 6) // 2 * 2
|
||||
@@ -1,97 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from qlib.finco.log import FinCoLog
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from qlib.finco.utils import similarity
|
||||
|
||||
|
||||
@dataclass
|
||||
class Design:
|
||||
plan: str
|
||||
classes: str
|
||||
decision: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Exp:
|
||||
"""Experiment"""
|
||||
|
||||
# compoments
|
||||
dataset: Optional[Design] = None
|
||||
datahandler: Optional[Design] = None
|
||||
model: Optional[Design] = None
|
||||
record: Optional[Design] = None
|
||||
strategy: Optional[Design] = None
|
||||
backtest: Optional[Design] = None
|
||||
|
||||
# basic
|
||||
template: Optional[Path] = None
|
||||
|
||||
# rolling strategy. None indicates no rolling
|
||||
rolling: Optional[Literal["base", "ddgda"]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructContext:
|
||||
"""Part of the context have clear meaning and structure, so they will be saved here and can be easily retrieved and understood"""
|
||||
|
||||
# TODO: move more content in WorkflowContextManager.context to here
|
||||
workspace: Path
|
||||
exp_list: List[Exp] = field(default_factory=list) # the planned experiments
|
||||
|
||||
|
||||
class WorkflowContextManager:
|
||||
"""Context Manager stores the context of the workflow"""
|
||||
|
||||
"""All context are key value pairs which saves the input, output and status of the whole workflow"""
|
||||
|
||||
def __init__(self, workspace: Path) -> None:
|
||||
self.context = {}
|
||||
self.logger = FinCoLog()
|
||||
# this context is public
|
||||
self.struct_context = StructContext(workspace) # TODO: move more content in context to here
|
||||
self.set_context("workspace", workspace) # TODO: remove me
|
||||
|
||||
def set_context(self, key, value):
|
||||
if key in self.context:
|
||||
self.logger.warning("The key already exists in the context, the value will be overwritten")
|
||||
self.context[key] = value
|
||||
|
||||
def get_context(self, key):
|
||||
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
|
||||
if key not in self.context:
|
||||
self.logger.warning("The key doesn't exist in the context")
|
||||
return None
|
||||
return self.context[key]
|
||||
|
||||
def update_context(self, key, new_value):
|
||||
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
|
||||
if key not in self.context:
|
||||
self.logger.warning("The key doesn't exist in the context")
|
||||
self.context.update({key: new_value})
|
||||
|
||||
def get_all_context(self):
|
||||
"""return a deep copy of the context"""
|
||||
"""TODO: do we need to return a deep copy?"""
|
||||
return copy.deepcopy(self.context)
|
||||
|
||||
def retrieve(self, query: str) -> dict:
|
||||
if query in self.context.keys():
|
||||
return {query: self.context.get(query)}
|
||||
|
||||
# Note: retrieve information from context by string similarity maybe abandon in future
|
||||
scores = {}
|
||||
for k, v in self.context.items():
|
||||
scores.update({k: max(similarity(query, k), similarity(query, v))})
|
||||
max_score_key = max(scores, key=scores.get)
|
||||
return {max_score_key: self.context.get(max_score_key)}
|
||||
|
||||
def clear(self, reserve: list = None):
|
||||
if reserve is None:
|
||||
reserve = []
|
||||
|
||||
_context = {k: self.get_context(k) for k in reserve}
|
||||
self.context = _context
|
||||
@@ -1,539 +0,0 @@
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
from typing import List, Union
|
||||
import pickle
|
||||
import yaml
|
||||
|
||||
from qlib.workflow import R
|
||||
from qlib.finco.log import FinCoLog
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.utils import similarity, random_string, SingletonBaseClass
|
||||
|
||||
logger = FinCoLog()
|
||||
|
||||
|
||||
class Storage:
|
||||
"""
|
||||
This class is responsible for storage and loading of Knowledge related data.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, path: Union[str, Path], name: str = None):
|
||||
self.path = path if isinstance(path, Path) else Path(path)
|
||||
self.name = name if name else self.path.name
|
||||
self.source = None
|
||||
|
||||
# todo: get document by key
|
||||
self.documents = []
|
||||
|
||||
def add(self, documents: List):
|
||||
self.documents.extend(documents)
|
||||
self.save()
|
||||
|
||||
def load(self, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
def save(self, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `save` method.")
|
||||
|
||||
|
||||
class PickleStorage(Storage):
|
||||
"""
|
||||
This class is responsible for storage and loading of Knowledge related data in pickle format.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, path: Union[str, Path]):
|
||||
super().__init__(path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]):
|
||||
"""use pickle as the default load method"""
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def save(self, **kwargs):
|
||||
"""use pickle as the default save method"""
|
||||
Path.mkdir(self.path.parent, exist_ok=True)
|
||||
with open(self.path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
|
||||
class YamlStorage(Storage):
|
||||
"""
|
||||
This class is responsible for storage and loading of Knowledge related data in yaml format.
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_NAME = "storage.yml"
|
||||
|
||||
def __init__(self, path: Union[str, Path]):
|
||||
super().__init__(path)
|
||||
assert self.path.name, "Yaml storage should specify file name."
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
"""load data from yaml format file"""
|
||||
try:
|
||||
self.documents = yaml.safe_load(self.path.open())
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"YamlStorage: file {self.path} doesn't exist.")
|
||||
|
||||
def save(self, **kwargs):
|
||||
"""use pickle as the default save method"""
|
||||
Path.mkdir(self.path.parent, exist_ok=True, parents=True)
|
||||
with open(self.path, 'w') as f:
|
||||
yaml.dump(self.documents, f)
|
||||
|
||||
|
||||
class ExperimentStorage(Storage):
|
||||
"""
|
||||
This class is responsible for storage and loading of mlflow related data.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, exp_name, path=None):
|
||||
super().__init__(path=path)
|
||||
self.exp_name = exp_name
|
||||
self.exp = None
|
||||
self.recs = []
|
||||
self.docs = []
|
||||
|
||||
def load(self, exp_name, rec_id=None):
|
||||
recs = []
|
||||
self.exp = R.get_exp(experiment_name=exp_name)
|
||||
for r in self.exp.list_recorders(rtype=self.exp.RT_L):
|
||||
if rec_id is not None and r.id != rec_id:
|
||||
continue
|
||||
recs.append(r)
|
||||
self.recs.extend(recs)
|
||||
|
||||
|
||||
class Knowledge:
|
||||
"""
|
||||
Use to handle knowledge in finCo such as experiment and outside domain information
|
||||
"""
|
||||
|
||||
def __init__(self, storages: Union[List[Storage], Storage], name: str = None):
|
||||
self.name = name if name else random_string()
|
||||
self.workdir = Path.cwd().joinpath("knowledge")
|
||||
self.storages = [storages] if isinstance(storages, Storage) else storages
|
||||
self.knowledge = []
|
||||
|
||||
def get_storage(self, name: str):
|
||||
"""
|
||||
return first storage matched given name, else return None
|
||||
"""
|
||||
for storage in self.storages:
|
||||
if storage.name == name:
|
||||
return storage
|
||||
return None
|
||||
|
||||
def summarize(self, **kwargs):
|
||||
"""
|
||||
summarize storage data to knowledge, default knowledge is storage.documents
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
knowledge = []
|
||||
for storage in self.storages:
|
||||
knowledge.extend(storage.documents)
|
||||
self.knowledge = knowledge
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]):
|
||||
"""
|
||||
Load knowledge in memory
|
||||
use pickle as the default file type
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
""""""
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def brief(self, **kwargs):
|
||||
"""
|
||||
Return a brief summary of knowledge
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
def save(self, **kwargs):
|
||||
"""save knowledge persistently"""
|
||||
# todo: storages save index only
|
||||
Path.mkdir(self.workdir.joinpath(self.name), exist_ok=True)
|
||||
with open(self.workdir.joinpath(self.name).joinpath("knowledge.pkl"), "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
|
||||
class ExperimentKnowledge(Knowledge):
|
||||
"""
|
||||
Handle knowledge from experiments
|
||||
"""
|
||||
|
||||
def __init__(self, storages: Union[List[ExperimentStorage], ExperimentStorage]):
|
||||
super().__init__(storages=storages)
|
||||
self.storage = storages
|
||||
|
||||
def brief(self):
|
||||
docs = []
|
||||
for recorder in self.storage.recs:
|
||||
docs.append(
|
||||
{
|
||||
"exp_name": self.storage.exp.name,
|
||||
"record_info": recorder.info,
|
||||
"config": recorder.load_object("config"),
|
||||
"context_summary": recorder.load_object("context_summary"),
|
||||
}
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
class PracticeKnowledge(Knowledge):
|
||||
"""
|
||||
some template sentence for now
|
||||
"""
|
||||
|
||||
def __init__(self, storages: Union[List[YamlStorage], YamlStorage]):
|
||||
super().__init__(storages=storages, name="practice")
|
||||
|
||||
self.summarize()
|
||||
|
||||
def add(self, docs: List, storage_name: str = YamlStorage.DEFAULT_NAME):
|
||||
s = "\n".join(docs)
|
||||
logger.info(f'Add to Practice Knowledge:\n {s}')
|
||||
storage = self.get_storage(storage_name)
|
||||
if storage is None:
|
||||
storage = YamlStorage(path=self.workdir.joinpath(self.name).joinpath(storage_name))
|
||||
storage.add(documents=docs)
|
||||
self.storages.append(storage)
|
||||
else:
|
||||
storage.add(documents=docs)
|
||||
|
||||
self.summarize()
|
||||
self.save()
|
||||
|
||||
|
||||
class FinanceKnowledge(Knowledge):
|
||||
"""
|
||||
Knowledge from articles
|
||||
"""
|
||||
|
||||
def __init__(self, storages: Union[List[YamlStorage], YamlStorage]):
|
||||
super().__init__(storages=storages, name="finance")
|
||||
|
||||
storage = self.get_storage(YamlStorage.DEFAULT_NAME)
|
||||
if len(storage.documents) == 0:
|
||||
docs = self.read_files_in_directory(self.workdir.joinpath(self.name))
|
||||
self.add(docs)
|
||||
self.summarize()
|
||||
|
||||
def add(self, docs: List, storage_name: str = YamlStorage.DEFAULT_NAME):
|
||||
storage = self.get_storage(storage_name)
|
||||
if storage is None:
|
||||
storage = YamlStorage(path=self.workdir.joinpath(self.name).joinpath(storage_name))
|
||||
storage.add(documents=docs)
|
||||
self.storages.append(storage)
|
||||
else:
|
||||
storage.add(documents=docs)
|
||||
|
||||
self.summarize()
|
||||
self.save()
|
||||
|
||||
@staticmethod
|
||||
def read_files_in_directory(directory) -> List:
|
||||
"""
|
||||
read all .txt files under directory
|
||||
"""
|
||||
# todo: split article in trunks
|
||||
file_contents = []
|
||||
for file_path in Path(directory).rglob("*.txt"):
|
||||
if file_path.is_file():
|
||||
file_content = file_path.read_text(encoding="utf-8")
|
||||
file_contents.append(file_content)
|
||||
return file_contents
|
||||
|
||||
|
||||
class ExecuteKnowledge(Knowledge):
|
||||
"""
|
||||
Config and associate execution result(pass or error message). We can regard the example in prompt as pass execution
|
||||
"""
|
||||
|
||||
def __init__(self, storages: Union[List[YamlStorage], YamlStorage]):
|
||||
super().__init__(storages=storages, name="execute")
|
||||
self.summarize()
|
||||
|
||||
storage = self.get_storage(YamlStorage.DEFAULT_NAME)
|
||||
if len(storage.documents) == 0:
|
||||
docs = [{"content": "[Success]: XXXX, the results looks reasonable # Keywords: supervised learning, data"},
|
||||
{"content": "[Fail]: XXXX, it raise memory error due to YYYYY "
|
||||
"# Keywords: supervised learning, data"}]
|
||||
self.add(docs)
|
||||
self.summarize()
|
||||
|
||||
def add(self, docs: List, storage_name: str = YamlStorage.DEFAULT_NAME):
|
||||
storage = self.get_storage(storage_name)
|
||||
if storage is None:
|
||||
storage = YamlStorage(path=self.workdir.joinpath(self.name).joinpath(storage_name))
|
||||
storage.add(documents=docs)
|
||||
self.storages.append(storage)
|
||||
else:
|
||||
storage.add(documents=docs)
|
||||
|
||||
self.summarize()
|
||||
self.save()
|
||||
|
||||
|
||||
class InfrastructureKnowledge(Knowledge):
|
||||
"""
|
||||
Knowledge from sentences, docstring, and code
|
||||
"""
|
||||
|
||||
def __init__(self, storages: Union[List[YamlStorage], YamlStorage]):
|
||||
super().__init__(storages=storages, name="infrastructure")
|
||||
|
||||
storage = self.get_storage(YamlStorage.DEFAULT_NAME)
|
||||
if len(storage.documents) == 0:
|
||||
docs = self.get_functions_and_docstrings(Path(__file__).parent.parent.parent)
|
||||
docs.extend([{"docstring": "All the models can be import from `qlib.contrib.models` "
|
||||
"# Keywords: supervised learning"},
|
||||
{"docstring": "The API to run rolling models can be found in … #Keywords: control"},
|
||||
{"docstring": "Here are a list of Qlib’s available analyzers. #KEYWORDS: analysis"}])
|
||||
self.add(docs)
|
||||
self.summarize()
|
||||
|
||||
def add(self, docs: List, storage_name: str = YamlStorage.DEFAULT_NAME):
|
||||
storage = self.get_storage(storage_name)
|
||||
if storage is None:
|
||||
storage = YamlStorage(path=self.workdir.joinpath(self.name).joinpath(storage_name))
|
||||
storage.add(documents=docs)
|
||||
self.storages.append(storage)
|
||||
else:
|
||||
storage.add(documents=docs)
|
||||
|
||||
self.summarize()
|
||||
self.save()
|
||||
|
||||
def get_functions_and_docstrings(self, directory) -> List:
|
||||
"""
|
||||
get all method and docstring in .py files under directory
|
||||
|
||||
"""
|
||||
functions = []
|
||||
for py_file_path in Path(directory).rglob("*.py"):
|
||||
for _functions in self.get_functions_with_docstrings(py_file_path):
|
||||
functions.append(_functions)
|
||||
|
||||
return functions
|
||||
|
||||
@staticmethod
|
||||
def get_functions_with_docstrings(file_path):
|
||||
"""
|
||||
Extract method name and docstring using string matching method
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
functions = []
|
||||
current_func = None
|
||||
docstring = None
|
||||
for line in lines:
|
||||
if line.strip().startswith("def ") or line.strip().startswith("class "):
|
||||
func = line.strip().split(" ")[1].split("(")[0]
|
||||
if func.startswith("__"):
|
||||
continue
|
||||
if current_func is not None:
|
||||
docstring = docstring.replace('"""', "") if docstring else docstring
|
||||
functions.append({"function": current_func, "docstring": docstring})
|
||||
current_func = f"{file_path.name.split('.')[0]}.{func}"
|
||||
docstring = None
|
||||
elif current_func is not None and docstring is None and line.strip().startswith('"""'):
|
||||
docstring = line
|
||||
elif current_func is not None and docstring is not None:
|
||||
docstring += line.strip()
|
||||
if line.strip().endswith('"""'):
|
||||
docstring = docstring.replace('"""', "") if docstring else docstring
|
||||
functions.append({"function": current_func, "docstring": docstring})
|
||||
current_func = None
|
||||
docstring = None
|
||||
|
||||
return functions
|
||||
|
||||
|
||||
class Topic:
|
||||
def __init__(self, name: str, system: Template, user: Template):
|
||||
self.name = name
|
||||
self.system_prompt_template = system
|
||||
self.user_prompt_template = user
|
||||
self.docs = []
|
||||
self.knowledge = None
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def summarize(self, practice_knowlege, user_intention, target, diffrence, target_metrics):
|
||||
system_prompt = self.system_prompt_template.render(topic=self.name)
|
||||
user_prompt = self.user_prompt_template.render(
|
||||
experiment_1_info = practice_knowlege[0],
|
||||
experiment_2_info = practice_knowlege[1],
|
||||
user_intention=user_intention,
|
||||
target=target,
|
||||
diffrence=diffrence,
|
||||
target_metrics=target_metrics)
|
||||
response = APIBackend().build_messages_and_create_chat_completion(user_prompt=user_prompt, system_prompt=system_prompt)
|
||||
|
||||
self.knowledge = response
|
||||
self.docs = practice_knowlege
|
||||
self.logger.info(f"Summary of {self.name}:\n{self.knowledge}")
|
||||
|
||||
|
||||
class KnowledgeBase(SingletonBaseClass):
|
||||
"""
|
||||
Load knowledge, offer brief information of knowledge and common handle interfaces
|
||||
"""
|
||||
|
||||
KT_EXECUTE = "execute"
|
||||
KT_PRACTICE = "practice"
|
||||
KT_FINANCE = "finance"
|
||||
KT_INFRASTRUCTURE = "infrastructure"
|
||||
|
||||
def __init__(self, workdir=None):
|
||||
self.logger = FinCoLog()
|
||||
self.workdir = Path(workdir) if workdir else Path.cwd()
|
||||
|
||||
if not self.workdir.exists():
|
||||
self.logger.warning(f"{self.workdir} not exist, create empty directory.")
|
||||
Path.mkdir(self.workdir)
|
||||
|
||||
self.practice_knowledge = self.load_practice_knowledge(self.workdir)
|
||||
self.execute_knowledge = self.load_execute_knowledge(self.workdir)
|
||||
self.finance_knowledge = self.load_finance_knowledge(self.workdir)
|
||||
self.infrastructure_knowledge = self.load_infrastructure_knowledge(self.workdir)
|
||||
|
||||
def load_experiment_knowledge(self, path) -> List:
|
||||
# similar to practice knowledge, not use for now
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
knowledge = []
|
||||
path = path if path.name == "mlruns" else path.joinpath("mlruns")
|
||||
# todo: check the influence of set uri
|
||||
R.set_uri(path.as_uri())
|
||||
for exp_name in R.list_experiments():
|
||||
knowledge.append(ExperimentKnowledge(storages=ExperimentStorage(exp_name=exp_name)))
|
||||
|
||||
self.logger.plain_info(f"Load knowledge from: {path} finished.")
|
||||
return knowledge
|
||||
|
||||
def load_practice_knowledge(self, path: Path) -> PracticeKnowledge:
|
||||
self.practice_knowledge = PracticeKnowledge(
|
||||
YamlStorage(path.joinpath(Path.cwd().joinpath("knowledge")/f"{self.KT_PRACTICE}/{YamlStorage.DEFAULT_NAME}")))
|
||||
return self.practice_knowledge
|
||||
|
||||
def load_execute_knowledge(self, path: Path) -> ExecuteKnowledge:
|
||||
self.execute_knowledge = ExecuteKnowledge(
|
||||
YamlStorage(path.joinpath(Path.cwd().joinpath("knowledge")/f"{self.KT_EXECUTE}/{YamlStorage.DEFAULT_NAME}")))
|
||||
return self.execute_knowledge
|
||||
|
||||
def load_finance_knowledge(self, path: Path) -> FinanceKnowledge:
|
||||
self.finance_knowledge = FinanceKnowledge(
|
||||
YamlStorage(path.joinpath(Path.cwd().joinpath("knowledge")/f"{self.KT_FINANCE}/{YamlStorage.DEFAULT_NAME}")))
|
||||
return self.finance_knowledge
|
||||
|
||||
def load_infrastructure_knowledge(self, path: Path) -> InfrastructureKnowledge:
|
||||
self.infrastructure_knowledge = InfrastructureKnowledge(
|
||||
YamlStorage(path.joinpath(Path.cwd().joinpath("knowledge")/f"{self.KT_INFRASTRUCTURE}/{YamlStorage.DEFAULT_NAME}")))
|
||||
return self.infrastructure_knowledge
|
||||
|
||||
def get_knowledge(self, knowledge_type: str = None):
|
||||
if knowledge_type == self.KT_EXECUTE:
|
||||
knowledge = self.execute_knowledge.knowledge
|
||||
elif knowledge_type == self.KT_PRACTICE:
|
||||
knowledge = self.practice_knowledge.knowledge
|
||||
elif knowledge_type == self.KT_FINANCE:
|
||||
knowledge = self.finance_knowledge.knowledge
|
||||
elif knowledge_type == self.KT_INFRASTRUCTURE:
|
||||
knowledge = self.infrastructure_knowledge.knowledge
|
||||
else:
|
||||
knowledge = (
|
||||
self.execute_knowledge.knowledge
|
||||
+ self.practice_knowledge.knowledge
|
||||
+ self.finance_knowledge.knowledge
|
||||
+ self.infrastructure_knowledge.knowledge
|
||||
)
|
||||
return knowledge
|
||||
|
||||
def query(self, knowledge_type: str = None, content: str = None, n: int = 5):
|
||||
"""
|
||||
|
||||
@param knowledge_type: self.KT_EXECUTE, self.KT_PRACTICE or self.KT_FINANCE
|
||||
@param content: content to query KnowledgeBase
|
||||
@param n: top n knowledge to ask ChatGPT
|
||||
@return:
|
||||
"""
|
||||
# todo: replace list with persistent storage strategy such as ES/pinecone to enable
|
||||
# literal search/semantic search
|
||||
|
||||
knowledge = self.get_knowledge(knowledge_type=knowledge_type)
|
||||
if len(knowledge) == 0 or knowledge_type == "infrastructure":
|
||||
return ""
|
||||
|
||||
if knowledge_type == "practice":
|
||||
knowledge = [line for line in knowledge if line.startswith("practice_knowledge on")]
|
||||
|
||||
scores = []
|
||||
for k in knowledge:
|
||||
scores.append(similarity(str(k), content))
|
||||
sorted_indexes = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
||||
similar_n_indexes = sorted_indexes[:n]
|
||||
similar_n_docs = "\n".join([knowledge[i] for i in similar_n_indexes])
|
||||
|
||||
user_prompt_template = Template(
|
||||
"""
|
||||
query: '{{query}}'
|
||||
paragraph:
|
||||
{{paragraph}}.
|
||||
"""
|
||||
)
|
||||
user_prompt = user_prompt_template.render(query=content, paragraph=similar_n_docs)
|
||||
system_prompt = """
|
||||
You are an assistant who find relevant sentences from a long paragraph to fit user's query sentence. Relevant means the sentence might provide userful information to explain user's query sentence. People after reading the relevant sentences might have a better understanding of the query sentence.
|
||||
|
||||
Please response no less than ten sentences, if paragraph is not enough, you can return less than ten. Don't pop out irrelevant sentences. Please list the sentences in a number index instead of a whole paragraph.
|
||||
|
||||
Example input:
|
||||
query: what is the best model for image classification?
|
||||
paragraph:
|
||||
Image classification is the process of identifying and categorizing objects within an image into different groups or classes.
|
||||
Machine learning is a type of artificial intelligence that enables computers to learn and make decisions without being explicitly programmed.
|
||||
The solar system is a collection of celestial bodies, including the Sun, planets, moons, and other objects, that orbit around the Sun due to its gravitational pull.
|
||||
A car is a wheeled vehicle, typically powered by an engine or electric motor, used for transportation of people and goods.
|
||||
ResNet, short for Residual Network, is a type of deep learning architecture designed to improve the accuracy and training speed of neural networks for image recognition tasks.
|
||||
|
||||
Example output:
|
||||
1. ResNet, short for Residual Network, is a type of deep learning architecture designed to improve the accuracy and training speed of neural networks for image recognition tasks.
|
||||
2. Image classification is the process of identifying and categorizing objects within an image into different groups or classes.
|
||||
3. Machine learning is a type of artificial intelligence that enables computers to learn and make decisions without being explicitly programmed.
|
||||
"""
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
user_prompt=user_prompt, system_prompt=system_prompt
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# perhaps init KnowledgeBase in other place
|
||||
KnowledgeBase(workdir=Path.cwd().joinpath('knowledge'))
|
||||
@@ -1,47 +0,0 @@
|
||||
Quantitative investment research, often referred to as "quant," is an investment approach that uses mathematical and statistical models to analyze financial data and identify investment opportunities. This method relies heavily on computer algorithms and advanced data analysis techniques to develop trading strategies and make investment decisions.
|
||||
|
||||
One of the key aspects of quant investment research is the development of predictive models to forecast asset prices, market movements, and other financial variables. These models are typically built using historical data and refined through rigorous testing and validation processes.
|
||||
|
||||
In quant investment research, various metrics are used to evaluate the performance of a model or strategy. Some common metrics include annual return, information coefficient, maximum drawdown, and cumulative sum (cumsum) return.
|
||||
|
||||
Annual return is a measure of an investment's performance over the course of a year and is expressed as a percentage. It is an important metric to consider but can be controversial as higher annual returns are often associated with higher risks.
|
||||
|
||||
Maximum drawdown is the largest peak-to-trough decline in an investment's value over a specified period. It is a measure of the strategy's risk and can be controversial since increasing annual return often leads to a more dynamic strategy with larger drawdowns.
|
||||
|
||||
Information coefficient (IC) is a measure of the relationship between predicted returns and actual returns. A higher IC indicates a stronger relationship and suggests a more effective predictive model.
|
||||
|
||||
Cumulative sum return is the total return generated by an investment over a given period. It is useful for evaluating the overall performance of a strategy and is particularly relevant when comparing multiple strategies over the same time frame.
|
||||
|
||||
Another important aspect of quant investment research is portfolio optimization, which involves determining the optimal allocation of assets to maximize returns while minimizing risk.
|
||||
|
||||
Quantitative researchers often use techniques such as factor analysis to identify underlying drivers of asset returns. This helps them to build more robust models and better understand the relationships between various financial variables.
|
||||
|
||||
Machine learning has become increasingly popular in quant investment research, as it offers new ways to identify patterns and relationships in large datasets. Techniques such as neural networks, decision trees, and clustering algorithms are commonly used in this field.
|
||||
|
||||
Backtesting is a critical step in the development of a quantitative investment strategy. It involves applying a model or algorithm to historical data to see how it would have performed under various market conditions.
|
||||
|
||||
Risk management is a crucial component of quant investment research. Quantitative researchers must carefully consider the potential risks associated with their models and strategies and take steps to mitigate these risks.
|
||||
|
||||
In recent years, there has been a growing interest in alternative data sources for quant investment research. These can include social media sentiment, satellite imagery, and other non-traditional data sources that may provide unique insights into market trends and investment opportunities.
|
||||
|
||||
Despite its reliance on data and mathematical models, quant investment research is not without its challenges and limitations. Some critics argue that it is inherently backward-looking, as it relies on historical data to make predictions about future market movements.
|
||||
|
||||
Furthermore, quant models may struggle to adapt to rapidly changing market conditions and may be susceptible to overfitting, which occurs when a model is tailored too closely to the training data and performs poorly on new, unseen data.
|
||||
|
||||
In conclusion, quantitative investment research is a growing field that combines advanced mathematical and statistical techniques with financial expertise to develop sophisticated investment strategies and models. While it offers significant potential for generating returns and managing risk, it also comes with its own set of challenges and limitations that must be carefully considered.
|
||||
|
||||
A rolling meta-controller is a valuable tool in quantitative investment research, providing several benefits such as increased adaptability, enhanced risk management, and improved performance. By continuously updating the parameters of underlying trading models based on recent data, this approach helps strategies stay relevant in dynamic market conditions and reduces the risk of overfitting. Additionally, a rolling meta-controller can manage a diverse portfolio of trading models, leading to a more robust and diversified investment strategy.
|
||||
|
||||
However, incorporating a rolling meta-controller also comes with potential challenges, such as increased complexity, higher computational requirements, and the risk of parameter instability. It is essential for quantitative researchers to carefully weigh these factors when deciding whether to include a rolling meta-controller in their trading strategies. By balancing the potential advantages and disadvantages, researchers can optimize their strategies to achieve their specific investment objectives and risk tolerances.
|
||||
|
||||
Meta controller like rolling or DDGDA is proved to be relevant to the maximum drawdown but effect is still unknown. Because seeing the latest data is always good to avoid some new patterns recently. A static model is not good at doing it.
|
||||
|
||||
If rolling has been proved in practice knowledge that it's useful in minimizing max drawdown or maximize dropdown, always try DDGDA in the following experiments to further minimize the max drawdown.
|
||||
|
||||
DDG-DA, which stands for Data Distribution Generation for Predictable Concept Drift Adaptation, is a research paper focused on addressing the challenge of concept drift in machine learning and data mining. Concept drift occurs when the statistical properties of a target variable change over time, causing the model's performance to degrade. This is a common issue in various domains, including finance, where market conditions and economic factors can change rapidly.
|
||||
|
||||
The DDG-DA paper proposes a framework for generating synthetic datasets that simulate concept drift in a controlled and predictable manner. By creating these datasets, researchers can better understand how concept drift affects the performance of their machine learning models and develop strategies for adapting to these changes.
|
||||
|
||||
The main idea behind DDG-DA is to create synthetic data distributions that mimic the underlying data generating process while controlling the extent of concept drift. This is achieved by using a combination of data transformation techniques, such as scaling, rotation, and translation of the original data distribution.
|
||||
|
||||
By generating synthetic datasets with controlled concept drift, researchers can evaluate and compare the performance of various adaptation techniques in a more systematic and controlled manner. This can lead to the development of more robust and adaptive machine learning models that can better handle changing data distributions, ultimately improving the performance of these models in real-world applications, such as finance and investment.
|
||||
@@ -1,139 +0,0 @@
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
import openai
|
||||
import json
|
||||
import yaml
|
||||
from typing import Optional, Tuple, Union
|
||||
from qlib.finco.conf import Config
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from qlib.finco.log import FinCoLog
|
||||
from qlib.config import DEFAULT_QLIB_DOT_PATH
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ConvManager:
|
||||
"""
|
||||
This is a conversation manager of LLM
|
||||
It is for convenience of exporting conversation for debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Union[Path, str] = DEFAULT_QLIB_DOT_PATH / "llm_conv", recent_n: int = 10) -> None:
|
||||
self.path = Path(path)
|
||||
self.path.mkdir(parents=True, exist_ok=True)
|
||||
self.recent_n = recent_n
|
||||
|
||||
def _rotate_files(self):
|
||||
pairs = []
|
||||
for f in self.path.glob("*.json"):
|
||||
m = re.match(r"(\d+).json", f.name)
|
||||
if m is not None:
|
||||
n = int(m.group(1))
|
||||
pairs.append((n, f))
|
||||
pass
|
||||
pairs.sort(key=lambda x: x[0])
|
||||
for n, f in pairs[: self.recent_n][::-1]:
|
||||
f.rename(self.path / f"{n+1}.json")
|
||||
|
||||
def append(self, conv: Tuple[list, str]):
|
||||
self._rotate_files()
|
||||
json.dump(conv, open(self.path / "0.json", "w"))
|
||||
# TODO: reseve line breaks to make it more convient to edit file directly.
|
||||
|
||||
|
||||
class APIBackend(SingletonBaseClass):
|
||||
def __init__(self):
|
||||
self.cfg = Config()
|
||||
openai.api_key = self.cfg.openai_api_key
|
||||
if self.cfg.use_azure:
|
||||
openai.api_type = "azure"
|
||||
openai.api_base = self.cfg.azure_api_base
|
||||
openai.api_version = self.cfg.azure_api_version
|
||||
self.use_azure = self.cfg.use_azure
|
||||
|
||||
self.debug_mode = False
|
||||
if self.cfg.debug_mode:
|
||||
self.debug_mode = True
|
||||
cwd = os.getcwd()
|
||||
self.cache_file_location = os.path.join(cwd, "prompt_cache.json")
|
||||
self.cache = (
|
||||
json.load(open(self.cache_file_location, "r")) if os.path.exists(self.cache_file_location) else {}
|
||||
)
|
||||
|
||||
def build_messages_and_create_chat_completion(self, user_prompt, system_prompt=None, former_messages=[], **kwargs):
|
||||
"""build the messages to avoid implementing several redundant lines of code"""
|
||||
cfg = Config()
|
||||
# TODO: system prompt should always be provided. In development stage we can use default value
|
||||
if system_prompt is None:
|
||||
try:
|
||||
system_prompt = cfg.system_prompt
|
||||
except AttributeError:
|
||||
FinCoLog().warning("system_prompt is not set, using default value.")
|
||||
system_prompt = "You are an AI assistant who helps to answer user's questions about finance."
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
]
|
||||
messages.extend(former_messages[-1 * cfg.max_past_message_include :])
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
)
|
||||
fcl = FinCoLog()
|
||||
response = self.try_create_chat_completion(messages=messages, **kwargs)
|
||||
fcl.log_message(messages)
|
||||
fcl.log_response(response)
|
||||
if self.debug_mode:
|
||||
ConvManager().append((messages, response))
|
||||
return response
|
||||
|
||||
def try_create_chat_completion(self, max_retry=10, **kwargs):
|
||||
max_retry = self.cfg.max_retry if self.cfg.max_retry is not None else max_retry
|
||||
for i in range(max_retry):
|
||||
try:
|
||||
response = self.create_chat_completion(**kwargs)
|
||||
return response
|
||||
except (openai.error.RateLimitError, openai.error.Timeout, openai.error.APIError) as e:
|
||||
print(e)
|
||||
print(f"Retrying {i+1}th time...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
raise Exception(f"Failed to create chat completion after {max_retry} retries.")
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages,
|
||||
model=None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
if self.debug_mode:
|
||||
key = json.dumps(messages)
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
|
||||
if temperature is None:
|
||||
temperature = self.cfg.temperature
|
||||
if max_tokens is None:
|
||||
max_tokens = self.cfg.max_tokens
|
||||
|
||||
if self.cfg.use_azure:
|
||||
response = openai.ChatCompletion.create(
|
||||
engine=self.cfg.model,
|
||||
messages=messages,
|
||||
max_tokens=self.cfg.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.cfg.model,
|
||||
messages=messages,
|
||||
)
|
||||
resp = response.choices[0].message["content"]
|
||||
if self.debug_mode:
|
||||
self.cache[key] = resp
|
||||
json.dump(self.cache, open(self.cache_file_location, "w"))
|
||||
return resp
|
||||
@@ -1,139 +0,0 @@
|
||||
"""
|
||||
This module will base on Qlib's logger module and provides some interactive functions.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
|
||||
from typing import Dict, List
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class LogColors:
|
||||
"""
|
||||
ANSI color codes for use in console output.
|
||||
"""
|
||||
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
WHITE = "\033[97m"
|
||||
GRAY = "\033[90m"
|
||||
BLACK = "\033[30m"
|
||||
|
||||
BOLD = "\033[1m"
|
||||
ITALIC = "\033[3m"
|
||||
|
||||
END = "\033[0m"
|
||||
|
||||
@classmethod
|
||||
def get_all_colors(cls):
|
||||
names = dir(cls)
|
||||
names = [name for name in names if not name.startswith("__") and not callable(getattr(cls, name))]
|
||||
var_values = [getattr(cls, name) for name in names]
|
||||
return var_values
|
||||
|
||||
def render(self, text: str, color: str = "", style: str = ""):
|
||||
"""
|
||||
render text by input color and style. It's not recommend that input text is already rendered.
|
||||
"""
|
||||
# This method is called too frequently, which is not good.
|
||||
colors = self.get_all_colors()
|
||||
# Perhaps color and font should be distinguished here.
|
||||
if color:
|
||||
assert color in colors, f"color should be in: {colors} but now is: {color}"
|
||||
if style:
|
||||
assert style in colors, f"style should be in: {colors} but now is: {style}"
|
||||
|
||||
text = f"{color}{text}{self.END}"
|
||||
text = f"{style}{text}{self.END}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def formatting_log(logger, title="Info"):
|
||||
"""
|
||||
a context manager, print liens before and after a function
|
||||
"""
|
||||
length = {"Start": 90, "Round": 90, "Task": 90, "Info": 60, "Interact": 60, "End": 90}.get(title, 60)
|
||||
color, bold = (
|
||||
(LogColors.YELLOW, LogColors.BOLD)
|
||||
if title in ["Start", "Round", "Task", "Info", "Interact", "End"]
|
||||
else (LogColors.CYAN, "")
|
||||
)
|
||||
logger.info("")
|
||||
logger.info(f"{color}{bold}{'-'} {title} {'-' * (length - len(title))}{LogColors.END}")
|
||||
|
||||
yield
|
||||
if color == LogColors.YELLOW:
|
||||
time.sleep(2)
|
||||
logger.info("")
|
||||
|
||||
|
||||
class FinCoLog(SingletonBaseClass):
|
||||
# TODO:
|
||||
# - config to file logger and save it into workspace
|
||||
def __init__(self) -> None:
|
||||
self.logger = logging.Logger("interactive")
|
||||
# TODO: merge these with Qlib's default logger.
|
||||
# We can do the same thing by changing the default log dict of Qlib.
|
||||
# Reference: https://github.com/microsoft/qlib/blob/main/qlib/config.py#L155
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
def log_message(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
messages is some info like this [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
"""
|
||||
with formatting_log(self.logger, "GPT Messages"):
|
||||
for m in messages:
|
||||
self.logger.info(
|
||||
f"{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END} "
|
||||
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
|
||||
+ f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
|
||||
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n"
|
||||
)
|
||||
|
||||
def log_response(self, response: str):
|
||||
with formatting_log(self.logger, "GPT Response"):
|
||||
self.logger.info(f"{LogColors.CYAN}{response}{LogColors.END}\n")
|
||||
time.sleep(1)
|
||||
|
||||
# TODO:
|
||||
# It looks wierd if we only have logger
|
||||
def info(self, *args, plain=False, title="Info"):
|
||||
if plain:
|
||||
return self.plain_info(*args)
|
||||
with formatting_log(self.logger, title):
|
||||
for arg in args:
|
||||
self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}")
|
||||
|
||||
def plain_info(self, *args):
|
||||
for arg in args:
|
||||
self.logger.info(
|
||||
f"{LogColors.YELLOW}{LogColors.BOLD}Info:{LogColors.END}{LogColors.WHITE}{arg}{LogColors.END}"
|
||||
)
|
||||
|
||||
def warning(self, *args):
|
||||
for arg in args:
|
||||
self.logger.warning(f"{LogColors.BLUE}{LogColors.BOLD}Warning:{LogColors.END}{arg}")
|
||||
|
||||
def error(self, *args):
|
||||
for arg in args:
|
||||
self.logger.error(f"{LogColors.RED}{LogColors.BOLD}Error:{LogColors.END}{arg}")
|
||||
@@ -1,33 +0,0 @@
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
import yaml
|
||||
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from qlib.finco import get_finco_path
|
||||
|
||||
|
||||
class PromptTemplate(SingletonBaseClass):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
_template = yaml.load(
|
||||
open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"), Loader=yaml.FullLoader
|
||||
)
|
||||
for k, v in _template.items():
|
||||
if k == "mods":
|
||||
continue
|
||||
self.__setattr__(k, Template(v))
|
||||
|
||||
def get(self, key: str):
|
||||
return self.__dict__.get(key, Template(""))
|
||||
|
||||
def update(self, key: str, value):
|
||||
self.__setattr__(key, value)
|
||||
|
||||
def save(self, file_path: Union[str, Path]):
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
Path.mkdir(file_path.parent, exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(self.__dict__, f)
|
||||
File diff suppressed because it is too large
Load Diff
1328
qlib/finco/task.py
1328
qlib/finco/task.py
File diff suppressed because it is too large
Load Diff
@@ -1,12 +0,0 @@
|
||||
This is a set of templates that should be copied for a new project.
|
||||
|
||||
Here are the explanations for the templates folder.
|
||||
|
||||
| folder | explanations |
|
||||
|--------|------------------------------------------------------------------|
|
||||
| sl | Default configuration for supervised learning |
|
||||
| sl-cfg | Like configuration in sl. But the dataset is highly configurable |
|
||||
|
||||
|
||||
# TODO
|
||||
- [ ] [Copier](https://copier.readthedocs.io/en/stable/#quick-start) may be useful if the generation process becomes complicated
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
def get_tpl_path() -> Path:
|
||||
"""
|
||||
return the template path
|
||||
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
|
||||
"""
|
||||
return DIRNAME
|
||||
File diff suppressed because one or more lines are too long
@@ -1,71 +0,0 @@
|
||||
import json
|
||||
import string
|
||||
import random
|
||||
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from fuzzywuzzy import fuzz
|
||||
|
||||
|
||||
class SingletonMeta(type):
|
||||
_instance = None
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SingletonBaseClass(metaclass=SingletonMeta):
|
||||
"""
|
||||
Because we try to support defining Singleton with `class A(SingletonBaseClass)` instead of `A(metaclass=SingletonMeta)`
|
||||
This class becomes necessary
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Add move this class to Qlib's general utils.
|
||||
|
||||
|
||||
def parse_json(response):
|
||||
try:
|
||||
return json.loads(response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
|
||||
raise Exception(f"Failed to parse response: {response}, please report it or help us to fix it.")
|
||||
|
||||
|
||||
def similarity(text1, text2):
|
||||
text1 = text1 if isinstance(text1, str) else ""
|
||||
text2 = text2 if isinstance(text2, str) else ""
|
||||
|
||||
# Maybe we can use other similarity algorithm such as tfidf
|
||||
return fuzz.ratio(text1, text2)
|
||||
|
||||
|
||||
def random_string(length=10):
|
||||
letters = string.ascii_letters + string.digits
|
||||
return "".join(random.choice(letters) for i in range(length))
|
||||
|
||||
|
||||
def directory_tree(root_dif, max_depth=None):
|
||||
|
||||
def _directory_tree(root_dir, padding="", deep=1, max_d=None) -> List:
|
||||
_output = []
|
||||
if max_d and deep > max_d:
|
||||
return _output
|
||||
|
||||
files = sorted(root_dir.iterdir())
|
||||
for i, file in enumerate(files):
|
||||
if i == len(files) - 1:
|
||||
_output.append(padding + '└── ' + file.name)
|
||||
if file.is_dir():
|
||||
_output.extend(_directory_tree(file, padding + " ", deep=deep + 1, max_d=max_d))
|
||||
else:
|
||||
_output.append(padding + '├── ' + file.name)
|
||||
if file.is_dir():
|
||||
_output.extend(_directory_tree(file, padding + "│ ", deep=deep + 1, max_d=max_d))
|
||||
return _output
|
||||
|
||||
output = _directory_tree(root_dif, max_d=max_depth)
|
||||
return '\n'.join(output)
|
||||
@@ -1,212 +0,0 @@
|
||||
import sys
|
||||
import time
|
||||
import shutil
|
||||
from typing import List
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from qlib.finco.task import IdeaTask, SummarizeTask
|
||||
from qlib.finco.prompt_template import PromptTemplate, Template
|
||||
from qlib.finco.log import FinCoLog, LogColors
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.conf import Config
|
||||
from qlib.finco.knowledge import KnowledgeBase, Topic
|
||||
from qlib.finco.context import WorkflowContextManager
|
||||
|
||||
|
||||
# TODO: it is not necessary in current phase
|
||||
# class TaskDAG:
|
||||
# """
|
||||
# This is a Task manager. it maintains a graph and a stack stucture to manager the task
|
||||
# The reason why the DGA relationship is maintained outside instead of inside the task is that
|
||||
# - To make the creating of task simpler(user don't have to care about the relation-ship)
|
||||
# - To manage the relation ship when poping and executing the tasks is relatively easier instead of scattering them everywhere
|
||||
# """
|
||||
# def __init__(self) -> None:
|
||||
# self._finished = []
|
||||
# self._stack = []
|
||||
# self._dag = defaultdict(list) # from id(object) -> list of id(object)
|
||||
#
|
||||
# def pop(self):
|
||||
# return self._stack.pop(0)
|
||||
#
|
||||
# def push(self, task: Union[Task, List[Task]], parent: Optional[Task] = None):
|
||||
# if isinstance(task, Task):
|
||||
# task = [task]
|
||||
# if parent is not None:
|
||||
# self._dag
|
||||
#
|
||||
# def done(self) -> bool:
|
||||
# return len(self._stack) == 0
|
||||
|
||||
|
||||
class WorkflowManager:
|
||||
"""This manage the whole task automation workflow including tasks and actions"""
|
||||
|
||||
def __init__(self, workspace=None) -> None:
|
||||
self.logger = FinCoLog()
|
||||
|
||||
if workspace is None:
|
||||
self._workspace = Path.cwd() / "finco_workspace"
|
||||
else:
|
||||
self._workspace = Path(workspace)
|
||||
self.conf = Config()
|
||||
self._confirm_and_rm()
|
||||
|
||||
self.prompt_template = PromptTemplate()
|
||||
self.context = WorkflowContextManager(workspace=self._workspace)
|
||||
self.context.set_context("workspace", self._workspace)
|
||||
self.default_user_prompt = "build an A-share stock market daily portfolio in quantitative investment and minimize the maximum drawdown while maintaining return."
|
||||
|
||||
def _confirm_and_rm(self):
|
||||
# if workspace exists, please confirm and remove it. Otherwise exit.
|
||||
if self._workspace.exists() and not self.conf.continuous_mode:
|
||||
self.logger.info(title="Interact")
|
||||
flag = input(
|
||||
LogColors().render(
|
||||
f"Will be deleted: \n\t{self._workspace}\n"
|
||||
f"If you do not need to delete {self._workspace},"
|
||||
f" please change the workspace dir or rename existing files\n"
|
||||
f"Are you sure you want to delete, yes(Y/y), no (N/n):",
|
||||
color=LogColors.WHITE,
|
||||
)
|
||||
)
|
||||
if str(flag) not in ["Y", "y"]:
|
||||
sys.exit()
|
||||
else:
|
||||
# remove self._workspace
|
||||
shutil.rmtree(self._workspace)
|
||||
elif self._workspace.exists() and self.conf.continuous_mode:
|
||||
shutil.rmtree(self._workspace)
|
||||
|
||||
def set_context(self, key, value):
|
||||
"""Direct call set_context method of the context manager"""
|
||||
self.context.set_context(key, value)
|
||||
|
||||
def get_context(self) -> WorkflowContextManager:
|
||||
return self.context
|
||||
|
||||
def run(self, prompt: str) -> Path:
|
||||
"""
|
||||
The workflow manager is supposed to generate a codebase based on the prompt
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prompt: str
|
||||
the prompt user gives
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The workflow manager is expected to produce output that includes a codebase containing generated code, results, and reports in a designated location.
|
||||
The path is returned
|
||||
|
||||
The output path should follow a specific format:
|
||||
- TODO: design
|
||||
There is a summarized report where user can start from.
|
||||
"""
|
||||
|
||||
# NOTE: The following items are not designed to make the workflow very flexible.
|
||||
# - The generated tasks can't be changed after geting new information from the execution retuls.
|
||||
# - But it is required in some cases, if we want to build a external dataset, it maybe have to plan like autogpt...
|
||||
|
||||
# NOTE: default user prompt might be changed in the future and exposed to the user
|
||||
if prompt is None:
|
||||
self.set_context("user_intention", self.default_user_prompt)
|
||||
else:
|
||||
self.set_context("user_intention", prompt)
|
||||
self.logger.info(f"user_intention: {self.get_context().get_context('user_intention')}", title="Start")
|
||||
|
||||
# NOTE: list may not be enough for general task list
|
||||
task_list = [IdeaTask(), SummarizeTask()]
|
||||
task_finished = []
|
||||
while len(task_list):
|
||||
task_list_info = [str(task) for task in task_list]
|
||||
|
||||
# task list is not long, so sort it is not a big problem
|
||||
# TODO: sort the task list based on the priority of the task
|
||||
# task_list = sorted(task_list, key=lambda x: x.task_type)
|
||||
t = task_list.pop(0)
|
||||
self.logger.info(
|
||||
f"Task finished: {[str(task) for task in task_finished]}",
|
||||
f"Task in queue: {task_list_info}",
|
||||
f"Executing task: {str(t)}",
|
||||
title="Task",
|
||||
)
|
||||
|
||||
t.assign_context_manager(self.context)
|
||||
res = t.execute()
|
||||
t.summarize()
|
||||
task_finished.append(t)
|
||||
self.context.set_context("task_finished", task_finished)
|
||||
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
|
||||
|
||||
task_list = res + task_list
|
||||
|
||||
return self._workspace
|
||||
|
||||
|
||||
class LearnManager:
|
||||
__DEFAULT_TOPICS = ["RollingModel"]
|
||||
|
||||
def __init__(self):
|
||||
self.epoch = 0
|
||||
self.wm = WorkflowManager()
|
||||
|
||||
self.topics = [
|
||||
Topic(name=topic, system=self.wm.prompt_template.get(f"Topic_system"), user=self.wm.prompt_template.get(f"Topic_user")) for topic in self.__DEFAULT_TOPICS
|
||||
]
|
||||
self.knowledge_base = KnowledgeBase()
|
||||
|
||||
def run(self, prompt):
|
||||
# todo: add early stop condition
|
||||
for i in range(10):
|
||||
self.wm.logger.info(f"Round: {self.epoch+1}", title="Round")
|
||||
self.wm.run(prompt)
|
||||
self.learn()
|
||||
self.epoch += 1
|
||||
|
||||
def learn(self):
|
||||
workspace = self.wm.context.get_context("workspace")
|
||||
|
||||
def _drop_duplicate_task(_task: List):
|
||||
unique_task = {}
|
||||
for obj in _task:
|
||||
task_name = obj.__class__.__name__
|
||||
if task_name not in unique_task:
|
||||
unique_task[task_name] = obj
|
||||
return list(unique_task.values())
|
||||
|
||||
# one task maybe run several times in workflow
|
||||
task_finished = _drop_duplicate_task(self.wm.context.get_context("task_finished"))
|
||||
|
||||
user_intention = self.wm.context.get_context("user_intention")
|
||||
summary = self.wm.context.get_context("summary")
|
||||
|
||||
|
||||
target = self.wm.context.get_context(f"target")
|
||||
diffrence = self.wm.context.get_context(f"experiments_difference")
|
||||
target_metrics = self.wm.context.get_context(f"high_level_metrics")
|
||||
|
||||
[topic.summarize(self.knowledge_base.practice_knowledge.knowledge[-2:], user_intention, target, diffrence, target_metrics) for topic in self.topics]
|
||||
[self.knowledge_base.practice_knowledge.add([f"practice_knowledge on {topic.name}:\,{topic.knowledge}"]) for topic in self.topics]
|
||||
# knowledge_of_topics = [{topic.name: topic.knowledge} for topic in self.topics]
|
||||
|
||||
# for task in task_finished:
|
||||
# prompt_workflow_selection = self.wm.prompt_template.get(f"{self.__class__.__name__}_user").render(
|
||||
# summary=summary,
|
||||
# brief=knowledge_of_topics,
|
||||
# task_finished=[str(t) for t in task_finished],
|
||||
# task=task.__class__.__name__, system=task.system.render(), user_intention=user_intention
|
||||
# )
|
||||
|
||||
# response = APIBackend().build_messages_and_create_chat_completion(
|
||||
# user_prompt=prompt_workflow_selection,
|
||||
# system_prompt=self.wm.prompt_template.get(f"{self.__class__.__name__}_system").render(),
|
||||
# )
|
||||
|
||||
# # todo: response assertion
|
||||
# task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
|
||||
|
||||
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
|
||||
self.wm.context.clear(reserve=["workspace"])
|
||||
@@ -30,7 +30,6 @@ class Ensemble:
|
||||
|
||||
|
||||
class SingleKeyEnsemble(Ensemble):
|
||||
|
||||
"""
|
||||
Extract the object if there is only one key and value in the dict. Make the result more readable.
|
||||
{Only key: Only value} -> Only value
|
||||
@@ -64,7 +63,6 @@ class SingleKeyEnsemble(Ensemble):
|
||||
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
|
||||
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
|
||||
|
||||
@@ -247,9 +247,7 @@ class ShrinkCovEstimator(RiskModel):
|
||||
v1 = y.T.dot(z) / t - cov_mkt[:, None] * S
|
||||
roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt
|
||||
v3 = z.T.dot(z) / t - var_mkt * S
|
||||
roff3 = (
|
||||
np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
|
||||
)
|
||||
roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
|
||||
roff = 2 * roff1 - roff3
|
||||
rho = rdiag + roff
|
||||
|
||||
|
||||
@@ -168,7 +168,9 @@ class TrainingVessel(TrainingVesselBase):
|
||||
self.policy.train()
|
||||
|
||||
with vector_env.collector_guard():
|
||||
collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
|
||||
collector = Collector(
|
||||
self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)), exploration_noise=True
|
||||
)
|
||||
|
||||
# Number of episodes collected in each training iteration can be overridden by fast dev run.
|
||||
if self.trainer.fast_dev_run is not None:
|
||||
|
||||
@@ -25,7 +25,12 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Callable
|
||||
from packaging import version
|
||||
from .file import get_or_create_path, save_multiple_parts_file, unpack_archive_with_buffer, get_tmp_file_with_buffer
|
||||
from .file import (
|
||||
get_or_create_path,
|
||||
save_multiple_parts_file,
|
||||
unpack_archive_with_buffer,
|
||||
get_tmp_file_with_buffer,
|
||||
)
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
|
||||
@@ -37,7 +42,12 @@ is_deprecated_lexsorted_pandas = version.parse(pd.__version__) > version.parse("
|
||||
#################### Server ####################
|
||||
def get_redis_connection():
|
||||
"""get redis connection instance."""
|
||||
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)
|
||||
return redis.StrictRedis(
|
||||
host=C.redis_host,
|
||||
port=C.redis_port,
|
||||
db=C.redis_task_db,
|
||||
password=C.redis_password,
|
||||
)
|
||||
|
||||
|
||||
#################### Data ####################
|
||||
@@ -96,7 +106,14 @@ def get_period_offset(first_year, period, quarterly):
|
||||
return offset
|
||||
|
||||
|
||||
def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None):
|
||||
def read_period_data(
|
||||
index_path,
|
||||
data_path,
|
||||
period,
|
||||
cur_date_int: int,
|
||||
quarterly,
|
||||
last_period_index: int = None,
|
||||
):
|
||||
"""
|
||||
At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803).
|
||||
Only the updating info before cur_date or at cur_date will be used.
|
||||
@@ -273,7 +290,10 @@ def parse_field(field):
|
||||
# \uff09 -> )
|
||||
chinese_punctuation_regex = r"\u3001\uff1a\uff08\uff09"
|
||||
for pattern, new in [
|
||||
(rf"\$\$([\w{chinese_punctuation_regex}]+)", r'PFeature("\1")'), # $$ must be before $
|
||||
(
|
||||
rf"\$\$([\w{chinese_punctuation_regex}]+)",
|
||||
r'PFeature("\1")',
|
||||
), # $$ must be before $
|
||||
(rf"\$([\w{chinese_punctuation_regex}]+)", r'Feature("\1")'),
|
||||
(r"(\w+\s*)\(", r"Operators.\1("),
|
||||
]: # Features # Operators
|
||||
@@ -383,7 +403,14 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
|
||||
def get_date_by_shift(
|
||||
trading_date,
|
||||
shift,
|
||||
future=False,
|
||||
clip_shift=True,
|
||||
freq="day",
|
||||
align: Optional[str] = None,
|
||||
):
|
||||
"""get trading date with shift bias will cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -569,7 +596,38 @@ def exists_qlib_data(qlib_dir):
|
||||
# check instruments
|
||||
code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()))
|
||||
_instrument = instruments_dir.joinpath("all.txt")
|
||||
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
|
||||
# Removed two possible ticker names "NA" and "NULL" from the default na_values list for column 0
|
||||
miss_code = set(
|
||||
pd.read_csv(
|
||||
_instrument,
|
||||
sep="\t",
|
||||
header=None,
|
||||
keep_default_na=False,
|
||||
na_values={
|
||||
0: [
|
||||
" ",
|
||||
"#N/A",
|
||||
"#N/A N/A",
|
||||
"#NA",
|
||||
"-1.#IND",
|
||||
"-1.#QNAN",
|
||||
"-NaN",
|
||||
"-nan",
|
||||
"1.#IND",
|
||||
"1.#QNAN",
|
||||
"<NA>",
|
||||
"N/A",
|
||||
"NaN",
|
||||
"None",
|
||||
"n/a",
|
||||
"nan",
|
||||
"null ",
|
||||
]
|
||||
},
|
||||
)
|
||||
.loc[:, 0]
|
||||
.apply(str.lower)
|
||||
) - set(code_names)
|
||||
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
|
||||
return False
|
||||
|
||||
|
||||
@@ -206,9 +206,6 @@ def find_all_classes(module_path: Union[str, ModuleType], cls: type) -> List[typ
|
||||
>>> from qlib.data.dataset.handler import DataHandler
|
||||
>>> find_all_classes("qlib.contrib.data.handler", DataHandler)
|
||||
[<class 'qlib.contrib.data.handler.Alpha158'>, <class 'qlib.contrib.data.handler.Alpha158vwap'>, <class 'qlib.contrib.data.handler.Alpha360'>, <class 'qlib.contrib.data.handler.Alpha360vwap'>, <class 'qlib.data.dataset.handler.DataHandlerLP'>]
|
||||
>>> from qlib.contrib.rolling.base import Rolling
|
||||
>>> find_all_classes("qlib.contrib.rolling", Rolling)
|
||||
[<class 'qlib.contrib.rolling.base.Rolling'>, <class 'qlib.contrib.rolling.ddgda.DDGDA'>]
|
||||
|
||||
TODO:
|
||||
- skip import error
|
||||
@@ -223,7 +220,7 @@ def find_all_classes(module_path: Union[str, ModuleType], cls: type) -> List[typ
|
||||
|
||||
def _append_cls(obj):
|
||||
# Leverage the closure trick to reuse code
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj not in cls_list:
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and cls not in cls_list:
|
||||
cls_list.append(obj)
|
||||
|
||||
for attr in dir(mod):
|
||||
|
||||
@@ -90,7 +90,6 @@ class OnlineStrategy:
|
||||
|
||||
|
||||
class RollingStrategy(OnlineStrategy):
|
||||
|
||||
"""
|
||||
This example strategy always uses the latest rolling model sas online models.
|
||||
"""
|
||||
|
||||
@@ -4,8 +4,10 @@
|
||||
import logging
|
||||
import warnings
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from pprint import pprint
|
||||
from typing import Union, List, Optional
|
||||
from typing import Union, List, Optional, Dict
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from ..contrib.evaluate import risk_analysis, indicator_analysis
|
||||
@@ -17,8 +19,9 @@ from ..log import get_module_logger
|
||||
from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift
|
||||
from ..utils.time import Freq
|
||||
from ..utils.data import deepcopy_basic_type
|
||||
from ..utils.exceptions import QlibException
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
from qlib.contrib.analyzer import HFAnalyzer, SignalAnalyzer
|
||||
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
@@ -155,9 +158,6 @@ class RecordTemp:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
|
||||
def analyse(self):
|
||||
raise NotImplementedError(f"Please implement the `analysis` method.")
|
||||
|
||||
|
||||
class SignalRecord(RecordTemp):
|
||||
"""
|
||||
@@ -233,9 +233,16 @@ class ACRecordTemp(RecordTemp):
|
||||
except FileNotFoundError:
|
||||
logger.warning("The dependent data does not exists. Generation skipped.")
|
||||
return
|
||||
return self._generate(*args, **kwargs)
|
||||
artifact_dict = self._generate(*args, **kwargs)
|
||||
if isinstance(artifact_dict, dict):
|
||||
self.save(**artifact_dict)
|
||||
return artifact_dict
|
||||
|
||||
def _generate(self, *args, **kwargs):
|
||||
def _generate(self, *args, **kwargs) -> Dict[str, object]:
|
||||
"""
|
||||
Run the concrete generating task, return the dictionary of the generated results.
|
||||
The caller method will save the results to the recorder.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_generate` method")
|
||||
|
||||
|
||||
@@ -339,8 +346,8 @@ class SigAnaRecord(ACRecordTemp):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
return objects
|
||||
|
||||
def list(self):
|
||||
paths = ["ic.pkl", "ric.pkl"]
|
||||
@@ -471,17 +478,18 @@ class PortAnaRecord(ACRecordTemp):
|
||||
if self.backtest_config["end_time"] is None:
|
||||
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
|
||||
|
||||
artifact_objects = {}
|
||||
# custom strategy and get backtest
|
||||
portfolio_metric_dict, indicator_dict = normal_backtest(
|
||||
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
|
||||
)
|
||||
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
|
||||
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
|
||||
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal})
|
||||
artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
|
||||
for _freq, indicators_normal in indicator_dict.items():
|
||||
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
|
||||
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
|
||||
artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
|
||||
artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq not in portfolio_metric_dict:
|
||||
@@ -503,7 +511,7 @@ class PortAnaRecord(ACRecordTemp):
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -528,12 +536,13 @@ class PortAnaRecord(ACRecordTemp):
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
|
||||
pprint(analysis_df)
|
||||
return artifact_objects
|
||||
|
||||
def list(self):
|
||||
list_path = []
|
||||
@@ -556,3 +565,124 @@ class PortAnaRecord(ACRecordTemp):
|
||||
else:
|
||||
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
|
||||
return list_path
|
||||
|
||||
|
||||
class MultiPassPortAnaRecord(PortAnaRecord):
|
||||
"""
|
||||
This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class.
|
||||
|
||||
If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random.
|
||||
The shuffle_init_score will only works when the signal is used as <PRED> placeholder. The placeholder will be replaced by pred.pkl saved in recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder : Recorder
|
||||
The recorder used to save the backtest results.
|
||||
pass_num : int
|
||||
The number of backtest passes.
|
||||
shuffle_init_score : bool
|
||||
Whether to shuffle the prediction score of the first backtest date.
|
||||
"""
|
||||
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
recorder : Recorder
|
||||
The recorder used to save the backtest results.
|
||||
pass_num : int
|
||||
The number of backtest passes.
|
||||
shuffle_init_score : bool
|
||||
Whether to shuffle the prediction score of the first backtest date.
|
||||
"""
|
||||
self.pass_num = pass_num
|
||||
self.shuffle_init_score = shuffle_init_score
|
||||
|
||||
super().__init__(recorder, **kwargs)
|
||||
|
||||
# Save original strategy so that pred df can be replaced in next generate
|
||||
self.original_strategy = deepcopy_basic_type(self.strategy_config)
|
||||
if not isinstance(self.original_strategy, dict):
|
||||
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict")
|
||||
if "signal" not in self.original_strategy.get("kwargs", {}):
|
||||
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter")
|
||||
|
||||
def random_init(self):
|
||||
pred_df = self.load("pred.pkl")
|
||||
|
||||
all_pred_dates = pred_df.index.get_level_values("datetime")
|
||||
bt_start_date = pd.to_datetime(self.backtest_config.get("start_time"))
|
||||
if bt_start_date is None:
|
||||
first_bt_pred_date = all_pred_dates.min()
|
||||
else:
|
||||
first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min()
|
||||
|
||||
# Shuffle the first backtest date's pred score
|
||||
first_date_score = pred_df.loc[first_bt_pred_date]["score"]
|
||||
np.random.shuffle(first_date_score.values)
|
||||
|
||||
# Use shuffled signal as the strategy signal
|
||||
self.strategy_config = deepcopy_basic_type(self.original_strategy)
|
||||
self.strategy_config["kwargs"]["signal"] = pred_df
|
||||
|
||||
def _generate(self, **kwargs):
|
||||
risk_analysis_df_map = {}
|
||||
|
||||
# Collect each frequency's analysis df as df list
|
||||
for i in trange(self.pass_num):
|
||||
if self.shuffle_init_score:
|
||||
self.random_init()
|
||||
|
||||
# Not check for cache file list
|
||||
single_run_artifacts = super()._generate(**kwargs)
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, [])
|
||||
risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list
|
||||
|
||||
analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"]
|
||||
analysis_df["run_id"] = i
|
||||
risk_analysis_df_list.append(analysis_df)
|
||||
|
||||
result_artifacts = {}
|
||||
# Concat df list
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])
|
||||
|
||||
# Calculate return and information ratio's mean, std and mean/std
|
||||
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply(
|
||||
lambda x: pd.Series(
|
||||
{"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()}
|
||||
)
|
||||
)
|
||||
|
||||
# Only look at "annualized_return" and "information_ratio"
|
||||
multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[
|
||||
(slice(None), ["annualized_return", "information_ratio"]), :
|
||||
]
|
||||
pprint(multi_pass_port_analysis_df)
|
||||
|
||||
# Save new df
|
||||
result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df})
|
||||
|
||||
# Log metrics
|
||||
metrics = flatten_dict(
|
||||
{
|
||||
"mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(),
|
||||
"std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(),
|
||||
"mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(),
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
return result_artifacts
|
||||
|
||||
def list(self):
|
||||
list_path = []
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(f"multi_pass_port_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
|
||||
return list_path
|
||||
|
||||
81
scripts/data_collector/baostock_5min/README.md
Normal file
81
scripts/data_collector/baostock_5min/README.md
Normal file
@@ -0,0 +1,81 @@
|
||||
## Collector Data
|
||||
|
||||
### Get Qlib data(`bin file`)
|
||||
|
||||
- get data: `python scripts/get_data.py qlib_data`
|
||||
- parameters:
|
||||
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data_5min*
|
||||
- `version`: dataset version, value from [`v2`], by default `v2`
|
||||
- `v2` end date is *2022-12*
|
||||
- `interval`: `5min`
|
||||
- `region`: `hs300`
|
||||
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
|
||||
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
|
||||
- examples:
|
||||
```bash
|
||||
# hs300 5min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/hs300_data_5min --region hs300 --interval 5min
|
||||
```
|
||||
|
||||
### Collector *Baostock high frequency* data to qlib
|
||||
> collector *Baostock high frequency* data and *dump* into `qlib` format.
|
||||
> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data.
|
||||
1. download data to csv: `python scripts/data_collector/baostock_5min/collector.py download_data`
|
||||
|
||||
This will download the raw data such as date, symbol, open, high, low, close, volume, amount, adjustflag from baostock to a local directory. One file per symbol.
|
||||
- parameters:
|
||||
- `source_dir`: save the directory
|
||||
- `interval`: `5min`
|
||||
- `region`: `HS300`
|
||||
- `start`: start datetime, by default *None*
|
||||
- `end`: end datetime, by default *None*
|
||||
- examples:
|
||||
```bash
|
||||
# cn 5min data
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300
|
||||
```
|
||||
2. normalize data: `python scripts/data_collector/baostock_5min/collector.py normalize_data`
|
||||
|
||||
This will:
|
||||
1. Normalize high, low, close, open price using adjclose.
|
||||
2. Normalize the high, low, close, open price so that the first valid trading date's close price is 1.
|
||||
- parameters:
|
||||
- `source_dir`: csv directory
|
||||
- `normalize_dir`: result directory
|
||||
- `interval`: `5min`
|
||||
> if **`interval == 5min`**, `qlib_data_1d_dir` cannot be `None`
|
||||
- `region`: `HS300`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
|
||||
- `qlib_data_1d_dir`: qlib directory(1d data)
|
||||
if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;
|
||||
```
|
||||
# qlib_data_1d can be obtained like this:
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3
|
||||
```
|
||||
- examples:
|
||||
```bash
|
||||
# normalize 5min cn
|
||||
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min
|
||||
```
|
||||
3. dump data: `python scripts/dump_bin.py dump_all`
|
||||
|
||||
This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory.
|
||||
|
||||
- parameters:
|
||||
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
|
||||
- `qlib_dir`: qlib(dump) data director
|
||||
- `freq`: transaction frequency, by default `day`
|
||||
> `freq_map = {1d:day, 5mih: 5min}`
|
||||
- `max_workers`: number of threads, by default *16*
|
||||
- `include_fields`: dump fields, by default `""`
|
||||
- `exclude_fields`: fields not dumped, by default `"""
|
||||
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
|
||||
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
|
||||
- `date_field_name`: column *name* identifying time in csv files, by default `date`
|
||||
- examples:
|
||||
```bash
|
||||
# dump 5min cn
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol
|
||||
```
|
||||
328
scripts/data_collector/baostock_5min/collector.py
Normal file
328
scripts/data_collector/baostock_5min/collector.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import sys
|
||||
import copy
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Iterable, List
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily, calc_adjusted_price
|
||||
|
||||
|
||||
class BaostockCollectorHS3005min(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="5min",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: int = None,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [5min], default 5min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: int
|
||||
check data length, by default None
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
bs.login()
|
||||
super(BaostockCollectorHS3005min, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
def get_trade_calendar(self):
|
||||
_format = "%Y-%m-%d"
|
||||
start = self.start_datetime.strftime(_format)
|
||||
end = self.end_datetime.strftime(_format)
|
||||
rs = bs.query_trade_dates(start_date=start, end_date=end)
|
||||
calendar_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
calendar_list.append(rs.get_row_data())
|
||||
calendar_df = pd.DataFrame(calendar_list, columns=rs.fields)
|
||||
trade_calendar_df = calendar_df[~calendar_df["is_trading_day"].isin(["0"])]
|
||||
return trade_calendar_df["calendar_date"].values
|
||||
|
||||
@staticmethod
|
||||
def process_interval(interval: str):
|
||||
if interval == "1d":
|
||||
return {"interval": "d", "fields": "date,code,open,high,low,close,volume,amount,adjustflag"}
|
||||
if interval == "5min":
|
||||
return {"interval": "5", "fields": "date,time,code,open,high,low,close,volume,amount,adjustflag"}
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
df = self.get_data_from_remote(
|
||||
symbol=symbol, interval=interval, start_datetime=start_datetime, end_datetime=end_datetime
|
||||
)
|
||||
df.columns = ["date", "time", "symbol", "open", "high", "low", "close", "volume", "amount", "adjustflag"]
|
||||
df["time"] = pd.to_datetime(df["time"], format="%Y%m%d%H%M%S%f")
|
||||
df["date"] = df["time"].dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
df["date"] = df["date"].map(lambda x: pd.Timestamp(x) - pd.Timedelta(minutes=5))
|
||||
df.drop(["time"], axis=1, inplace=True)
|
||||
df["symbol"] = df["symbol"].map(lambda x: str(x).replace(".", "").upper())
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(
|
||||
symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
df = pd.DataFrame()
|
||||
rs = bs.query_history_k_data_plus(
|
||||
symbol,
|
||||
BaostockCollectorHS3005min.process_interval(interval=interval)["fields"],
|
||||
start_date=str(start_datetime.strftime("%Y-%m-%d")),
|
||||
end_date=str(end_datetime.strftime("%Y-%m-%d")),
|
||||
frequency=BaostockCollectorHS3005min.process_interval(interval=interval)["interval"],
|
||||
adjustflag="3",
|
||||
)
|
||||
if rs.error_code == "0" and len(rs.data) > 0:
|
||||
data_list = rs.data
|
||||
columns = rs.fields
|
||||
df = pd.DataFrame(data_list, columns=columns)
|
||||
return df
|
||||
|
||||
def get_hs300_symbols(self) -> List[str]:
|
||||
hs300_stocks = []
|
||||
trade_calendar = self.get_trade_calendar()
|
||||
with tqdm(total=len(trade_calendar)) as p_bar:
|
||||
for date in trade_calendar:
|
||||
rs = bs.query_hs300_stocks(date=date)
|
||||
while rs.error_code == "0" and rs.next():
|
||||
hs300_stocks.append(rs.get_row_data())
|
||||
p_bar.update()
|
||||
return sorted({e[1] for e in hs300_stocks})
|
||||
|
||||
def get_instrument_list(self):
|
||||
logger.info("get HS stock symbols......")
|
||||
symbols = self.get_hs300_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def normalize_symbol(self, symbol: str):
|
||||
return str(symbol).replace(".", "").upper()
|
||||
|
||||
|
||||
class BaostockNormalizeHS3005min(BaseNormalize):
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
AM_RANGE = ("09:30:00", "11:29:00")
|
||||
PM_RANGE = ("13:00:00", "14:59:00")
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: Normalised to 5min using local 1d data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
bs.login()
|
||||
qlib.init(provider_uri=qlib_data_1d_dir)
|
||||
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name)
|
||||
|
||||
@staticmethod
|
||||
def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:
|
||||
df = df.copy()
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None:
|
||||
_tmp_shift_series.iloc[0] = float(last_close)
|
||||
change_series = _tmp_series / _tmp_shift_series - 1
|
||||
return change_series
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return self.generate_5min_from_daily(self.calendar_list_1d)
|
||||
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
|
||||
if calendar_list_1d is None:
|
||||
calendar_list_1d = self._get_1d_calendar_list()
|
||||
setattr(self, "_calendar_list_1d", calendar_list_1d)
|
||||
return calendar_list_1d
|
||||
|
||||
@staticmethod
|
||||
def normalize_baostock(
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
last_close: float = None,
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name]
|
||||
columns = copy.deepcopy(BaostockNormalizeHS3005min.COLUMNS)
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
pd.DataFrame(index=calendar_list)
|
||||
.loc[pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(days=1)]
|
||||
.index
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan
|
||||
|
||||
df["change"] = BaostockNormalizeHS3005min.calc_change(df, last_close)
|
||||
|
||||
columns += ["change"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
|
||||
df[symbol_field_name] = symbol
|
||||
df.index.names = [date_field_name]
|
||||
return df.reset_index()
|
||||
|
||||
def generate_5min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="5min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = calc_adjusted_price(
|
||||
df=df,
|
||||
_date_field_name=self._date_field_name,
|
||||
_symbol_field_name=self._symbol_field_name,
|
||||
frequence="5min",
|
||||
_1d_data_all=self.all_1d_data,
|
||||
)
|
||||
return df
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
df = self.normalize_baostock(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
||||
# adjusted price
|
||||
df = self.adjusted_price(df)
|
||||
return df
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="5min", region="HS300"):
|
||||
"""
|
||||
Changed the default value of: scripts.data_collector.base.BaseRun.
|
||||
"""
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"BaostockCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"BaostockNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0.5,
|
||||
start=None,
|
||||
end=None,
|
||||
check_data_length=None,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Baostock
|
||||
|
||||
Notes
|
||||
-----
|
||||
check_data_length, example:
|
||||
hs300 5min, a week: 4 * 60 * 5
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get hs300 5min data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300
|
||||
"""
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
end_date: str = None,
|
||||
qlib_data_1d_dir: str = None,
|
||||
):
|
||||
"""normalize data
|
||||
|
||||
Attention
|
||||
---------
|
||||
qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3
|
||||
or:
|
||||
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min
|
||||
"""
|
||||
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
|
||||
raise ValueError(
|
||||
"If normalize 5min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
|
||||
)
|
||||
super(Run, self).normalize_data(
|
||||
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
13
scripts/data_collector/baostock_5min/requirements.txt
Normal file
13
scripts/data_collector/baostock_5min/requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
loguru
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
yahooquery
|
||||
joblib
|
||||
beautifulsoup4
|
||||
bs4
|
||||
soupsieve
|
||||
baostock
|
||||
@@ -8,7 +8,7 @@ import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type, Iterable
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
@@ -290,7 +290,7 @@ class Normalize:
|
||||
|
||||
# some symbol_field values such as TRUE, NA are decoded as True(bool), NaN(np.float) by pandas default csv parsing.
|
||||
# manually defines dtype and na_values of the symbol_field.
|
||||
default_na = pd._libs.parsers.STR_NA_VALUES
|
||||
default_na = pd._libs.parsers.STR_NA_VALUES # pylint: disable=I1101
|
||||
symbol_na = default_na.copy()
|
||||
symbol_na.remove("NA")
|
||||
columns = pd.read_csv(file_path, nrows=0).columns
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from functools import partial
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
import datetime
|
||||
|
||||
import fire
|
||||
@@ -98,7 +97,7 @@ class IBOVIndex(IndexBase):
|
||||
now = datetime.datetime.now()
|
||||
current_year = now.year
|
||||
current_month = now.month
|
||||
for year in [item for item in range(init_year, current_year)]:
|
||||
for year in [item for item in range(init_year, current_year)]: # pylint: disable=R1721
|
||||
for el in four_months_period:
|
||||
self.years_4_month_periods.append(str(year) + "_" + el)
|
||||
# For current year the logic must be a little different
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
async-generator==1.10
|
||||
attrs==21.4.0
|
||||
certifi==2021.10.8
|
||||
certifi==2022.12.7
|
||||
cffi==1.15.0
|
||||
charset-normalizer==2.0.12
|
||||
cryptography==36.0.1
|
||||
@@ -8,7 +8,7 @@ fire==0.4.0
|
||||
h11==0.13.0
|
||||
idna==3.3
|
||||
loguru==0.6.0
|
||||
lxml==4.8.0
|
||||
lxml==4.9.1
|
||||
multitasking==0.0.10
|
||||
numpy==1.22.2
|
||||
outcome==1.1.0
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import re
|
||||
import abc
|
||||
import sys
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
from typing import List, Iterable
|
||||
from pathlib import Path
|
||||
@@ -39,7 +38,7 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None):
|
||||
if exclude_status is None:
|
||||
exclude_status = []
|
||||
method_func = getattr(requests, method)
|
||||
_resp = method_func(url, headers=REQ_HEADERS)
|
||||
_resp = method_func(url, headers=REQ_HEADERS, timeout=None)
|
||||
_status = _resp.status_code
|
||||
if _status not in exclude_status and _status != 200:
|
||||
raise ValueError(f"response status: {_status}, url={url}")
|
||||
@@ -397,14 +396,7 @@ class CSI500Index(CSIIndex):
|
||||
today = pd.Timestamp.now()
|
||||
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
|
||||
ret_list = []
|
||||
col = ["date", "symbol", "code_name"]
|
||||
for date in tqdm(date_range, desc="Download CSI500"):
|
||||
rs = bs.query_zz500_stocks(date=str(date))
|
||||
zz500_stocks = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
zz500_stocks.append(rs.get_row_data())
|
||||
result = pd.DataFrame(zz500_stocks, columns=col)
|
||||
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
|
||||
result = self.get_data_from_baostock(date)
|
||||
ret_list.append(result[["date", "symbol"]])
|
||||
bs.logout()
|
||||
|
||||
@@ -5,3 +5,5 @@ pandas
|
||||
lxml
|
||||
loguru
|
||||
tqdm
|
||||
yahooquery
|
||||
openpyxl
|
||||
|
||||
@@ -5,7 +5,6 @@ from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
@@ -31,15 +30,15 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
crypto symbols in given exchanges list of coingecko
|
||||
"""
|
||||
global _CG_CRYPTO_SYMBOLS
|
||||
global _CG_CRYPTO_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_coingecko():
|
||||
try:
|
||||
cg = CoinGeckoAPI()
|
||||
resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd"))
|
||||
except:
|
||||
raise ValueError("request error")
|
||||
except Exception as e:
|
||||
raise ValueError("request error") from e
|
||||
try:
|
||||
_symbols = resp["id"].to_list()
|
||||
except Exception as e:
|
||||
|
||||
@@ -107,7 +107,7 @@ class FundCollector(BaseCollector):
|
||||
url = INDEX_BENCH_URL.format(
|
||||
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
|
||||
)
|
||||
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
|
||||
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}, timeout=None)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
@@ -116,8 +116,8 @@ class FundCollector(BaseCollector):
|
||||
|
||||
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
|
||||
SYType = data["Data"]["SYType"]
|
||||
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
|
||||
raise Exception("The fund contains 每*份收益")
|
||||
if SYType in {"每万份收益", "每百份收益", "每百万份收益"}:
|
||||
raise ValueError("The fund contains 每*份收益")
|
||||
|
||||
# TODO: should we sort the value by datetime?
|
||||
_resp = pd.DataFrame(data["Data"]["LSJZList"])
|
||||
|
||||
@@ -53,7 +53,7 @@ class CollectorFutureCalendar:
|
||||
return datetime_d.strftime(self.calendar_format)
|
||||
|
||||
def write_calendar(self, calendar: Iterable):
|
||||
calendars_list = list(map(lambda x: self._format_datetime(x), sorted(set(self.calendar_list + calendar))))
|
||||
calendars_list = [self._format_datetime(x) for x in sorted(set(self.calendar_list + calendar))]
|
||||
np.savetxt(self.future_path, calendars_list, fmt="%s", encoding="utf-8")
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import abc
|
||||
from functools import partial
|
||||
import sys
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
@@ -113,7 +112,7 @@ class WIKIIndex(IndexBase):
|
||||
return _calendar_list
|
||||
|
||||
def _request_new_companies(self) -> requests.Response:
|
||||
resp = requests.get(self._target_url)
|
||||
resp = requests.get(self._target_url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {self._target_url}")
|
||||
|
||||
@@ -164,7 +163,7 @@ class NASDAQ100Index(WIKIIndex):
|
||||
df = pd.read_pickle(cache_path)
|
||||
else:
|
||||
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
|
||||
resp = requests.post(url)
|
||||
resp = requests.post(url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {url}")
|
||||
df = pd.DataFrame(resp.json()["aaData"])
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import copy
|
||||
import importlib
|
||||
import time
|
||||
import bisect
|
||||
@@ -68,7 +69,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url).json()["data"]["klines"]
|
||||
_value_list = requests.get(url, timeout=None).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
@@ -85,12 +86,14 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
def _get_calendar(month):
|
||||
_cal = []
|
||||
try:
|
||||
resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json()
|
||||
resp = requests.get(
|
||||
SZSE_CALENDAR_URL.format(month=month, random=random.random), timeout=None
|
||||
).json()
|
||||
for _r in resp["data"]:
|
||||
if int(_r["jybz"]):
|
||||
_cal.append(pd.Timestamp(_r["jyrq"]))
|
||||
except Exception as e:
|
||||
raise ValueError(f"{month}-->{e}")
|
||||
raise ValueError(f"{month}-->{e}") from e
|
||||
return _cal
|
||||
|
||||
month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M")
|
||||
@@ -109,7 +112,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
|
||||
def return_date_list(date_field_name: str, file_path: Path):
|
||||
date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list()
|
||||
return sorted(map(lambda x: pd.Timestamp(x), date_list))
|
||||
return sorted([pd.Timestamp(x) for x in date_list])
|
||||
|
||||
|
||||
def get_calendar_list_by_ratio(
|
||||
@@ -155,7 +158,7 @@ def get_calendar_list_by_ratio(
|
||||
if date_list:
|
||||
all_oldest_list.append(date_list[0])
|
||||
for date in date_list:
|
||||
if date not in _dict_count_trade.keys():
|
||||
if date not in _dict_count_trade:
|
||||
_dict_count_trade[date] = 0
|
||||
|
||||
_dict_count_trade[date] += 1
|
||||
@@ -163,7 +166,7 @@ def get_calendar_list_by_ratio(
|
||||
p_bar.update()
|
||||
|
||||
logger.info(f"count how many funds have founded in this day......")
|
||||
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
|
||||
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade} # dict{date:count}
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
for oldest_date in all_oldest_list:
|
||||
for date in _dict_count_founding.keys():
|
||||
@@ -171,9 +174,7 @@ def get_calendar_list_by_ratio(
|
||||
_dict_count_founding[date] -= 1
|
||||
|
||||
calendar = [
|
||||
date
|
||||
for date in _dict_count_trade
|
||||
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
|
||||
date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count)
|
||||
]
|
||||
|
||||
return calendar
|
||||
@@ -186,16 +187,16 @@ def get_hs_stock_symbols() -> list:
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _HS_SYMBOLS
|
||||
global _HS_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
def _get_symbol():
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k))
|
||||
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k), timeout=None)
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), # pylint: disable=W0640
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), # pylint: disable=I1101
|
||||
)
|
||||
)
|
||||
time.sleep(3)
|
||||
@@ -230,12 +231,12 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _US_SYMBOLS
|
||||
global _US_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
|
||||
resp = requests.get(url)
|
||||
resp = requests.get(url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
@@ -277,7 +278,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"maxResultsPerPage": 10000,
|
||||
"filterToken": "",
|
||||
}
|
||||
resp = requests.post(url, json=_parms)
|
||||
resp = requests.post(url, json=_parms, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
@@ -317,7 +318,7 @@ def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _IN_SYMBOLS
|
||||
global _IN_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_nifty():
|
||||
@@ -358,7 +359,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
B3 stock symbols
|
||||
"""
|
||||
global _BR_SYMBOLS
|
||||
global _BR_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_ibovespa():
|
||||
@@ -367,7 +368,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
|
||||
# Request
|
||||
agent = {"User-Agent": "Mozilla/5.0"}
|
||||
page = requests.get(url, headers=agent)
|
||||
page = requests.get(url, headers=agent, timeout=None)
|
||||
|
||||
# BeautifulSoup
|
||||
soup = BeautifulSoup(page.content, "html.parser")
|
||||
@@ -375,7 +376,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
|
||||
children = tbody.findChildren("a", recursive=True)
|
||||
for child in children:
|
||||
_symbols.append(str(child).split('"')[-1].split(">")[1].split("<")[0])
|
||||
_symbols.append(str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0])
|
||||
|
||||
return _symbols
|
||||
|
||||
@@ -409,12 +410,12 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
-------
|
||||
fund symbols in China
|
||||
"""
|
||||
global _EN_FUND_SYMBOLS
|
||||
global _EN_FUND_SYMBOLS # pylint: disable=W0603
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://fund.eastmoney.com/js/fundcode_search.js"
|
||||
resp = requests.get(url)
|
||||
resp = requests.get(url, timeout=None)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
@@ -605,5 +606,177 @@ def get_instruments(
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
def _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame):
|
||||
df = copy.deepcopy(_1d_data_all)
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
return df
|
||||
|
||||
|
||||
def get_1d_data(
|
||||
_date_field_name: str,
|
||||
_symbol_field_name: str,
|
||||
symbol: str,
|
||||
start: str,
|
||||
end: str,
|
||||
_1d_data_all: pd.DataFrame,
|
||||
) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [_date_field_name, _symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
_all_1d_data = _get_all_1d_data(_date_field_name, _symbol_field_name, _1d_data_all)
|
||||
return _all_1d_data[
|
||||
(_all_1d_data[_symbol_field_name] == symbol.upper())
|
||||
& (_all_1d_data[_date_field_name] >= pd.Timestamp(start))
|
||||
& (_all_1d_data[_date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
|
||||
|
||||
def calc_adjusted_price(
|
||||
df: pd.DataFrame,
|
||||
_1d_data_all: pd.DataFrame,
|
||||
_date_field_name: str,
|
||||
_symbol_field_name: str,
|
||||
frequence: str,
|
||||
consistent_1d: bool = True,
|
||||
calc_paused: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""calc adjusted price
|
||||
This method does 4 things.
|
||||
1. Adds the `paused` field.
|
||||
- The added paused field comes from the paused field of the 1d data.
|
||||
2. Aligns the time of the 1d data.
|
||||
3. The data is reweighted.
|
||||
- The reweighting method:
|
||||
- volume / factor
|
||||
- open * factor
|
||||
- high * factor
|
||||
- low * factor
|
||||
- close * factor
|
||||
4. Called `calc_paused_num` method to add the `paused_num` field.
|
||||
- The `paused_num` is the number of consecutive days of trading suspension.
|
||||
"""
|
||||
# TODO: using daily data factor
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df.drop_duplicates(subset=_date_field_name, inplace=True)
|
||||
df.sort_values(_date_field_name, inplace=True)
|
||||
symbol = df.iloc[0][_symbol_field_name]
|
||||
df[_date_field_name] = pd.to_datetime(df[_date_field_name])
|
||||
# get 1d data from qlib
|
||||
_start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d")
|
||||
_end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all)
|
||||
data_1d = data_1d.copy()
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
|
||||
# TODO: np.nan or 1 or 0
|
||||
df["paused"] = np.nan
|
||||
else:
|
||||
# NOTE: volume is np.nan or volume <= 0, paused = 1
|
||||
# FIXME: find a more accurate data source
|
||||
data_1d["paused"] = 0
|
||||
data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1
|
||||
data_1d = data_1d.set_index(_date_field_name)
|
||||
|
||||
# add factor from 1d data
|
||||
# NOTE: 1d data info:
|
||||
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
|
||||
def _calc_factor(df_1d: pd.DataFrame):
|
||||
try:
|
||||
_date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date())
|
||||
df_1d["factor"] = data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"]
|
||||
df_1d["paused"] = data_1d.loc[_date]["paused"]
|
||||
except Exception:
|
||||
df_1d["factor"] = np.nan
|
||||
df_1d["paused"] = np.nan
|
||||
return df_1d
|
||||
|
||||
df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor)
|
||||
if consistent_1d:
|
||||
# the date sequence is consistent with 1d
|
||||
df.set_index(_date_field_name, inplace=True)
|
||||
df = df.reindex(
|
||||
generate_minutes_calendar_from_daily(
|
||||
calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()),
|
||||
freq=frequence,
|
||||
am_range=("09:30:00", "11:29:00"),
|
||||
pm_range=("13:00:00", "14:59:00"),
|
||||
)
|
||||
)
|
||||
df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name]
|
||||
df.index.names = [_date_field_name]
|
||||
df.reset_index(inplace=True)
|
||||
for _col in ["open", "close", "high", "low", "volume"]:
|
||||
if _col not in df.columns:
|
||||
continue
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] / df["factor"]
|
||||
else:
|
||||
df[_col] = df[_col] * df["factor"]
|
||||
if calc_paused:
|
||||
df = calc_paused_num(df, _date_field_name, _symbol_field_name)
|
||||
return df
|
||||
|
||||
|
||||
def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name):
|
||||
"""calc paused num
|
||||
This method adds the paused_num field
|
||||
- The `paused_num` is the number of consecutive days of trading suspension.
|
||||
"""
|
||||
_symbol = df.iloc[0][_symbol_field_name]
|
||||
df = df.copy()
|
||||
df["_tmp_date"] = df[_date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
# remove data that starts and ends with `np.nan` all day
|
||||
all_data = []
|
||||
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby("_tmp_date"):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
_df.loc[_df["volume"] < 0, "volume"] = np.nan
|
||||
|
||||
check_fields = set(_df.columns) - {
|
||||
"_tmp_date",
|
||||
"paused",
|
||||
"factor",
|
||||
_date_field_name,
|
||||
_symbol_field_name,
|
||||
}
|
||||
if _df.loc[:, list(check_fields)].isna().values.all() or (_df["volume"] == 0).all():
|
||||
all_nan_nums += 1
|
||||
not_nan_nums = 0
|
||||
_df["paused"] = 1
|
||||
if all_data:
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
else:
|
||||
all_nan_nums = 0
|
||||
not_nan_nums += 1
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
all_data = all_data[: len(all_data) - all_nan_nums]
|
||||
if all_data:
|
||||
df = pd.concat(all_data, sort=False)
|
||||
else:
|
||||
logger.warning(f"data is empty: {_symbol}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
del df["_tmp_date"]
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -121,7 +121,7 @@ pip install -r requirements.txt
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --end_date <end_date>
|
||||
or:
|
||||
download 1d data from YahooFinance
|
||||
|
||||
@@ -180,9 +180,8 @@ pip install -r requirements.txt
|
||||
|
||||
* Manual update of data
|
||||
```
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --end_date <end date>
|
||||
```
|
||||
* `trading_date`: start of trading day
|
||||
* `end_date`: end of trading day(not included)
|
||||
* `check_data_length`: check the number of rows per *symbol*, by default `None`
|
||||
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
|
||||
@@ -191,10 +190,10 @@ pip install -r requirements.txt
|
||||
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
* `region`: region, value from ["CN", "US"], default "CN"
|
||||
|
||||
* `interval`: interval, default "1d"(Currently only supports 1d data)
|
||||
* `exists_skip`: exists skip, by default False
|
||||
|
||||
## Using qlib data
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from re import I
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
@@ -21,6 +20,8 @@ from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.utils import code_to_fname, fname_to_code, exists_qlib_data
|
||||
from qlib.constant import REG_CN as REGION_CN
|
||||
@@ -38,6 +39,7 @@ from data_collector.utils import (
|
||||
get_in_stock_symbols,
|
||||
get_br_stock_symbols,
|
||||
generate_minutes_calendar_from_daily,
|
||||
calc_adjusted_price,
|
||||
)
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
@@ -229,9 +231,9 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
df = pd.DataFrame(
|
||||
map(
|
||||
lambda x: x.split(","),
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[
|
||||
"data"
|
||||
]["klines"],
|
||||
requests.get(
|
||||
INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end), timeout=None
|
||||
).json()["data"]["klines"],
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -316,7 +318,7 @@ class YahooCollectorIN1min(YahooCollectorIN):
|
||||
|
||||
|
||||
class YahooCollectorBR(YahooCollector, ABC):
|
||||
def retry(cls):
|
||||
def retry(cls): # pylint: disable=E0213
|
||||
"""
|
||||
The reason to use retry=2 is due to the fact that
|
||||
Yahoo Finance unfortunately does not keep track of some
|
||||
@@ -356,12 +358,10 @@ class YahooCollectorBR(YahooCollector, ABC):
|
||||
|
||||
class YahooCollectorBR1d(YahooCollectorBR):
|
||||
retry = 2
|
||||
pass
|
||||
|
||||
|
||||
class YahooCollectorBR1min(YahooCollectorBR):
|
||||
retry = 2
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalize(BaseNormalize):
|
||||
@@ -393,6 +393,7 @@ class YahooNormalize(BaseNormalize):
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df.index = df.index.tz_localize(None)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
@@ -522,78 +523,39 @@ class YahooNormalize1dExtend(YahooNormalize1d):
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
|
||||
self._first_close_field = "first_close"
|
||||
self._ori_close_field = "ori_close"
|
||||
self.column_list = ["open", "high", "low", "close", "volume", "factor", "change"]
|
||||
self.old_qlib_data = self._get_old_data(old_qlib_data_dir)
|
||||
|
||||
def _get_old_data(self, qlib_data_dir: [str, Path]):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
|
||||
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
|
||||
df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
|
||||
df.columns = [self._ori_close_field, self._first_close_field]
|
||||
df = D.features(D.instruments("all"), ["$" + col for col in self.column_list])
|
||||
df.columns = self.column_list
|
||||
return df
|
||||
|
||||
def _get_close(self, df: pd.DataFrame, field_name: str):
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_close = _df.loc[_df.last_valid_index()][field_name]
|
||||
return _close
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
try:
|
||||
_close = self._get_close(df, field_name=self._first_close_field)
|
||||
except KeyError:
|
||||
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
|
||||
return _close
|
||||
|
||||
def _get_last_close(self, df: pd.DataFrame) -> float:
|
||||
try:
|
||||
_close = self._get_close(df, field_name=self._ori_close_field)
|
||||
except KeyError:
|
||||
_close = None
|
||||
return _close
|
||||
|
||||
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
|
||||
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
|
||||
try:
|
||||
_df = self.old_qlib_data.loc(axis=0)[_symbol]
|
||||
_date = _df.index.max()
|
||||
except KeyError:
|
||||
_date = None
|
||||
return _date
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
_last_close = self._get_last_close(df)
|
||||
# reindex
|
||||
_last_date = self._get_last_date(df)
|
||||
if _last_date is not None:
|
||||
df = df.set_index(self._date_field_name)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
_max_date = df.index.max()
|
||||
df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
|
||||
df = df[df[self._date_field_name] > _last_date]
|
||||
if df.empty:
|
||||
return pd.DataFrame()
|
||||
_si = df["close"].first_valid_index()
|
||||
if _si > df.index[0]:
|
||||
logger.warning(
|
||||
f"{df.loc[_si][self._symbol_field_name]} missing data: {df.loc[:_si - 1][self._date_field_name].to_list()}"
|
||||
)
|
||||
# normalize
|
||||
df = self.normalize_yahoo(
|
||||
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
|
||||
)
|
||||
# adjusted price
|
||||
df = self.adjusted_price(df)
|
||||
df = self._manual_adj_data(df)
|
||||
return df
|
||||
df = super(YahooNormalize1dExtend, self).normalize(df)
|
||||
df.set_index(self._date_field_name, inplace=True)
|
||||
symbol_name = df[self._symbol_field_name].iloc[0]
|
||||
old_symbol_list = self.old_qlib_data.index.get_level_values("instrument").unique().to_list()
|
||||
if str(symbol_name).upper() not in old_symbol_list:
|
||||
return df.reset_index()
|
||||
old_df = self.old_qlib_data.loc[str(symbol_name).upper()]
|
||||
latest_date = old_df.index[-1]
|
||||
df = df.loc[latest_date:]
|
||||
new_latest_data = df.iloc[0]
|
||||
old_latest_data = old_df.loc[latest_date]
|
||||
for col in self.column_list[:-1]:
|
||||
if col == "volume":
|
||||
df[col] = df[col] / (new_latest_data[col] / old_latest_data[col])
|
||||
else:
|
||||
df[col] = df[col] * (old_latest_data[col] / new_latest_data[col])
|
||||
return df.drop(df.index[0]).reset_index()
|
||||
|
||||
|
||||
class YahooNormalize1min(YahooNormalize, ABC):
|
||||
"""Normalised to 1min using local 1d data"""
|
||||
|
||||
AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00")
|
||||
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
|
||||
|
||||
@@ -601,160 +563,6 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
CONSISTENT_1d = True
|
||||
CALC_PAUSED_NUM = True
|
||||
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
|
||||
if calendar_list_1d is None:
|
||||
calendar_list_1d = self._get_1d_calendar_list()
|
||||
setattr(self, "_calendar_list_1d", calendar_list_1d)
|
||||
return calendar_list_1d
|
||||
|
||||
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
|
||||
if not (data_1d is None or data_1d.empty):
|
||||
_class_name = self.__class__.__name__.replace("min", "d")
|
||||
_class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name)
|
||||
data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
|
||||
data_1d = data_1d_obj.normalize(data_1d)
|
||||
return data_1d
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df = df.sort_values(self._date_field_name)
|
||||
symbol = df.iloc[0][self._symbol_field_name]
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
|
||||
data_1d = data_1d.copy()
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
|
||||
# TODO: np.nan or 1 or 0
|
||||
df["paused"] = np.nan
|
||||
else:
|
||||
# NOTE: volume is np.nan or volume <= 0, paused = 1
|
||||
# FIXME: find a more accurate data source
|
||||
data_1d["paused"] = 0
|
||||
data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1
|
||||
data_1d = data_1d.set_index(self._date_field_name)
|
||||
|
||||
# add factor from 1d data
|
||||
# NOTE: yahoo 1d data info:
|
||||
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
|
||||
def _calc_factor(df_1d: pd.DataFrame):
|
||||
try:
|
||||
_date = pd.Timestamp(pd.Timestamp(df_1d[self._date_field_name].iloc[0]).date())
|
||||
df_1d["factor"] = (
|
||||
data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"]
|
||||
)
|
||||
df_1d["paused"] = data_1d.loc[_date]["paused"]
|
||||
except Exception:
|
||||
df_1d["factor"] = np.nan
|
||||
df_1d["paused"] = np.nan
|
||||
return df_1d
|
||||
|
||||
df = df.groupby([df[self._date_field_name].dt.date]).apply(_calc_factor)
|
||||
|
||||
if self.CONSISTENT_1d:
|
||||
# the date sequence is consistent with 1d
|
||||
df.set_index(self._date_field_name, inplace=True)
|
||||
df = df.reindex(
|
||||
self.generate_1min_from_daily(
|
||||
pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates())
|
||||
)
|
||||
)
|
||||
df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][
|
||||
self._symbol_field_name
|
||||
]
|
||||
df.index.names = [self._date_field_name]
|
||||
df.reset_index(inplace=True)
|
||||
for _col in self.COLUMNS:
|
||||
if _col not in df.columns:
|
||||
continue
|
||||
if _col == "volume":
|
||||
df[_col] = df[_col] / df["factor"]
|
||||
else:
|
||||
df[_col] = df[_col] * df["factor"]
|
||||
|
||||
if self.CALC_PAUSED_NUM:
|
||||
df = self.calc_paused_num(df)
|
||||
return df
|
||||
|
||||
def calc_paused_num(self, df: pd.DataFrame):
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
df = df.copy()
|
||||
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
# remove data that starts and ends with `np.nan` all day
|
||||
all_data = []
|
||||
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby("_tmp_date"):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
_df.loc[_df["volume"] < 0, "volume"] = np.nan
|
||||
|
||||
check_fields = set(_df.columns) - {
|
||||
"_tmp_date",
|
||||
"paused",
|
||||
"factor",
|
||||
self._date_field_name,
|
||||
self._symbol_field_name,
|
||||
}
|
||||
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
|
||||
all_nan_nums += 1
|
||||
not_nan_nums = 0
|
||||
_df["paused"] = 1
|
||||
if all_data:
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
else:
|
||||
all_nan_nums = 0
|
||||
not_nan_nums += 1
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
all_data = all_data[: len(all_data) - all_nan_nums]
|
||||
if all_data:
|
||||
df = pd.concat(all_data, sort=False)
|
||||
else:
|
||||
logger.warning(f"data is empty: {_symbol}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
del df["_tmp_date"]
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalize1minOffline(YahooNormalize1min):
|
||||
"""Normalised to 1min using local 1d data"""
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
@@ -769,42 +577,45 @@ class YahooNormalize1minOffline(YahooNormalize1min):
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self.qlib_data_1d_dir = qlib_data_1d_dir
|
||||
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self._all_1d_data = self._get_all_1d_data()
|
||||
super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)
|
||||
qlib.init(provider_uri=qlib_data_1d_dir)
|
||||
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def _get_all_1d_data(self):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
|
||||
if calendar_list_1d is None:
|
||||
calendar_list_1d = self._get_1d_calendar_list()
|
||||
setattr(self, "_calendar_list_1d", calendar_list_1d)
|
||||
return calendar_list_1d
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
df = calc_adjusted_price(
|
||||
df=df,
|
||||
_date_field_name=self._date_field_name,
|
||||
_symbol_field_name=self._symbol_field_name,
|
||||
frequence="1min",
|
||||
consistent_1d=self.CONSISTENT_1d,
|
||||
calc_paused=self.CALC_PAUSED_NUM,
|
||||
_1d_data_all=self.all_1d_data,
|
||||
)
|
||||
return df
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
@abc.abstractmethod
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
raise NotImplementedError("rewrite symbol_to_yahoo")
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
return self._all_1d_data[
|
||||
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
|
||||
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
|
||||
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
@abc.abstractmethod
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalizeUS:
|
||||
@@ -821,7 +632,7 @@ class YahooNormalizeUS1dExtend(YahooNormalizeUS, YahooNormalize1dExtend):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
@@ -844,7 +655,7 @@ class YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1minOffline):
|
||||
class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1min):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
@@ -872,7 +683,7 @@ class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
AM_RANGE = ("09:30:00", "11:29:00")
|
||||
PM_RANGE = ("13:00:00", "14:59:00")
|
||||
|
||||
@@ -899,7 +710,7 @@ class YahooNormalizeBR1d(YahooNormalizeBR, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1minOffline):
|
||||
class YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1min):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
@@ -1123,10 +934,10 @@ class Run(BaseRun):
|
||||
def update_data_to_bin(
|
||||
self,
|
||||
qlib_data_1d_dir: str,
|
||||
trading_date: str = None,
|
||||
end_date: str = None,
|
||||
check_data_length: int = None,
|
||||
delay: float = 1,
|
||||
exists_skip: bool = False,
|
||||
):
|
||||
"""update yahoo data to bin
|
||||
|
||||
@@ -1135,14 +946,14 @@ class Run(BaseRun):
|
||||
qlib_data_1d_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
|
||||
trading_date: str
|
||||
trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
|
||||
end_date: str
|
||||
end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
|
||||
check_data_length: int
|
||||
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
||||
delay: float
|
||||
time.sleep(delay), default 1
|
||||
exists_skip: bool
|
||||
exists skip, by default False
|
||||
Notes
|
||||
-----
|
||||
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
|
||||
@@ -1150,24 +961,24 @@ class Run(BaseRun):
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
|
||||
# get 1m data
|
||||
"""
|
||||
|
||||
if self.interval.lower() != "1d":
|
||||
logger.warning(f"currently supports 1d data updates: --interval 1d")
|
||||
|
||||
# start/end date
|
||||
if trading_date is None:
|
||||
trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
logger.warning(f"trading_date is None, use the current date: {trading_date}")
|
||||
|
||||
if end_date is None:
|
||||
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# download qlib 1d data
|
||||
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
|
||||
if not exists_qlib_data(qlib_data_1d_dir):
|
||||
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
|
||||
GetData().qlib_data(
|
||||
target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip
|
||||
)
|
||||
|
||||
# start/end date
|
||||
calendar_df = pd.read_csv(Path(qlib_data_1d_dir).joinpath("calendars/day.txt"))
|
||||
trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
if end_date is None:
|
||||
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# download data from yahoo
|
||||
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
|
||||
|
||||
@@ -135,7 +135,7 @@ class DumpDataBase:
|
||||
|
||||
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(str).astype(np.datetime64)
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(str).astype("datetime64[ns]")
|
||||
# df.drop_duplicates([self.date_field_name], inplace=True)
|
||||
return df
|
||||
|
||||
@@ -146,9 +146,7 @@ class DumpDataBase:
|
||||
return (
|
||||
self._include_fields
|
||||
if self._include_fields
|
||||
else set(df_columns) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else df_columns
|
||||
else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -176,7 +174,7 @@ class DumpDataBase:
|
||||
def save_calendars(self, calendars_data: list):
|
||||
self._calendars_dir.mkdir(parents=True, exist_ok=True)
|
||||
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
|
||||
result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data))
|
||||
result_calendars_list = [self._format_datetime(x) for x in calendars_data]
|
||||
np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8")
|
||||
|
||||
def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
|
||||
@@ -195,7 +193,7 @@ class DumpDataBase:
|
||||
def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:
|
||||
# calendars
|
||||
calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype("datetime64[ns]")
|
||||
cal_df = calendars_df[
|
||||
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
|
||||
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
|
||||
|
||||
@@ -3,24 +3,21 @@
|
||||
"""
|
||||
TODO:
|
||||
- A more well-designed PIT database is required.
|
||||
- seperated insert, delete, update, query operations are required.
|
||||
- separated insert, delete, update, query operations are required.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import shutil
|
||||
import struct
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Union
|
||||
from typing import Iterable
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import fname_to_code, code_to_fname, get_period_offset
|
||||
from qlib.utils import fname_to_code, get_period_offset
|
||||
from qlib.config import C
|
||||
|
||||
|
||||
@@ -135,9 +132,11 @@ class DumpPitData:
|
||||
return (
|
||||
set(self._include_fields)
|
||||
if self._include_fields
|
||||
else set(df[self.field_column_name]) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else set(df[self.field_column_name])
|
||||
else (
|
||||
set(df[self.field_column_name]) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else set(df[self.field_column_name])
|
||||
)
|
||||
)
|
||||
|
||||
def get_filenames(self, symbol, field, interval):
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
|
||||
|
||||
# Requirements
|
||||
|
||||
|
||||
Use following install command to complete the project.
|
||||
```
|
||||
pip install -e '.[finco]'
|
||||
```
|
||||
|
||||
|
||||
# TODOs
|
||||
|
||||
- [ ] Select the appropriate LLM API
|
||||
- Which API is more suitable for meeting our requirements - the original API or an alternative like LangChain?
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -x # show command
|
||||
set -e # Error on exception
|
||||
|
||||
DIR="$(
|
||||
cd "$(dirname "$(readlink -f "$0")")" || exit
|
||||
pwd -P
|
||||
)"
|
||||
# --load the cridentials
|
||||
if [ -e $DIR/cridential.sh ]; then
|
||||
source $DIR/cridential.sh
|
||||
fi
|
||||
|
||||
# run the command
|
||||
python -m qlib.finco.cli "build an A-share stock market daily portfolio in quantitative investment and minimize the maximum drawdown."
|
||||
@@ -1,3 +0,0 @@
|
||||
export OPENAI_API_TYPE=azure # This only necessary for Azure OpenAI
|
||||
export OPENAI_API_KEY=
|
||||
export OPENAI_API_BASE=
|
||||
20
setup.py
20
setup.py
@@ -65,12 +65,18 @@ REQUIRED = [
|
||||
# To ensure stable operation of the experiment manager, we have limited the version of mlflow,
|
||||
# and we need to verify whether version 2.0 of mlflow can serve qlib properly.
|
||||
"mlflow>=1.12.1, <=1.30.0",
|
||||
# mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail.
|
||||
"packaging<22",
|
||||
"tqdm",
|
||||
"loguru",
|
||||
"lightgbm>=3.3.0",
|
||||
"tornado",
|
||||
"joblib>=0.17.0",
|
||||
"ruamel.yaml>=0.16.12",
|
||||
# With the upgrading of ruamel.yaml to 0.18, the safe_load method was deprecated,
|
||||
# which would cause qlib.workflow.cli to not work properly,
|
||||
# and no good replacement has been found, so the version of ruamel.yaml has been restricted for now.
|
||||
# Refs: https://pypi.org/project/ruamel.yaml/
|
||||
"ruamel.yaml<=0.17.36",
|
||||
"pymongo==3.7.2", # For task management
|
||||
"scikit-learn>=0.22",
|
||||
"dill",
|
||||
@@ -140,7 +146,8 @@ setup(
|
||||
"wheel",
|
||||
"setuptools",
|
||||
"black",
|
||||
"pylint",
|
||||
# Version 3.0 of pylint had problems with the build process, so we limited the version of pylint.
|
||||
"pylint<=2.17.6",
|
||||
# Using the latest versions(0.981 and 0.982) of mypy,
|
||||
# the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py",
|
||||
# If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy.
|
||||
@@ -174,15 +181,6 @@ setup(
|
||||
"tianshou<=0.4.10",
|
||||
"torch",
|
||||
],
|
||||
"finco": [
|
||||
# finco is not necessary for all Qlib users; So a single require section is used for it.
|
||||
"openai",
|
||||
"pydantic", # Please add it to basic requirements after the design of pydantic is state.
|
||||
"pydantic-settings",
|
||||
"python-dotenv", # I don't think this is necessary if we use pydantic.
|
||||
"fuzzywuzzy",
|
||||
"python-Levenshtein", # not necessary but accelerate fuzzywuzzy calculation
|
||||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
classifiers=[
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import unittest
|
||||
import shutil
|
||||
import difflib
|
||||
from qlib.finco.tpl import get_tpl_path
|
||||
import ruamel.yaml as yaml
|
||||
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from pathlib import Path
|
||||
from qlib.finco.tpl import get_tpl_path
|
||||
from qlib.finco.task import YamlEditTask
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
class FincoTpl(TestAutoData):
|
||||
def test_tpl_consistence(self):
|
||||
"""Motivation: make sure the configuable template is consistent with the default config"""
|
||||
tpl_p = get_tpl_path()
|
||||
with (tpl_p / "sl" / "workflow_config.yaml").open("rb") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
# init_data_handler
|
||||
hd: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"])
|
||||
# NOTE: The config in workflow_config.yaml is generated by the following code:
|
||||
# dump in yaml format to file without auto linebreak
|
||||
# print(yaml.dump(hd.data_loader.fields, width=10000, stream=open("_tmp", "w")))
|
||||
|
||||
with (tpl_p / "sl-cfg" / "workflow_config.yaml").open("rb") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
hd_ds: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"])
|
||||
self.assertEqual(hd_ds.data_loader.fields, hd.data_loader.fields)
|
||||
|
||||
check = hd_ds.fetch().fillna(0.0) == hd.fetch().fillna(0.0)
|
||||
self.assertTrue(check.all().all())
|
||||
|
||||
def test_update_yaml(self):
|
||||
p = get_tpl_path() / "sl" / "workflow_config.yaml"
|
||||
p_new = DIRNAME / "_test_config.yaml"
|
||||
shutil.copy(p, p_new)
|
||||
updated_content = """
|
||||
class: LGBModelTest
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 1.8879
|
||||
learning_rate: 0.3
|
||||
subsample: 0.8790
|
||||
lambda_l1: 205.7000
|
||||
lambda_l2: 580.9769
|
||||
max_depth: 9
|
||||
num_leaves: 211
|
||||
num_threads: 21
|
||||
"""
|
||||
t = YamlEditTask(p_new, "task.model", updated_content)
|
||||
t.execute()
|
||||
# NOTE: the formmat is changed by ruamel.yaml, so it can't be compared by text directly..
|
||||
# print the diff between p and p_new with difflib
|
||||
# with p.open("r") as fp:
|
||||
# content = fp.read()
|
||||
# with p_new.open("r") as fp:
|
||||
# content_new = fp.read()
|
||||
# for line in difflib.unified_diff(content, content_new, fromfile="original", tofile="new", lineterm=""):
|
||||
# print(line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,66 +0,0 @@
|
||||
import unittest
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from dotenv import load_dotenv
|
||||
# pydantic support load_dotenv, so load_dotenv will be deprecated in the future.
|
||||
|
||||
from qlib.finco.task import SummarizeTask
|
||||
from qlib.finco.workflow import WorkflowContextManager
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.workflow import WorkflowManager
|
||||
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
|
||||
class TestSummarize(unittest.TestCase):
|
||||
|
||||
def test_chat(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your are a professional financial assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How to write a perfect quant strategy.",
|
||||
},
|
||||
]
|
||||
response = APIBackend().try_create_chat_completion(messages=messages)
|
||||
print(response)
|
||||
|
||||
def test_execution(self):
|
||||
task = SummarizeTask()
|
||||
context = WorkflowContextManager()
|
||||
context.set_context("workspace", "../../examples/benchmarks/Linear")
|
||||
context.set_context("user_prompt", "My main focus is on the performance of the strategy's return."
|
||||
"Please summarize the information and give me some advice.")
|
||||
task.assign_context_manager(context)
|
||||
resp = task.execute()
|
||||
print(resp)
|
||||
|
||||
def test_generate_batch_result(self):
|
||||
wm = WorkflowManager()
|
||||
|
||||
prompt = wm.default_user_prompt
|
||||
# prompt = ""
|
||||
|
||||
workdir = os.path.dirname(wm.get_context().get_context("workspace"))
|
||||
summaries_path = os.path.join(workdir, "summaries")
|
||||
|
||||
if not os.path.exists(summaries_path):
|
||||
os.makedirs(summaries_path)
|
||||
|
||||
for i in range(10):
|
||||
wm.run(prompt)
|
||||
if os.path.exists(f"{workdir}/finCoReport.md"):
|
||||
shutil.move(f"{workdir}/finCoReport.md", f"{workdir}/summaries/finCoReport{i}.md")
|
||||
|
||||
def test_parse2txt(self):
|
||||
task = SummarizeTask()
|
||||
resp = task.get_info_from_file("")
|
||||
print(resp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,23 +0,0 @@
|
||||
import unittest
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
|
||||
|
||||
class SingletonTest(unittest.TestCase):
|
||||
|
||||
def test_singleton(self):
|
||||
# self.assertEqual(self.to_str(data.tail()), self.to_str(res))
|
||||
closure_checker = []
|
||||
|
||||
class A(SingletonBaseClass):
|
||||
|
||||
def __init__(self) -> None:
|
||||
closure_checker.append(0)
|
||||
|
||||
A()
|
||||
self.assertEqual(len(closure_checker), 1)
|
||||
A()
|
||||
self.assertEqual(len(closure_checker), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -13,7 +13,9 @@ from pathlib import Path
|
||||
|
||||
from qlib.data import D
|
||||
from qlib.tests.data import GetData
|
||||
from scripts.dump_pit import DumpPitData
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from dump_pit import DumpPitData
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit")))
|
||||
from collector import Run
|
||||
|
||||
@@ -9,7 +9,9 @@ from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class WorkflowTest(TestAutoData):
|
||||
TMP_PATH = Path("./.mlruns_tmp/")
|
||||
# Creating the directory manually doesn't work with mlflow,
|
||||
# so we add a subfolder named .trash when we create the directory.
|
||||
TMP_PATH = Path("./.mlruns_tmp/.trash")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if self.TMP_PATH.exists():
|
||||
@@ -17,6 +19,8 @@ class WorkflowTest(TestAutoData):
|
||||
|
||||
def test_get_local_dir(self):
|
||||
""" """
|
||||
self.TMP_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with R.start(uri=str(self.TMP_PATH)):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user