1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 14:01:28 +08:00

Compare commits

...

28 Commits

Author SHA1 Message Date
Linlang Lv (iSoftStone Information)
5fafba36f2 fix docs 2024-05-21 04:12:44 +08:00
Linlang
8a087d0db9 fix docs (#1721)
* fix docs

* modify file extension

* modify file extension

---------

Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-17 19:19:45 +08:00
playfund
2ae4be426a Delete redundant copy() code to speed up (#1732)
Delete redundant copy() code to speed up

Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-17 18:45:07 +08:00
fei long
6ed83f7c04 data_collector: cn_index: fix missing dependencies package in requirements.txt (#1770)
add yahooquery and openpyxl in requirements.txt

Signed-off-by: YuLong Yao <feilongphone@gmail.com>
Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-17 18:43:12 +08:00
Ikko Eltociear Ashimine
917e3a725e Update dump_pit.py (#1759)
seperated -> separated

Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-10 14:42:41 +08:00
Chuan Xu
b1e0e77c97 Fix the bug of reading string NA as NaN in the function exists_qlib_data. (#1736)
* Fix the bug of reading NA string as NaN in exists_qlib_data.

* Fix the .gitignore file.

* Update the fix and add some comments.

* format with black

---------

Co-authored-by: Chuan Xu <chuan.xu@sas.com>
Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-10 13:09:39 +08:00
Linlang
ea245f5435 Fix issue 1729 (#1776)
* fix issue 1729

* fix issue 1729

* fix issue 1729

---------

Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-10 11:04:59 +08:00
Linlang
3779b5186a bump version (#1784)
Co-authored-by: Linlang Lv (iSoftStone Information) <v-lvlinlang@microsoft.com>
2024-05-08 13:50:55 +08:00
Young
194284b1ac Update version 2024-05-07 14:15:35 +08:00
Xisen Wang
1bb8f2fa23 Enhance README with LightGBM Installation Guidance for Mac M1 Users (#1766)
* Update README.md

* Update README.md

* Update README.md
2024-03-20 20:48:52 +08:00
Linlang
39f88daaa7 download orderbook data (#1754)
* download orderbook data

* fix CI error

* fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* test fix CI error

* optimize get_data code

* optimize get_data code

* optimize get_data code

* optimize README

---------

Co-authored-by: Linlang <v-linlanglv@microsoft.com>
2024-03-07 14:41:21 +08:00
Linlang
98f569eed2 add_baostock_collector (#1641)
* add_baostock_collector

* modify_comments

* fix_pylint_error

* solve_duplication_methods

* modified the logic of update_data_to_bin

* modified the logic of update_data_to_bin

* optimize code

* optimize pylint issue

* fix pylint error

* changes suggested by the review

* fix CI faild

* fix CI faild

* fix issue 1121

* format with black

* optimize code logic

* optimize code logic

* fix error code

* drop warning during code runs

* optimize code

* format with black

* fix bug

* format with black

* optimize code

* optimize code

* add comments
2023-11-21 20:31:47 +08:00
JJ
ceff886f49 Update data.rst (#1679)
Fixed a couple of small spelling errors.
2023-11-16 18:11:29 +08:00
Ikko Eltociear Ashimine
15b64768e2 Update README.md (#1637)
an -> a
2023-11-15 17:03:26 +08:00
Andy li
8bf2678676 fix the warning (#1656) 2023-11-03 17:03:11 +08:00
JJ
fb80e318e2 Update quick.rst (#1667)
Fixed small spelling error.
2023-10-20 17:23:34 +08:00
zhuan
ecbeeafdc1 Update requirements.txt (#1521) 2023-09-15 17:18:04 +08:00
Fivele-Li
69e28ceab8 suppress the SettingWithCopyWarning of pandas (#1513)
* df value is set as expected, suppress the warning;

* depress warning with pandas option_context

---------

Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-09-01 18:12:49 +08:00
Fivele-Li
4c30e5827b Troubleshooting pip version issues in CI (#1504)
* CI failed to run on 23.1 and 23.1.1

* add pyproject.toml

* upgrade pip in slow.yml

* upgrade build-system requires

* troubleshooting pytest problem

* troubleshooting pytest problem

* troubleshooting pytest problem

* troubleshooting pytest problem

* add qlib root path to python sys.path

* add qlib root path to $PYTHONPATH

* add qlib root path to $PYTHONPATH

* add qlib root path to $PYTHONPATH

* modify pytest root;

* remove set env

* change_pytest_command_CI

* change_pytest_command_CI

* fix_ci

* fix_ci

* fix_ci

* fix_ci

* fix_ci

* fix_ci

* fix_ci

* remove_toml

* recover_toml

---------

Co-authored-by: lijinhui <362237642@qq.com>
Co-authored-by: linlang <Lv.Linlang@hotmail.com>
2023-08-24 21:24:50 +08:00
Di
5387ea5c1f Add exploration noise to rl training collector (#1481)
* Update vessel.py

Add exploration_noise=True  to training collector

* Update vessel.py

Reformat
2023-08-18 17:41:02 +08:00
Di
05d67b3828 Add multi pass portfolio analysis record (#1546)
* Add multi pass port ana record

* Add list function

* Add documentation and support <MODEL> tag

* Add drop in replacement example

* reformat

* Change according to comments

* update format

* Update record_temp.py

Fix type hint

* Update record_temp.py
2023-08-04 17:41:12 +08:00
Linlang
38edac5069 fix docs (#1618)
Co-authored-by: Linlang <v-linlanglv@microsoft.com>
2023-08-02 20:14:54 +08:00
Fivele-Li
b4b7a2fdd4 depress warning with pandas option_context (#1524)
Co-authored-by: Cadenza-Li <362237642@qq.com>
2023-08-01 19:02:04 +08:00
JJ
480f233e3f Update introduction.rst (#1578) 2023-07-26 16:42:53 +08:00
Gene
953621ac7e Update README.md (#1553) 2023-07-26 16:38:22 +08:00
JJ
87a026fef3 Update introduction.rst (#1579)
Fixed a spelling mistake. I changed deicsions to decisions.
2023-07-26 16:37:59 +08:00
Linlang
8676303077 fix_ci (#1608)
Co-authored-by: Linlang <v-linlanglv@microsoft.com>
2023-07-19 17:33:47 +08:00
you-n-g
1a32ba1806 Bump Version & Fix CI (#1606)
* Bump Version & Fix CI

* Update test_qlib_from_pip.yml
2023-07-18 20:54:15 +08:00
51 changed files with 1118 additions and 427 deletions

View File

@@ -51,8 +51,8 @@ jobs:
python setup.py bdist_wheel
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
twine upload dist/*
@@ -72,10 +72,10 @@ jobs:
python-version: 3.7
- name: Install dependencies
run: |
pip install twine
pip install twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
twine upload dist/pyqlib-*-manylinux*.whl

View File

@@ -6,8 +6,14 @@ on:
branches:
- main
permissions:
contents: read
jobs:
update_release_draft:
permissions:
contents: write
pull-requests: read
runs-on: ubuntu-latest
steps:
# Drafts your next Release notes as Pull Requests are merged into "master"

View File

@@ -8,13 +8,15 @@ on:
jobs:
build:
if: ${{ false }} # FIXME: temporarily disable... Due to we are rushing a feature
timeout-minutes: 120
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -44,9 +46,6 @@ jobs:
- name: Qlib installation test
run: |
python -m pip install pyqlib
# Specify the numpy version because the numpy upgrade caused the CI test to fail,
# and this line of code will be removed when the next version of qlib is released.
python -m pip install "numpy<1.23"
- name: Install Lightgbm for MacOS
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}

View File

@@ -14,7 +14,10 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -38,10 +41,8 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Update pip to the latest version
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
# The pip version has been temporarily fixed to 23.0
run: |
python -m pip install pip==23.0
python -m pip install --upgrade pip
- name: Installing pytorch for macos
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
@@ -104,6 +105,7 @@ jobs:
- name: Check Qlib with pylint
run: |
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)"
# The following flake8 error codes were ignored:
# E501 line too long

View File

@@ -14,7 +14,10 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-latest]
# Since macos-latest changed from 12.7.4 to 14.4.1,
# the minimum python version that matches a 14.4.1 version of macos is 3.10,
# so we limit the macos version to macos-12.
os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-11, macos-12]
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
python-version: [3.7, 3.8]
@@ -38,10 +41,8 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Set up Python tools
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
# The pip version has been temporarily fixed to 23.0
run: |
python -m pip install pip==23.0
python -m pip install --upgrade pip
pip install --upgrade cython numpy
pip install -e .[dev]

2
.gitignore vendored
View File

@@ -48,4 +48,4 @@ tags
*.swp
./pretrain
.idea/
.idea/

View File

@@ -5,6 +5,12 @@
# Required
version: 2
# Set the version of Python and other tools you might need
build:
os: ubuntu-22.04
tools:
python: "3.7"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
@@ -14,7 +20,6 @@ formats: all
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
install:
- requirements: docs/requirements.txt
- method: pip

View File

@@ -139,7 +139,7 @@ This table demonstrates the supported Python version of `Qlib`:
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
**Note**:
1. **Conda** is suggested for managing your Python environment.
1. **Conda** is suggested for managing your Python environment. In some cases, using Python outside of a `conda` environment may result in missing header files, causing the installation failure of certain packages.
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
1. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
1. `Qlib`Requires `tables` package, `hdf5` in tables does not support python3.9.
@@ -172,6 +172,8 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
**Tips**: If you fail to install `Qlib` or run the examples in your environment, comparing your steps and the [CI workflow](.github/workflows/test_qlib_from_source.yml) may help you find the problem.
**Tips for Mac**: If you are using Mac with M1, you might encounter issues in building the wheel for LightGBM, which is due to missing dependencies from OpenMP. To solve the problem, install openmp first with ``brew install libomp`` and then run ``pip install .`` to build it successfully.
## Data Preparation
Load and prepare data by running the following code:
@@ -321,7 +323,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
# Main Challenges & Solutions in Quant Research
Quant investment is an very unique scenario with lots of key challenges to be solved.
Quant investment is a very unique scenario with lots of key challenges to be solved.
Currently, Qlib provides some solutions for several of them.
## Forecasting: Finding Valuable Signals/Patterns
@@ -360,7 +362,7 @@ Here is a list of models built on `Qlib`.
Your PR of new Quant models is highly welcomed.
The performance of each model on the `Alpha158` and `Alpha360` dataset can be found [here](examples/benchmarks/README.md).
The performance of each model on the `Alpha158` and `Alpha360` datasets can be found [here](examples/benchmarks/README.md).
### Run a single model
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.

View File

@@ -52,7 +52,7 @@ Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency
Qlib Format Dataset
-------------------
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows. User can also use numpy to load `.bin` file to validate data.
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
The price volume data look different from the actual dealing price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
Here are some discussions about the price adjusting of Qlib.
@@ -140,12 +140,13 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
where the data are in the following format:
.. code-block::
+-----------+-------+
| symbol | close |
+===========+=======+
| SH600000 | 120 |
+-----------+-------+
symbol,close
SH600000,120
- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
- CSV file **must** include a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
.. code-block:: bash
@@ -153,11 +154,13 @@ Users can also provide their own data in CSV format. However, the CSV data **mus
where the data are in the following format:
.. code-block::
symbol,date,close,open,volume
SH600000,2020-11-01,120,121,12300000
SH600000,2020-11-02,123,120,12300000
+---------+------------+-------+------+----------+
| symbol | date | close | open | volume |
+=========+============+=======+======+==========+
| SH600000| 2020-11-01 | 120 | 121 | 12300000 |
+---------+------------+-------+------+----------+
| SH600000| 2020-11-02 | 123 | 120 | 12300000 |
+---------+------------+-------+------+----------+
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.

View File

@@ -36,7 +36,7 @@ Name Description
the training process of models which enable algorithms controlling the
training process.
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are learnable. They are learned
`Learning Framework` layer The `Forecast Model` and `Trading Agent` are trainable. They are trained
based on the `Learning Framework` layer and then applied to multiple scenarios
in `Workflow` layer. The supported learning paradigms can be categorized into
reinforcement learning and supervised learning. The learning framework
@@ -51,7 +51,7 @@ Name Description
modules. With these signals `Decision Generator` will generate the target
trading decisions(i.e. portfolio, orders)
If RL-based Strategies are adopted, the `Policy` is learned in a end-to-end way,
the trading deicsions are generated directly.
the trading decisions are generated directly.
Decisions will be executed by `Execution Env`
(i.e. the trading market). There may be multiple levels of `Strategy`
and `Executor` (e.g. an *order executor trading strategy and intraday order executor*

View File

@@ -16,7 +16,7 @@ This ``Quick Start`` guide tries to demonstrate
Installation
============
Users can easily intsall ``Qlib`` according to the following steps:
Users can easily install ``Qlib`` according to the following steps:
- Before installing ``Qlib`` from source, users need to install some dependencies:

View File

@@ -5,3 +5,4 @@ scipy
scikit-learn
pandas
tianshou
sphinx_rtd_theme

View File

@@ -0,0 +1,78 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LinearModel
module_path: qlib.contrib.model.linear
kwargs:
estimator: ols
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: True
ann_scaler: 252
- class: MultiPassPortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -136,7 +136,7 @@ If you want to contribute your new models, you can follow the steps below.
- `README.md`: a brief introduction to your models
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py#LL286C22-L286C22) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))

View File

@@ -324,7 +324,6 @@ class TRAModel(Model):
class LSTM(nn.Module):
"""LSTM Model
Args:
@@ -414,7 +413,6 @@ class PositionalEncoding(nn.Module):
class Transformer(nn.Module):
"""Transformer Model
Args:
@@ -475,7 +473,6 @@ class Transformer(nn.Module):
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,

View File

@@ -27,13 +27,11 @@ pip install arctic # NOTE: pip may fail to resolve the right package dependency
2. Please follow following steps to download example data
```bash
cd examples/orderbook_data/
wget http://fintech.msra.cn/stock_data/downloads/highfreq_orderboook_example_data.tar.bz2
tar xf highfreq_orderboook_example_data.tar.bz2
python ../../scripts/get_data.py download_data --target_dir . --file_name highfreq_orderbook_example_data.zip
```
3. Please import the example data to your mongo db
```bash
cd examples/orderbook_data/
python create_dataset.py initialize_library # Initialization Libraries
python create_dataset.py import_data # Initialization Libraries
```
@@ -42,7 +40,6 @@ python create_dataset.py import_data # Initialization Libraries
After importing these data, you run `example.py` to create some high-frequency features.
```bash
cd examples/orderbook_data/
pytest -s --disable-warnings example.py # If you want run all examples
pytest -s --disable-warnings example.py::TestClass::test_exp_10 # If you want to run specific example
```

2
pyproject.toml Normal file
View File

@@ -0,0 +1,2 @@
[build-system]
requires = ["setuptools", "numpy", "Cython"]

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from pathlib import Path
__version__ = "0.9.3"
__version__ = "0.9.4.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
from typing import Union

View File

@@ -162,13 +162,15 @@ def create_account_instance(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
benchmark_config=(
{}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
}
),
)

View File

@@ -622,9 +622,11 @@ class Indicator:
print(
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
freq,
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
(
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S")
),
fulfill_rate,
price_advantage,
positive_rate,

View File

@@ -3,6 +3,7 @@ Here is a batch of evaluation functions.
The interface should be redesigned carefully in the future.
"""
import pandas as pd
from typing import Tuple
from qlib import get_module_logger

View File

@@ -511,7 +511,6 @@ class TRAModel(Model):
class RNN(nn.Module):
"""RNN Model
Args:
@@ -601,7 +600,6 @@ class PositionalEncoding(nn.Module):
class Transformer(nn.Module):
"""Transformer Model
Args:
@@ -649,7 +647,6 @@ class Transformer(nn.Module):
class TRA(nn.Module):
"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,

View File

@@ -373,7 +373,6 @@ class WeightStrategyBase(BaseSignalStrategy):
class EnhancedIndexingStrategy(WeightStrategyBase):
"""Enhanced Indexing Strategy
Enhanced indexing combines the arts of active management and passive management,

View File

@@ -536,7 +536,6 @@ class DatasetProvider(abc.ABC):
"""
if len(fields) == 0:
raise ValueError("fields cannot be empty")
fields = fields.copy()
column_names = [str(f) for f in fields]
return column_names

View File

@@ -318,9 +318,13 @@ class CSZScoreNorm(Processor):
# try not modify original dataframe
if not isinstance(self.fields_group, list):
self.fields_group = [self.fields_group]
for g in self.fields_group:
cols = get_group_columns(df, g)
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func)
# depress warning by references:
# https://stackoverflow.com/questions/20625582/how-to-deal-with-settingwithcopywarning-in-pandas
# https://pandas.pydata.org/pandas-docs/stable/user_guide/options.html#getting-and-setting-options
with pd.option_context("mode.chained_assignment", None):
for g in self.fields_group:
cols = get_group_columns(df, g)
df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func)
return df

View File

@@ -30,7 +30,6 @@ class Ensemble:
class SingleKeyEnsemble(Ensemble):
"""
Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
@@ -64,7 +63,6 @@ class SingleKeyEnsemble(Ensemble):
class RollingEnsemble(Ensemble):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".

View File

@@ -247,9 +247,7 @@ class ShrinkCovEstimator(RiskModel):
v1 = y.T.dot(z) / t - cov_mkt[:, None] * S
roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt
v3 = z.T.dot(z) / t - var_mkt * S
roff3 = (
np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
)
roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
roff = 2 * roff1 - roff3
rho = rdiag + roff

View File

@@ -168,7 +168,9 @@ class TrainingVessel(TrainingVesselBase):
self.policy.train()
with vector_env.collector_guard():
collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)))
collector = Collector(
self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)), exploration_noise=True
)
# Number of episodes collected in each training iteration can be overridden by fast dev run.
if self.trainer.fast_dev_run is not None:

View File

@@ -25,7 +25,12 @@ import pandas as pd
from pathlib import Path
from typing import List, Union, Optional, Callable
from packaging import version
from .file import get_or_create_path, save_multiple_parts_file, unpack_archive_with_buffer, get_tmp_file_with_buffer
from .file import (
get_or_create_path,
save_multiple_parts_file,
unpack_archive_with_buffer,
get_tmp_file_with_buffer,
)
from ..config import C
from ..log import get_module_logger, set_log_with_config
@@ -37,7 +42,12 @@ is_deprecated_lexsorted_pandas = version.parse(pd.__version__) > version.parse("
#################### Server ####################
def get_redis_connection():
"""get redis connection instance."""
return redis.StrictRedis(host=C.redis_host, port=C.redis_port, db=C.redis_task_db, password=C.redis_password)
return redis.StrictRedis(
host=C.redis_host,
port=C.redis_port,
db=C.redis_task_db,
password=C.redis_password,
)
#################### Data ####################
@@ -96,7 +106,14 @@ def get_period_offset(first_year, period, quarterly):
return offset
def read_period_data(index_path, data_path, period, cur_date_int: int, quarterly, last_period_index: int = None):
def read_period_data(
index_path,
data_path,
period,
cur_date_int: int,
quarterly,
last_period_index: int = None,
):
"""
At `cur_date`(e.g. 20190102), read the information at `period`(e.g. 201803).
Only the updating info before cur_date or at cur_date will be used.
@@ -273,7 +290,10 @@ def parse_field(field):
# \uff09 -> )
chinese_punctuation_regex = r"\u3001\uff1a\uff08\uff09"
for pattern, new in [
(rf"\$\$([\w{chinese_punctuation_regex}]+)", r'PFeature("\1")'), # $$ must be before $
(
rf"\$\$([\w{chinese_punctuation_regex}]+)",
r'PFeature("\1")',
), # $$ must be before $
(rf"\$([\w{chinese_punctuation_regex}]+)", r'Feature("\1")'),
(r"(\w+\s*)\(", r"Operators.\1("),
]: # Features # Operators
@@ -383,7 +403,14 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
def get_date_by_shift(
trading_date,
shift,
future=False,
clip_shift=True,
freq="day",
align: Optional[str] = None,
):
"""get trading date with shift bias will cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
@@ -569,7 +596,38 @@ def exists_qlib_data(qlib_dir):
# check instruments
code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()))
_instrument = instruments_dir.joinpath("all.txt")
miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
# Removed two possible ticker names "NA" and "NULL" from the default na_values list for column 0
miss_code = set(
pd.read_csv(
_instrument,
sep="\t",
header=None,
keep_default_na=False,
na_values={
0: [
" ",
"#N/A",
"#N/A N/A",
"#NA",
"-1.#IND",
"-1.#QNAN",
"-NaN",
"-nan",
"1.#IND",
"1.#QNAN",
"<NA>",
"N/A",
"NaN",
"None",
"n/a",
"nan",
"null ",
]
},
)
.loc[:, 0]
.apply(str.lower)
) - set(code_names)
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
return False

View File

@@ -90,7 +90,6 @@ class OnlineStrategy:
class RollingStrategy(OnlineStrategy):
"""
This example strategy always uses the latest rolling model sas online models.
"""

View File

@@ -4,8 +4,10 @@
import logging
import warnings
import pandas as pd
import numpy as np
from tqdm import trange
from pprint import pprint
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict
from qlib.utils.exceptions import LoadObjectError
from ..contrib.evaluate import risk_analysis, indicator_analysis
@@ -17,6 +19,7 @@ from ..log import get_module_logger
from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_shift
from ..utils.time import Freq
from ..utils.data import deepcopy_basic_type
from ..utils.exceptions import QlibException
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
@@ -230,9 +233,16 @@ class ACRecordTemp(RecordTemp):
except FileNotFoundError:
logger.warning("The dependent data does not exists. Generation skipped.")
return
return self._generate(*args, **kwargs)
artifact_dict = self._generate(*args, **kwargs)
if isinstance(artifact_dict, dict):
self.save(**artifact_dict)
return artifact_dict
def _generate(self, *args, **kwargs):
def _generate(self, *args, **kwargs) -> Dict[str, object]:
"""
Run the concrete generating task, return the dictionary of the generated results.
The caller method will save the results to the recorder.
"""
raise NotImplementedError(f"Please implement the `_generate` method")
@@ -336,8 +346,8 @@ class SigAnaRecord(ACRecordTemp):
}
)
self.recorder.log_metrics(**metrics)
self.save(**objects)
pprint(metrics)
return objects
def list(self):
paths = ["ic.pkl", "ric.pkl"]
@@ -468,17 +478,18 @@ class PortAnaRecord(ACRecordTemp):
if self.backtest_config["end_time"] is None:
self.backtest_config["end_time"] = get_date_by_shift(dt_values.max(), 1)
artifact_objects = {}
# custom strategy and get backtest
portfolio_metric_dict, indicator_dict = normal_backtest(
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal})
artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal})
for _freq, indicators_normal in indicator_dict.items():
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
self.save(**{f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]})
artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]})
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq not in portfolio_metric_dict:
@@ -500,7 +511,7 @@ class PortAnaRecord(ACRecordTemp):
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -525,12 +536,13 @@ class PortAnaRecord(ACRecordTemp):
analysis_dict = analysis_df["value"].to_dict()
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
pprint(analysis_df)
return artifact_objects
def list(self):
list_path = []
@@ -553,3 +565,124 @@ class PortAnaRecord(ACRecordTemp):
else:
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
return list_path
class MultiPassPortAnaRecord(PortAnaRecord):
"""
This is the Multiple Pass Portfolio Analysis Record class that run backtest multiple times and generates the analysis results such as those of backtest. This class inherits the ``PortAnaRecord`` class.
If shuffle_init_score enabled, the prediction score of the first backtest date will be shuffled, so that initial position will be random.
The shuffle_init_score will only works when the signal is used as <PRED> placeholder. The placeholder will be replaced by pred.pkl saved in recorder.
Parameters
----------
recorder : Recorder
The recorder used to save the backtest results.
pass_num : int
The number of backtest passes.
shuffle_init_score : bool
Whether to shuffle the prediction score of the first backtest date.
"""
depend_cls = SignalRecord
def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs):
"""
Parameters
----------
recorder : Recorder
The recorder used to save the backtest results.
pass_num : int
The number of backtest passes.
shuffle_init_score : bool
Whether to shuffle the prediction score of the first backtest date.
"""
self.pass_num = pass_num
self.shuffle_init_score = shuffle_init_score
super().__init__(recorder, **kwargs)
# Save original strategy so that pred df can be replaced in next generate
self.original_strategy = deepcopy_basic_type(self.strategy_config)
if not isinstance(self.original_strategy, dict):
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict")
if "signal" not in self.original_strategy.get("kwargs", {}):
raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter")
def random_init(self):
pred_df = self.load("pred.pkl")
all_pred_dates = pred_df.index.get_level_values("datetime")
bt_start_date = pd.to_datetime(self.backtest_config.get("start_time"))
if bt_start_date is None:
first_bt_pred_date = all_pred_dates.min()
else:
first_bt_pred_date = all_pred_dates[all_pred_dates >= bt_start_date].min()
# Shuffle the first backtest date's pred score
first_date_score = pred_df.loc[first_bt_pred_date]["score"]
np.random.shuffle(first_date_score.values)
# Use shuffled signal as the strategy signal
self.strategy_config = deepcopy_basic_type(self.original_strategy)
self.strategy_config["kwargs"]["signal"] = pred_df
def _generate(self, **kwargs):
risk_analysis_df_map = {}
# Collect each frequency's analysis df as df list
for i in trange(self.pass_num):
if self.shuffle_init_score:
self.random_init()
# Not check for cache file list
single_run_artifacts = super()._generate(**kwargs)
for _analysis_freq in self.risk_analysis_freq:
risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, [])
risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list
analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"]
analysis_df["run_id"] = i
risk_analysis_df_list.append(analysis_df)
result_artifacts = {}
# Concat df list
for _analysis_freq in self.risk_analysis_freq:
combined_df = pd.concat(risk_analysis_df_map[_analysis_freq])
# Calculate return and information ratio's mean, std and mean/std
multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1]).apply(
lambda x: pd.Series(
{"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()}
)
)
# Only look at "annualized_return" and "information_ratio"
multi_pass_port_analysis_df = multi_pass_port_analysis_df.loc[
(slice(None), ["annualized_return", "information_ratio"]), :
]
pprint(multi_pass_port_analysis_df)
# Save new df
result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df})
# Log metrics
metrics = flatten_dict(
{
"mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(),
"std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(),
"mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(),
}
)
self.recorder.log_metrics(**metrics)
return result_artifacts
def list(self):
list_path = []
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(f"multi_pass_port_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
return list_path

View File

@@ -0,0 +1,81 @@
## Collector Data
### Get Qlib data(`bin file`)
- get data: `python scripts/get_data.py qlib_data`
- parameters:
- `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data_5min*
- `version`: dataset version, value from [`v2`], by default `v2`
- `v2` end date is *2022-12*
- `interval`: `5min`
- `region`: `hs300`
- `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True`
- `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False`
- examples:
```bash
# hs300 5min
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/hs300_data_5min --region hs300 --interval 5min
```
### Collector *Baostock high frequency* data to qlib
> collector *Baostock high frequency* data and *dump* into `qlib` format.
> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data.
1. download data to csv: `python scripts/data_collector/baostock_5min/collector.py download_data`
This will download the raw data such as date, symbol, open, high, low, close, volume, amount, adjustflag from baostock to a local directory. One file per symbol.
- parameters:
- `source_dir`: save the directory
- `interval`: `5min`
- `region`: `HS300`
- `start`: start datetime, by default *None*
- `end`: end datetime, by default *None*
- examples:
```bash
# cn 5min data
python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300
```
2. normalize data: `python scripts/data_collector/baostock_5min/collector.py normalize_data`
This will:
1. Normalize high, low, close, open price using adjclose.
2. Normalize the high, low, close, open price so that the first valid trading date's close price is 1.
- parameters:
- `source_dir`: csv directory
- `normalize_dir`: result directory
- `interval`: `5min`
> if **`interval == 5min`**, `qlib_data_1d_dir` cannot be `None`
- `region`: `HS300`
- `date_field_name`: column *name* identifying time in csv files, by default `date`
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
- `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None`
- `qlib_data_1d_dir`: qlib directory(1d data)
if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;
```
# qlib_data_1d can be obtained like this:
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3
```
- examples:
```bash
# normalize 5min cn
python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min
```
3. dump data: `python scripts/dump_bin.py dump_all`
This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory.
- parameters:
- `csv_path`: stock data path or directory, **normalize result(normalize_dir)**
- `qlib_dir`: qlib(dump) data director
- `freq`: transaction frequency, by default `day`
> `freq_map = {1d:day, 5mih: 5min}`
- `max_workers`: number of threads, by default *16*
- `include_fields`: dump fields, by default `""`
- `exclude_fields`: fields not dumped, by default `"""
> dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns`
- `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol`
- `date_field_name`: column *name* identifying time in csv files, by default `date`
- examples:
```bash
# dump 5min cn
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol
```

View File

@@ -0,0 +1,328 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
import copy
import fire
import numpy as np
import pandas as pd
import baostock as bs
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from typing import Iterable, List
import qlib
from qlib.data import D
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.utils import generate_minutes_calendar_from_daily, calc_adjusted_price
class BaostockCollectorHS3005min(BaseCollector):
def __init__(
self,
save_dir: [str, Path],
start=None,
end=None,
interval="5min",
max_workers=4,
max_collector_count=2,
delay=0,
check_data_length: int = None,
limit_nums: int = None,
):
"""
Parameters
----------
save_dir: str
stock save dir
max_workers: int
workers, default 4
max_collector_count: int
default 2
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [5min], default 5min
start: str
start datetime, default None
end: str
end datetime, default None
check_data_length: int
check data length, by default None
limit_nums: int
using for debug, by default None
"""
bs.login()
super(BaostockCollectorHS3005min, self).__init__(
save_dir=save_dir,
start=start,
end=end,
interval=interval,
max_workers=max_workers,
max_collector_count=max_collector_count,
delay=delay,
check_data_length=check_data_length,
limit_nums=limit_nums,
)
def get_trade_calendar(self):
_format = "%Y-%m-%d"
start = self.start_datetime.strftime(_format)
end = self.end_datetime.strftime(_format)
rs = bs.query_trade_dates(start_date=start, end_date=end)
calendar_list = []
while (rs.error_code == "0") & rs.next():
calendar_list.append(rs.get_row_data())
calendar_df = pd.DataFrame(calendar_list, columns=rs.fields)
trade_calendar_df = calendar_df[~calendar_df["is_trading_day"].isin(["0"])]
return trade_calendar_df["calendar_date"].values
@staticmethod
def process_interval(interval: str):
if interval == "1d":
return {"interval": "d", "fields": "date,code,open,high,low,close,volume,amount,adjustflag"}
if interval == "5min":
return {"interval": "5", "fields": "date,time,code,open,high,low,close,volume,amount,adjustflag"}
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
df = self.get_data_from_remote(
symbol=symbol, interval=interval, start_datetime=start_datetime, end_datetime=end_datetime
)
df.columns = ["date", "time", "symbol", "open", "high", "low", "close", "volume", "amount", "adjustflag"]
df["time"] = pd.to_datetime(df["time"], format="%Y%m%d%H%M%S%f")
df["date"] = df["time"].dt.strftime("%Y-%m-%d %H:%M:%S")
df["date"] = df["date"].map(lambda x: pd.Timestamp(x) - pd.Timedelta(minutes=5))
df.drop(["time"], axis=1, inplace=True)
df["symbol"] = df["symbol"].map(lambda x: str(x).replace(".", "").upper())
return df
@staticmethod
def get_data_from_remote(
symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
df = pd.DataFrame()
rs = bs.query_history_k_data_plus(
symbol,
BaostockCollectorHS3005min.process_interval(interval=interval)["fields"],
start_date=str(start_datetime.strftime("%Y-%m-%d")),
end_date=str(end_datetime.strftime("%Y-%m-%d")),
frequency=BaostockCollectorHS3005min.process_interval(interval=interval)["interval"],
adjustflag="3",
)
if rs.error_code == "0" and len(rs.data) > 0:
data_list = rs.data
columns = rs.fields
df = pd.DataFrame(data_list, columns=columns)
return df
def get_hs300_symbols(self) -> List[str]:
hs300_stocks = []
trade_calendar = self.get_trade_calendar()
with tqdm(total=len(trade_calendar)) as p_bar:
for date in trade_calendar:
rs = bs.query_hs300_stocks(date=date)
while rs.error_code == "0" and rs.next():
hs300_stocks.append(rs.get_row_data())
p_bar.update()
return sorted({e[1] for e in hs300_stocks})
def get_instrument_list(self):
logger.info("get HS stock symbols......")
symbols = self.get_hs300_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols
def normalize_symbol(self, symbol: str):
return str(symbol).replace(".", "").upper()
class BaostockNormalizeHS3005min(BaseNormalize):
COLUMNS = ["open", "close", "high", "low", "volume"]
AM_RANGE = ("09:30:00", "11:29:00")
PM_RANGE = ("13:00:00", "14:59:00")
def __init__(
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
Parameters
----------
qlib_data_1d_dir: str, Path
the qlib data to be updated for yahoo, usually from: Normalised to 5min using local 1d data
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
bs.login()
qlib.init(provider_uri=qlib_data_1d_dir)
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name)
@staticmethod
def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series:
df = df.copy()
_tmp_series = df["close"].fillna(method="ffill")
_tmp_shift_series = _tmp_series.shift(1)
if last_close is not None:
_tmp_shift_series.iloc[0] = float(last_close)
change_series = _tmp_series / _tmp_shift_series - 1
return change_series
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
return self.generate_5min_from_daily(self.calendar_list_1d)
@property
def calendar_list_1d(self):
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
if calendar_list_1d is None:
calendar_list_1d = self._get_1d_calendar_list()
setattr(self, "_calendar_list_1d", calendar_list_1d)
return calendar_list_1d
@staticmethod
def normalize_baostock(
df: pd.DataFrame,
calendar_list: list = None,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
last_close: float = None,
):
if df.empty:
return df
symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name]
columns = copy.deepcopy(BaostockNormalizeHS3005min.COLUMNS)
df = df.copy()
df.set_index(date_field_name, inplace=True)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
if calendar_list is not None:
df = df.reindex(
pd.DataFrame(index=calendar_list)
.loc[pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(days=1)]
.index
)
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan
df["change"] = BaostockNormalizeHS3005min.calc_change(df, last_close)
columns += ["change"]
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
df[symbol_field_name] = symbol
df.index.names = [date_field_name]
return df.reset_index()
def generate_5min_from_daily(self, calendars: Iterable) -> pd.Index:
return generate_minutes_calendar_from_daily(
calendars, freq="5min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
)
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
df = calc_adjusted_price(
df=df,
_date_field_name=self._date_field_name,
_symbol_field_name=self._symbol_field_name,
frequence="5min",
_1d_data_all=self.all_1d_data,
)
return df
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
return list(D.calendar(freq="day"))
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
# normalize
df = self.normalize_baostock(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
# adjusted price
df = self.adjusted_price(df)
return df
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="5min", region="HS300"):
"""
Changed the default value of: scripts.data_collector.base.BaseRun.
"""
super().__init__(source_dir, normalize_dir, max_workers, interval)
self.region = region
@property
def collector_class_name(self):
return f"BaostockCollector{self.region.upper()}{self.interval}"
@property
def normalize_class_name(self):
return f"BaostockNormalize{self.region.upper()}{self.interval}"
@property
def default_base_dir(self) -> [Path, str]:
return CUR_DIR
def download_data(
self,
max_collector_count=2,
delay=0.5,
start=None,
end=None,
check_data_length=None,
limit_nums=None,
):
"""download data from Baostock
Notes
-----
check_data_length, example:
hs300 5min, a week: 4 * 60 * 5
Examples
---------
# get hs300 5min data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300
"""
super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums)
def normalize_data(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
end_date: str = None,
qlib_data_1d_dir: str = None,
):
"""normalize data
Attention
---------
qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data;
qlib_data_1d can be obtained like this:
$ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3
or:
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
Examples
---------
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min
"""
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
raise ValueError(
"If normalize 5min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
)
super(Run, self).normalize_data(
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
)
if __name__ == "__main__":
fire.Fire(Run)

View File

@@ -0,0 +1,13 @@
loguru
fire
requests
numpy
pandas
tqdm
lxml
yahooquery
joblib
beautifulsoup4
bs4
soupsieve
baostock

View File

@@ -8,7 +8,7 @@ import datetime
import importlib
from pathlib import Path
from typing import Type, Iterable
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor
import pandas as pd
from tqdm import tqdm
@@ -290,7 +290,7 @@ class Normalize:
# some symbol_field values such as TRUE, NA are decoded as True(bool), NaN(np.float) by pandas default csv parsing.
# manually defines dtype and na_values of the symbol_field.
default_na = pd._libs.parsers.STR_NA_VALUES
default_na = pd._libs.parsers.STR_NA_VALUES # pylint: disable=I1101
symbol_na = default_na.copy()
symbol_na.remove("NA")
columns = pd.read_csv(file_path, nrows=0).columns

View File

@@ -3,7 +3,6 @@
from functools import partial
import sys
from pathlib import Path
import importlib
import datetime
import fire
@@ -98,7 +97,7 @@ class IBOVIndex(IndexBase):
now = datetime.datetime.now()
current_year = now.year
current_month = now.month
for year in [item for item in range(init_year, current_year)]:
for year in [item for item in range(init_year, current_year)]: # pylint: disable=R1721
for el in four_months_period:
self.years_4_month_periods.append(str(year) + "_" + el)
# For current year the logic must be a little different

View File

@@ -1,6 +1,6 @@
async-generator==1.10
attrs==21.4.0
certifi==2021.10.8
certifi==2022.12.7
cffi==1.15.0
charset-normalizer==2.0.12
cryptography==36.0.1
@@ -8,7 +8,7 @@ fire==0.4.0
h11==0.13.0
idna==3.3
loguru==0.6.0
lxml==4.8.0
lxml==4.9.1
multitasking==0.0.10
numpy==1.22.2
outcome==1.1.0

View File

@@ -4,7 +4,6 @@
import re
import abc
import sys
import datetime
from io import BytesIO
from typing import List, Iterable
from pathlib import Path
@@ -39,7 +38,7 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None):
if exclude_status is None:
exclude_status = []
method_func = getattr(requests, method)
_resp = method_func(url, headers=REQ_HEADERS)
_resp = method_func(url, headers=REQ_HEADERS, timeout=None)
_status = _resp.status_code
if _status not in exclude_status and _status != 200:
raise ValueError(f"response status: {_status}, url={url}")
@@ -397,14 +396,7 @@ class CSI500Index(CSIIndex):
today = pd.Timestamp.now()
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
ret_list = []
col = ["date", "symbol", "code_name"]
for date in tqdm(date_range, desc="Download CSI500"):
rs = bs.query_zz500_stocks(date=str(date))
zz500_stocks = []
while (rs.error_code == "0") & rs.next():
zz500_stocks.append(rs.get_row_data())
result = pd.DataFrame(zz500_stocks, columns=col)
result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper())
result = self.get_data_from_baostock(date)
ret_list.append(result[["date", "symbol"]])
bs.logout()

View File

@@ -5,3 +5,5 @@ pandas
lxml
loguru
tqdm
yahooquery
openpyxl

View File

@@ -5,7 +5,6 @@ from abc import ABC
from pathlib import Path
import fire
import requests
import pandas as pd
from loguru import logger
from dateutil.tz import tzlocal
@@ -31,15 +30,15 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
-------
crypto symbols in given exchanges list of coingecko
"""
global _CG_CRYPTO_SYMBOLS
global _CG_CRYPTO_SYMBOLS # pylint: disable=W0603
@deco_retry
def _get_coingecko():
try:
cg = CoinGeckoAPI()
resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd"))
except:
raise ValueError("request error")
except Exception as e:
raise ValueError("request error") from e
try:
_symbols = resp["id"].to_list()
except Exception as e:

View File

@@ -107,7 +107,7 @@ class FundCollector(BaseCollector):
url = INDEX_BENCH_URL.format(
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
)
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}, timeout=None)
if resp.status_code != 200:
raise ValueError("request error")
@@ -116,8 +116,8 @@ class FundCollector(BaseCollector):
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
SYType = data["Data"]["SYType"]
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
raise Exception("The fund contains 每*份收益")
if SYType in {"每万份收益", "每百份收益", "每百万份收益"}:
raise ValueError("The fund contains 每*份收益")
# TODO: should we sort the value by datetime?
_resp = pd.DataFrame(data["Data"]["LSJZList"])

View File

@@ -53,7 +53,7 @@ class CollectorFutureCalendar:
return datetime_d.strftime(self.calendar_format)
def write_calendar(self, calendar: Iterable):
calendars_list = list(map(lambda x: self._format_datetime(x), sorted(set(self.calendar_list + calendar))))
calendars_list = [self._format_datetime(x) for x in sorted(set(self.calendar_list + calendar))]
np.savetxt(self.future_path, calendars_list, fmt="%s", encoding="utf-8")
@abc.abstractmethod

View File

@@ -4,7 +4,6 @@
import abc
from functools import partial
import sys
import importlib
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from typing import List
@@ -113,7 +112,7 @@ class WIKIIndex(IndexBase):
return _calendar_list
def _request_new_companies(self) -> requests.Response:
resp = requests.get(self._target_url)
resp = requests.get(self._target_url, timeout=None)
if resp.status_code != 200:
raise ValueError(f"request error: {self._target_url}")
@@ -164,7 +163,7 @@ class NASDAQ100Index(WIKIIndex):
df = pd.read_pickle(cache_path)
else:
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
resp = requests.post(url)
resp = requests.post(url, timeout=None)
if resp.status_code != 200:
raise ValueError(f"request error: {url}")
df = pd.DataFrame(resp.json()["aaData"])

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import re
import copy
import importlib
import time
import bisect
@@ -68,7 +69,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
logger.info(f"get calendar list: {bench_code}......")
def _get_calendar(url):
_value_list = requests.get(url).json()["data"]["klines"]
_value_list = requests.get(url, timeout=None).json()["data"]["klines"]
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
calendar = _CALENDAR_MAP.get(bench_code, None)
@@ -85,12 +86,14 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
def _get_calendar(month):
_cal = []
try:
resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json()
resp = requests.get(
SZSE_CALENDAR_URL.format(month=month, random=random.random), timeout=None
).json()
for _r in resp["data"]:
if int(_r["jybz"]):
_cal.append(pd.Timestamp(_r["jyrq"]))
except Exception as e:
raise ValueError(f"{month}-->{e}")
raise ValueError(f"{month}-->{e}") from e
return _cal
month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M")
@@ -109,7 +112,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
def return_date_list(date_field_name: str, file_path: Path):
date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list()
return sorted(map(lambda x: pd.Timestamp(x), date_list))
return sorted([pd.Timestamp(x) for x in date_list])
def get_calendar_list_by_ratio(
@@ -155,7 +158,7 @@ def get_calendar_list_by_ratio(
if date_list:
all_oldest_list.append(date_list[0])
for date in date_list:
if date not in _dict_count_trade.keys():
if date not in _dict_count_trade:
_dict_count_trade[date] = 0
_dict_count_trade[date] += 1
@@ -163,7 +166,7 @@ def get_calendar_list_by_ratio(
p_bar.update()
logger.info(f"count how many funds have founded in this day......")
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade} # dict{date:count}
with tqdm(total=_number_all_funds) as p_bar:
for oldest_date in all_oldest_list:
for date in _dict_count_founding.keys():
@@ -171,9 +174,7 @@ def get_calendar_list_by_ratio(
_dict_count_founding[date] -= 1
calendar = [
date
for date in _dict_count_trade
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count)
]
return calendar
@@ -186,16 +187,16 @@ def get_hs_stock_symbols() -> list:
-------
stock symbols
"""
global _HS_SYMBOLS
global _HS_SYMBOLS # pylint: disable=W0603
def _get_symbol():
_res = set()
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k))
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k), timeout=None)
_res |= set(
map(
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), # pylint: disable=W0640
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), # pylint: disable=I1101
)
)
time.sleep(3)
@@ -230,12 +231,12 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
-------
stock symbols
"""
global _US_SYMBOLS
global _US_SYMBOLS # pylint: disable=W0603
@deco_retry
def _get_eastmoney():
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
resp = requests.get(url)
resp = requests.get(url, timeout=None)
if resp.status_code != 200:
raise ValueError("request error")
@@ -277,7 +278,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
"maxResultsPerPage": 10000,
"filterToken": "",
}
resp = requests.post(url, json=_parms)
resp = requests.post(url, json=_parms, timeout=None)
if resp.status_code != 200:
raise ValueError("request error")
@@ -317,7 +318,7 @@ def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
-------
stock symbols
"""
global _IN_SYMBOLS
global _IN_SYMBOLS # pylint: disable=W0603
@deco_retry
def _get_nifty():
@@ -358,7 +359,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
-------
B3 stock symbols
"""
global _BR_SYMBOLS
global _BR_SYMBOLS # pylint: disable=W0603
@deco_retry
def _get_ibovespa():
@@ -367,7 +368,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
# Request
agent = {"User-Agent": "Mozilla/5.0"}
page = requests.get(url, headers=agent)
page = requests.get(url, headers=agent, timeout=None)
# BeautifulSoup
soup = BeautifulSoup(page.content, "html.parser")
@@ -375,7 +376,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
children = tbody.findChildren("a", recursive=True)
for child in children:
_symbols.append(str(child).split('"')[-1].split(">")[1].split("<")[0])
_symbols.append(str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0])
return _symbols
@@ -409,12 +410,12 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
-------
fund symbols in China
"""
global _EN_FUND_SYMBOLS
global _EN_FUND_SYMBOLS # pylint: disable=W0603
@deco_retry
def _get_eastmoney():
url = "http://fund.eastmoney.com/js/fundcode_search.js"
resp = requests.get(url)
resp = requests.get(url, timeout=None)
if resp.status_code != 200:
raise ValueError("request error")
try:
@@ -605,5 +606,177 @@ def get_instruments(
getattr(obj, method)()
def _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame):
df = copy.deepcopy(_1d_data_all)
df.reset_index(inplace=True)
df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
return df
def get_1d_data(
_date_field_name: str,
_symbol_field_name: str,
symbol: str,
start: str,
end: str,
_1d_data_all: pd.DataFrame,
) -> pd.DataFrame:
"""get 1d data
Returns
------
data_1d: pd.DataFrame
data_1d.columns = [_date_field_name, _symbol_field_name, "paused", "volume", "factor", "close"]
"""
_all_1d_data = _get_all_1d_data(_date_field_name, _symbol_field_name, _1d_data_all)
return _all_1d_data[
(_all_1d_data[_symbol_field_name] == symbol.upper())
& (_all_1d_data[_date_field_name] >= pd.Timestamp(start))
& (_all_1d_data[_date_field_name] < pd.Timestamp(end))
]
def calc_adjusted_price(
df: pd.DataFrame,
_1d_data_all: pd.DataFrame,
_date_field_name: str,
_symbol_field_name: str,
frequence: str,
consistent_1d: bool = True,
calc_paused: bool = True,
) -> pd.DataFrame:
"""calc adjusted price
This method does 4 things.
1. Adds the `paused` field.
- The added paused field comes from the paused field of the 1d data.
2. Aligns the time of the 1d data.
3. The data is reweighted.
- The reweighting method:
- volume / factor
- open * factor
- high * factor
- low * factor
- close * factor
4. Called `calc_paused_num` method to add the `paused_num` field.
- The `paused_num` is the number of consecutive days of trading suspension.
"""
# TODO: using daily data factor
if df.empty:
return df
df = df.copy()
df.drop_duplicates(subset=_date_field_name, inplace=True)
df.sort_values(_date_field_name, inplace=True)
symbol = df.iloc[0][_symbol_field_name]
df[_date_field_name] = pd.to_datetime(df[_date_field_name])
# get 1d data from qlib
_start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d")
_end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all)
data_1d = data_1d.copy()
if data_1d is None or data_1d.empty:
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
# TODO: np.nan or 1 or 0
df["paused"] = np.nan
else:
# NOTE: volume is np.nan or volume <= 0, paused = 1
# FIXME: find a more accurate data source
data_1d["paused"] = 0
data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1
data_1d = data_1d.set_index(_date_field_name)
# add factor from 1d data
# NOTE: 1d data info:
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
def _calc_factor(df_1d: pd.DataFrame):
try:
_date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date())
df_1d["factor"] = data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"]
df_1d["paused"] = data_1d.loc[_date]["paused"]
except Exception:
df_1d["factor"] = np.nan
df_1d["paused"] = np.nan
return df_1d
df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor)
if consistent_1d:
# the date sequence is consistent with 1d
df.set_index(_date_field_name, inplace=True)
df = df.reindex(
generate_minutes_calendar_from_daily(
calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()),
freq=frequence,
am_range=("09:30:00", "11:29:00"),
pm_range=("13:00:00", "14:59:00"),
)
)
df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name]
df.index.names = [_date_field_name]
df.reset_index(inplace=True)
for _col in ["open", "close", "high", "low", "volume"]:
if _col not in df.columns:
continue
if _col == "volume":
df[_col] = df[_col] / df["factor"]
else:
df[_col] = df[_col] * df["factor"]
if calc_paused:
df = calc_paused_num(df, _date_field_name, _symbol_field_name)
return df
def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name):
"""calc paused num
This method adds the paused_num field
- The `paused_num` is the number of consecutive days of trading suspension.
"""
_symbol = df.iloc[0][_symbol_field_name]
df = df.copy()
df["_tmp_date"] = df[_date_field_name].apply(lambda x: pd.Timestamp(x).date())
# remove data that starts and ends with `np.nan` all day
all_data = []
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
all_nan_nums = 0
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
not_nan_nums = 0
for _date, _df in df.groupby("_tmp_date"):
_df["paused"] = 0
if not _df.loc[_df["volume"] < 0].empty:
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
_df.loc[_df["volume"] < 0, "volume"] = np.nan
check_fields = set(_df.columns) - {
"_tmp_date",
"paused",
"factor",
_date_field_name,
_symbol_field_name,
}
if _df.loc[:, list(check_fields)].isna().values.all() or (_df["volume"] == 0).all():
all_nan_nums += 1
not_nan_nums = 0
_df["paused"] = 1
if all_data:
_df["paused_num"] = not_nan_nums
all_data.append(_df)
else:
all_nan_nums = 0
not_nan_nums += 1
_df["paused_num"] = not_nan_nums
all_data.append(_df)
all_data = all_data[: len(all_data) - all_nan_nums]
if all_data:
df = pd.concat(all_data, sort=False)
else:
logger.warning(f"data is empty: {_symbol}")
df = pd.DataFrame()
return df
del df["_tmp_date"]
return df
if __name__ == "__main__":
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM

View File

@@ -121,7 +121,7 @@ pip install -r requirements.txt
qlib_data_1d can be obtained like this:
$ python scripts/get_data.py qlib_data --target_dir <qlib_data_1d_dir> --interval 1d
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --end_date <end_date>
or:
download 1d data from YahooFinance
@@ -180,9 +180,8 @@ pip install -r requirements.txt
* Manual update of data
```
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --end_date <end date>
```
* `trading_date`: start of trading day
* `end_date`: end of trading day(not included)
* `check_data_length`: check the number of rows per *symbol*, by default `None`
> if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter
@@ -191,10 +190,10 @@ pip install -r requirements.txt
* `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
* `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize"
* `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
* `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
* `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
* `region`: region, value from ["CN", "US"], default "CN"
* `interval`: interval, default "1d"(Currently only supports 1d data)
* `exists_skip`: exists skip, by default False
## Using qlib data

View File

@@ -2,7 +2,6 @@
# Licensed under the MIT License.
import abc
from re import I
import sys
import copy
import time
@@ -21,6 +20,8 @@ from loguru import logger
from yahooquery import Ticker
from dateutil.tz import tzlocal
import qlib
from qlib.data import D
from qlib.tests.data import GetData
from qlib.utils import code_to_fname, fname_to_code, exists_qlib_data
from qlib.constant import REG_CN as REGION_CN
@@ -38,6 +39,7 @@ from data_collector.utils import (
get_in_stock_symbols,
get_br_stock_symbols,
generate_minutes_calendar_from_daily,
calc_adjusted_price,
)
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
@@ -229,9 +231,9 @@ class YahooCollectorCN1d(YahooCollectorCN):
df = pd.DataFrame(
map(
lambda x: x.split(","),
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[
"data"
]["klines"],
requests.get(
INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end), timeout=None
).json()["data"]["klines"],
)
)
except Exception as e:
@@ -316,7 +318,7 @@ class YahooCollectorIN1min(YahooCollectorIN):
class YahooCollectorBR(YahooCollector, ABC):
def retry(cls):
def retry(cls): # pylint: disable=E0213
"""
The reason to use retry=2 is due to the fact that
Yahoo Finance unfortunately does not keep track of some
@@ -356,12 +358,10 @@ class YahooCollectorBR(YahooCollector, ABC):
class YahooCollectorBR1d(YahooCollectorBR):
retry = 2
pass
class YahooCollectorBR1min(YahooCollectorBR):
retry = 2
pass
class YahooNormalize(BaseNormalize):
@@ -393,6 +393,7 @@ class YahooNormalize(BaseNormalize):
df = df.copy()
df.set_index(date_field_name, inplace=True)
df.index = pd.to_datetime(df.index)
df.index = df.index.tz_localize(None)
df = df[~df.index.duplicated(keep="first")]
if calendar_list is not None:
df = df.reindex(
@@ -522,78 +523,39 @@ class YahooNormalize1dExtend(YahooNormalize1d):
symbol field name, default is symbol
"""
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
self._first_close_field = "first_close"
self._ori_close_field = "ori_close"
self.column_list = ["open", "high", "low", "close", "volume", "factor", "change"]
self.old_qlib_data = self._get_old_data(old_qlib_data_dir)
def _get_old_data(self, qlib_data_dir: [str, Path]):
import qlib
from qlib.data import D
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"])
df.columns = [self._ori_close_field, self._first_close_field]
df = D.features(D.instruments("all"), ["$" + col for col in self.column_list])
df.columns = self.column_list
return df
def _get_close(self, df: pd.DataFrame, field_name: str):
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
_df = self.old_qlib_data.loc(axis=0)[_symbol]
_close = _df.loc[_df.last_valid_index()][field_name]
return _close
def _get_first_close(self, df: pd.DataFrame) -> float:
try:
_close = self._get_close(df, field_name=self._first_close_field)
except KeyError:
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
return _close
def _get_last_close(self, df: pd.DataFrame) -> float:
try:
_close = self._get_close(df, field_name=self._ori_close_field)
except KeyError:
_close = None
return _close
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
try:
_df = self.old_qlib_data.loc(axis=0)[_symbol]
_date = _df.index.max()
except KeyError:
_date = None
return _date
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
_last_close = self._get_last_close(df)
# reindex
_last_date = self._get_last_date(df)
if _last_date is not None:
df = df.set_index(self._date_field_name)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
_max_date = df.index.max()
df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
df = df[df[self._date_field_name] > _last_date]
if df.empty:
return pd.DataFrame()
_si = df["close"].first_valid_index()
if _si > df.index[0]:
logger.warning(
f"{df.loc[_si][self._symbol_field_name]} missing data: {df.loc[:_si - 1][self._date_field_name].to_list()}"
)
# normalize
df = self.normalize_yahoo(
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
)
# adjusted price
df = self.adjusted_price(df)
df = self._manual_adj_data(df)
return df
df = super(YahooNormalize1dExtend, self).normalize(df)
df.set_index(self._date_field_name, inplace=True)
symbol_name = df[self._symbol_field_name].iloc[0]
old_symbol_list = self.old_qlib_data.index.get_level_values("instrument").unique().to_list()
if str(symbol_name).upper() not in old_symbol_list:
return df.reset_index()
old_df = self.old_qlib_data.loc[str(symbol_name).upper()]
latest_date = old_df.index[-1]
df = df.loc[latest_date:]
new_latest_data = df.iloc[0]
old_latest_data = old_df.loc[latest_date]
for col in self.column_list[:-1]:
if col == "volume":
df[col] = df[col] / (new_latest_data[col] / old_latest_data[col])
else:
df[col] = df[col] * (old_latest_data[col] / new_latest_data[col])
return df.drop(df.index[0]).reset_index()
class YahooNormalize1min(YahooNormalize, ABC):
"""Normalised to 1min using local 1d data"""
AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00")
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
@@ -601,160 +563,6 @@ class YahooNormalize1min(YahooNormalize, ABC):
CONSISTENT_1d = True
CALC_PAUSED_NUM = True
@property
def calendar_list_1d(self):
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
if calendar_list_1d is None:
calendar_list_1d = self._get_1d_calendar_list()
setattr(self, "_calendar_list_1d", calendar_list_1d)
return calendar_list_1d
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
return generate_minutes_calendar_from_daily(
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
)
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
"""get 1d data
Returns
------
data_1d: pd.DataFrame
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
"""
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
if not (data_1d is None or data_1d.empty):
_class_name = self.__class__.__name__.replace("min", "d")
_class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name)
data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
data_1d = data_1d_obj.normalize(data_1d)
return data_1d
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
# TODO: using daily data factor
if df.empty:
return df
df = df.copy()
df = df.sort_values(self._date_field_name)
symbol = df.iloc[0][self._symbol_field_name]
# get 1d data from yahoo
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
data_1d = data_1d.copy()
if data_1d is None or data_1d.empty:
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"]
# TODO: np.nan or 1 or 0
df["paused"] = np.nan
else:
# NOTE: volume is np.nan or volume <= 0, paused = 1
# FIXME: find a more accurate data source
data_1d["paused"] = 0
data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1
data_1d = data_1d.set_index(self._date_field_name)
# add factor from 1d data
# NOTE: yahoo 1d data info:
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
def _calc_factor(df_1d: pd.DataFrame):
try:
_date = pd.Timestamp(pd.Timestamp(df_1d[self._date_field_name].iloc[0]).date())
df_1d["factor"] = (
data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"]
)
df_1d["paused"] = data_1d.loc[_date]["paused"]
except Exception:
df_1d["factor"] = np.nan
df_1d["paused"] = np.nan
return df_1d
df = df.groupby([df[self._date_field_name].dt.date]).apply(_calc_factor)
if self.CONSISTENT_1d:
# the date sequence is consistent with 1d
df.set_index(self._date_field_name, inplace=True)
df = df.reindex(
self.generate_1min_from_daily(
pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates())
)
)
df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][
self._symbol_field_name
]
df.index.names = [self._date_field_name]
df.reset_index(inplace=True)
for _col in self.COLUMNS:
if _col not in df.columns:
continue
if _col == "volume":
df[_col] = df[_col] / df["factor"]
else:
df[_col] = df[_col] * df["factor"]
if self.CALC_PAUSED_NUM:
df = self.calc_paused_num(df)
return df
def calc_paused_num(self, df: pd.DataFrame):
_symbol = df.iloc[0][self._symbol_field_name]
df = df.copy()
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
# remove data that starts and ends with `np.nan` all day
all_data = []
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
all_nan_nums = 0
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
not_nan_nums = 0
for _date, _df in df.groupby("_tmp_date"):
_df["paused"] = 0
if not _df.loc[_df["volume"] < 0].empty:
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
_df.loc[_df["volume"] < 0, "volume"] = np.nan
check_fields = set(_df.columns) - {
"_tmp_date",
"paused",
"factor",
self._date_field_name,
self._symbol_field_name,
}
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
all_nan_nums += 1
not_nan_nums = 0
_df["paused"] = 1
if all_data:
_df["paused_num"] = not_nan_nums
all_data.append(_df)
else:
all_nan_nums = 0
not_nan_nums += 1
_df["paused_num"] = not_nan_nums
all_data.append(_df)
all_data = all_data[: len(all_data) - all_nan_nums]
if all_data:
df = pd.concat(all_data, sort=False)
else:
logger.warning(f"data is empty: {_symbol}")
df = pd.DataFrame()
return df
del df["_tmp_date"]
return df
@abc.abstractmethod
def symbol_to_yahoo(self, symbol):
raise NotImplementedError("rewrite symbol_to_yahoo")
@abc.abstractmethod
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
raise NotImplementedError("rewrite _get_1d_calendar_list")
class YahooNormalize1minOffline(YahooNormalize1min):
"""Normalised to 1min using local 1d data"""
def __init__(
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
@@ -769,42 +577,45 @@ class YahooNormalize1minOffline(YahooNormalize1min):
symbol_field_name: str
symbol field name, default is symbol
"""
self.qlib_data_1d_dir = qlib_data_1d_dir
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
self._all_1d_data = self._get_all_1d_data()
super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)
qlib.init(provider_uri=qlib_data_1d_dir)
self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
return list(D.calendar(freq="day"))
def _get_all_1d_data(self):
import qlib
from qlib.data import D
@property
def calendar_list_1d(self):
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
if calendar_list_1d is None:
calendar_list_1d = self._get_1d_calendar_list()
setattr(self, "_calendar_list_1d", calendar_list_1d)
return calendar_list_1d
qlib.init(provider_uri=self.qlib_data_1d_dir)
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
df.reset_index(inplace=True)
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
return generate_minutes_calendar_from_daily(
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
)
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
df = calc_adjusted_price(
df=df,
_date_field_name=self._date_field_name,
_symbol_field_name=self._symbol_field_name,
frequence="1min",
consistent_1d=self.CONSISTENT_1d,
calc_paused=self.CALC_PAUSED_NUM,
_1d_data_all=self.all_1d_data,
)
return df
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
"""get 1d data
@abc.abstractmethod
def symbol_to_yahoo(self, symbol):
raise NotImplementedError("rewrite symbol_to_yahoo")
Returns
------
data_1d: pd.DataFrame
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
"""
return self._all_1d_data[
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
]
@abc.abstractmethod
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
raise NotImplementedError("rewrite _get_1d_calendar_list")
class YahooNormalizeUS:
@@ -821,7 +632,7 @@ class YahooNormalizeUS1dExtend(YahooNormalizeUS, YahooNormalize1dExtend):
pass
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
CALC_PAUSED_NUM = False
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
@@ -844,7 +655,7 @@ class YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d):
pass
class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1minOffline):
class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1min):
CALC_PAUSED_NUM = False
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
@@ -872,7 +683,7 @@ class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
pass
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
AM_RANGE = ("09:30:00", "11:29:00")
PM_RANGE = ("13:00:00", "14:59:00")
@@ -899,7 +710,7 @@ class YahooNormalizeBR1d(YahooNormalizeBR, YahooNormalize1d):
pass
class YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1minOffline):
class YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1min):
CALC_PAUSED_NUM = False
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
@@ -1123,10 +934,10 @@ class Run(BaseRun):
def update_data_to_bin(
self,
qlib_data_1d_dir: str,
trading_date: str = None,
end_date: str = None,
check_data_length: int = None,
delay: float = 1,
exists_skip: bool = False,
):
"""update yahoo data to bin
@@ -1135,14 +946,14 @@ class Run(BaseRun):
qlib_data_1d_dir: str
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
trading_date: str
trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")``
end_date: str
end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end)
check_data_length: int
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
delay: float
time.sleep(delay), default 1
exists_skip: bool
exists skip, by default False
Notes
-----
If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day
@@ -1150,24 +961,24 @@ class Run(BaseRun):
Examples
-------
$ python collector.py update_data_to_bin --qlib_data_1d_dir <user data dir> --trading_date <start date> --end_date <end date>
# get 1m data
"""
if self.interval.lower() != "1d":
logger.warning(f"currently supports 1d data updates: --interval 1d")
# start/end date
if trading_date is None:
trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
logger.warning(f"trading_date is None, use the current date: {trading_date}")
if end_date is None:
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
# download qlib 1d data
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
if not exists_qlib_data(qlib_data_1d_dir):
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
GetData().qlib_data(
target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip
)
# start/end date
calendar_df = pd.read_csv(Path(qlib_data_1d_dir).joinpath("calendars/day.txt"))
trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=1)).strftime("%Y-%m-%d")
if end_date is None:
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
# download data from yahoo
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1

View File

@@ -135,7 +135,7 @@ class DumpDataBase:
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
df[self.date_field_name] = df[self.date_field_name].astype(str).astype(np.datetime64)
df[self.date_field_name] = df[self.date_field_name].astype(str).astype("datetime64[ns]")
# df.drop_duplicates([self.date_field_name], inplace=True)
return df
@@ -146,9 +146,7 @@ class DumpDataBase:
return (
self._include_fields
if self._include_fields
else set(df_columns) - set(self._exclude_fields)
if self._exclude_fields
else df_columns
else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns
)
@staticmethod
@@ -176,7 +174,7 @@ class DumpDataBase:
def save_calendars(self, calendars_data: list):
self._calendars_dir.mkdir(parents=True, exist_ok=True)
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data))
result_calendars_list = [self._format_datetime(x) for x in calendars_data]
np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8")
def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
@@ -195,7 +193,7 @@ class DumpDataBase:
def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:
# calendars
calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype("datetime64[ns]")
cal_df = calendars_df[
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())

View File

@@ -3,24 +3,21 @@
"""
TODO:
- A more well-designed PIT database is required.
- seperated insert, delete, update, query operations are required.
- separated insert, delete, update, query operations are required.
"""
import abc
import shutil
import struct
import traceback
from pathlib import Path
from typing import Iterable, List, Union
from typing import Iterable
from functools import partial
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor
import fire
import numpy as np
import pandas as pd
from tqdm import tqdm
from loguru import logger
from qlib.utils import fname_to_code, code_to_fname, get_period_offset
from qlib.utils import fname_to_code, get_period_offset
from qlib.config import C
@@ -135,9 +132,11 @@ class DumpPitData:
return (
set(self._include_fields)
if self._include_fields
else set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
else (
set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
)
)
def get_filenames(self, symbol, field, interval):

View File

@@ -65,12 +65,18 @@ REQUIRED = [
# To ensure stable operation of the experiment manager, we have limited the version of mlflow,
# and we need to verify whether version 2.0 of mlflow can serve qlib properly.
"mlflow>=1.12.1, <=1.30.0",
# mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail.
"packaging<22",
"tqdm",
"loguru",
"lightgbm>=3.3.0",
"tornado",
"joblib>=0.17.0",
"ruamel.yaml>=0.16.12",
# With the upgrading of ruamel.yaml to 0.18, the safe_load method was deprecated,
# which would cause qlib.workflow.cli to not work properly,
# and no good replacement has been found, so the version of ruamel.yaml has been restricted for now.
# Refs: https://pypi.org/project/ruamel.yaml/
"ruamel.yaml<=0.17.36",
"pymongo==3.7.2", # For task management
"scikit-learn>=0.22",
"dill",
@@ -140,7 +146,8 @@ setup(
"wheel",
"setuptools",
"black",
"pylint",
# Version 3.0 of pylint had problems with the build process, so we limited the version of pylint.
"pylint<=2.17.6",
# Using the latest versions(0.981 and 0.982) of mypy,
# the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py",
# If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy.

View File

@@ -13,7 +13,9 @@ from pathlib import Path
from qlib.data import D
from qlib.tests.data import GetData
from scripts.dump_pit import DumpPitData
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from dump_pit import DumpPitData
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit")))
from collector import Run

View File

@@ -9,7 +9,9 @@ from qlib.tests import TestAutoData
class WorkflowTest(TestAutoData):
TMP_PATH = Path("./.mlruns_tmp/")
# Creating the directory manually doesn't work with mlflow,
# so we add a subfolder named .trash when we create the directory.
TMP_PATH = Path("./.mlruns_tmp/.trash")
def tearDown(self) -> None:
if self.TMP_PATH.exists():
@@ -17,6 +19,8 @@ class WorkflowTest(TestAutoData):
def test_get_local_dir(self):
""" """
self.TMP_PATH.mkdir(parents=True, exist_ok=True)
with R.start(uri=str(self.TMP_PATH)):
pass