mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 00:51:19 +08:00
Compare commits
46 Commits
v0.8.0a1
...
backtest_i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ac9dd7221 | ||
|
|
7efec6bbc4 | ||
|
|
3fa48d7017 | ||
|
|
4f2d6b0d84 | ||
|
|
3943b7001f | ||
|
|
2593185721 | ||
|
|
7a884fa9f2 | ||
|
|
d929d4bb21 | ||
|
|
e54b019ee2 | ||
|
|
426b98a3bc | ||
|
|
82f8ff9066 | ||
|
|
31e9d529de | ||
|
|
5fa56703ae | ||
|
|
c6bb11fe56 | ||
|
|
3d7ebd1fe0 | ||
|
|
7313b4dad0 | ||
|
|
b70caff522 | ||
|
|
96b422a906 | ||
|
|
64130d9407 | ||
|
|
a58bc03a8e | ||
|
|
f537222ce3 | ||
|
|
c427c64845 | ||
|
|
22ff8fdc44 | ||
|
|
4efb0a75c1 | ||
|
|
052aad7982 | ||
|
|
12f05c7182 | ||
|
|
ac08468330 | ||
|
|
df9745f134 | ||
|
|
2e49a5f7c0 | ||
|
|
3ab5721448 | ||
|
|
6a94b45503 | ||
|
|
7c31012b50 | ||
|
|
334b92ace7 | ||
|
|
9a175d7507 | ||
|
|
17ea44e0cf | ||
|
|
c0ce712be9 | ||
|
|
8e81a017c1 | ||
|
|
706727988c | ||
|
|
e99224e5c2 | ||
|
|
8c8d1336de | ||
|
|
d01de411a8 | ||
|
|
28fe4d4bb4 | ||
|
|
873129aa9b | ||
|
|
3a152f9b8b | ||
|
|
2b75b41a08 | ||
|
|
00d17f0a52 |
8
.github/workflows/python-publish.yml
vendored
8
.github/workflows/python-publish.yml
vendored
@@ -12,8 +12,9 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
os: [windows-latest, macos-latest, macos-11]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -44,7 +45,8 @@ jobs:
|
||||
- name: Build wheel on Linux
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
|
||||
with:
|
||||
python-versions: 'cp36-cp36m cp37-cp37m cp38-cp38'
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-versions: 'cp37-cp37m cp38-cp38'
|
||||
build-requirements: 'numpy cython'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
|
||||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@@ -13,7 +13,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -49,15 +50,6 @@ jobs:
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install -e .
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
else
|
||||
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
|
||||
6
.github/workflows/test_macos.yml
vendored
6
.github/workflows/test_macos.yml
vendored
@@ -10,10 +10,12 @@ on:
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: macos-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
os: [macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
17
CHANGES.rst
17
CHANGES.rst
@@ -159,6 +159,21 @@ Version 0.5.0
|
||||
- Add baselines
|
||||
- public data crawler
|
||||
|
||||
Version greater than Version 0.5.0
|
||||
|
||||
Version 0.8.0
|
||||
--------------------
|
||||
- The backtest is greatly refactored.
|
||||
- Nested decision execution framework is supported
|
||||
- There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed
|
||||
- The trading limitation is more accurate;
|
||||
- In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`_, longing and shorting actions share the same action.
|
||||
- In `current verison <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`_, the trading limitation is different between loging and shorting action.
|
||||
- The constant is different when calculating annualized metrics.
|
||||
- `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`_
|
||||
- `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.
|
||||
- Users could chec kout the backtesting results between `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`_ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`_
|
||||
|
||||
|
||||
Other Versions
|
||||
----------------------------------
|
||||
Please refer to `Github release Notes <https://github.com/microsoft/qlib/releases>`_
|
||||
|
||||
@@ -100,7 +100,6 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
@@ -307,7 +306,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
|
||||
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py run --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
|
||||
|
||||
## Run multiple models
|
||||
@@ -317,7 +316,7 @@ The script will create a unique virtual environment for each model, and delete t
|
||||
|
||||
Here is an example of running all the models for 10 iterations:
|
||||
```python
|
||||
python run_all_model.py 10
|
||||
python run_all_model.py run 10
|
||||
```
|
||||
|
||||
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
@@ -53,6 +53,9 @@ Below is a typical config file of ``qrun``.
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
@@ -240,6 +243,9 @@ The following script is the configuration of `backtest` and the `strategy` used
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -86,4 +87,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -100,4 +101,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -35,8 +35,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -94,4 +95,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -14,7 +14,7 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
|
||||
@@ -33,6 +33,9 @@ port_analysis_config: &port_analysis_config
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
@@ -80,4 +83,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -76,4 +77,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -31,18 +31,22 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -41,8 +41,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -98,4 +99,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -3,49 +3,58 @@
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
<!--
|
||||
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
|
||||
>
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ -->
|
||||
|
||||
> NOTE:
|
||||
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
|
||||
|
||||
## Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| Linear | Alpha360 | 0.0150±0.00 | 0.1049±0.00| 0.0284±0.00 | 0.1970±0.00 | -0.0659±0.00 | -0.7072±0.00| -0.2955±0.00 |
|
||||
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0397±0.00 | 0.2878±0.00| 0.0470±0.00 | 0.3703±0.00 | 0.0342±0.00 | 0.4092±0.00| -0.1057±0.00 |
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha360 | 0.0400±0.00 | 0.3031±0.00| 0.0461±0.00 | 0.3862±0.00 | 0.0528±0.00 | 0.6307±0.00| -0.1113±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha360 | 0.0399±0.00 | 0.3075±0.00| 0.0492±0.00 | 0.4019±0.00 | 0.0323±0.00 | 0.4370±0.00| -0.0917±0.00 |
|
||||
| MLP | Alpha360 | 0.0285±0.00 | 0.1981±0.02| 0.0402±0.00 | 0.2993±0.02 | 0.0073±0.02 | 0.0880±0.22| -0.1446±0.03 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha360 | 0.0490±0.01 | 0.3787±0.05| 0.0581±0.00 | 0.4664±0.04 | 0.0726±0.02 | 0.9817±0.34| -0.0902±0.03 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| Linear | Alpha158 | 0.0393±0.00 | 0.2980±0.00| 0.0475±0.00 | 0.3546±0.00 | 0.0795±0.00 | 1.0712±0.00| -0.1449±0.00 |
|
||||
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0503±0.00 | 0.3586±0.00| 0.0483±0.00 | 0.3667±0.00 | 0.1080±0.00 | 1.1561±0.00| -0.0787±0.00 |
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| TabNet(Sercan O. Arik, et al.) | Alpha158 | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04 | 0.3676±0.54 | -0.1089±0.08 |
|
||||
| Transformer(Ashish Vaswani, et al.) | Alpha158 | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02 | 0.3970±0.26 | -0.1101±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02 | 0.5160±0.25 | -0.1017±0.02 |
|
||||
| LSTM(Sepp Hochreiter, et al.) | Alpha158(with selected 20 features) | 0.0318±0.00 | 0.2367±0.04 | 0.0435±0.00 | 0.3389±0.03 | 0.0381±0.03 | 0.5561±0.46 | -0.1207±0.04 |
|
||||
| Localformer(Juyong Jiang, et al.) | Alpha158 | 0.0356±0.00 | 0.2756±0.03 | 0.0468±0.00 | 0.3784±0.03 | 0.0438±0.02 | 0.6600±0.33 | -0.0952±0.02 |
|
||||
| SFM(Liheng Zhang, et al.) | Alpha158 | 0.0379±0.00 | 0.2959±0.04 | 0.0464±0.00 | 0.3825±0.04 | 0.0465±0.02 | 0.5672±0.29 | -0.1282±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158(with selected 20 features) | 0.0362±0.01 | 0.2789±0.06 | 0.0463±0.01 | 0.3661±0.05 | 0.0470±0.03 | 0.6992±0.47 | -0.1072±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158(with selected 20 features) | 0.0349±0.00 | 0.2511±0.01 | 0.0462±0.00 | 0.3564±0.01 | 0.0497±0.01 | 0.7338±0.19 | -0.0777±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha158(with selected 20 features) | 0.0404±0.00 | 0.3197±0.05 | 0.0490±0.00 | 0.4047±0.04 | 0.0649±0.02 | 1.0091±0.30 | -0.0860±0.02 |
|
||||
| Linear | Alpha158 | 0.0397±0.00 | 0.3000±0.00 | 0.0472±0.00 | 0.3531±0.00 | 0.0692±0.00 | 0.9209±0.00 | -0.1509±0.00 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha158 | 0.0440±0.00 | 0.3535±0.05 | 0.0540±0.00 | 0.4451±0.03 | 0.0718±0.02 | 1.0835±0.35 | -0.0760±0.02 |
|
||||
| CatBoost(Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0481±0.00 | 0.3366±0.00 | 0.0454±0.00 | 0.3311±0.00 | 0.0765±0.00 | 0.8032±0.01 | -0.1092±0.00 |
|
||||
| XGBoost(Tianqi Chen, et al.) | Alpha158 | 0.0498±0.00 | 0.3779±0.00 | 0.0505±0.00 | 0.4131±0.00 | 0.0780±0.00 | 0.9070±0.00 | -0.1168±0.00 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158(with selected 20 features) | 0.0358±0.00 | 0.2160±0.03 | 0.0116±0.01 | 0.0720±0.03 | 0.0847±0.02 | 0.8131±0.19 | -0.1824±0.03 |
|
||||
| MLP | Alpha158 | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02 | 1.1408±0.23 | -0.1103±0.02 |
|
||||
| LightGBM(Guolin Ke, et al.) | Alpha158 | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00 | 1.0164±0.00 | -0.1038±0.00 |
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
|
||||
|
||||
|
||||
|
||||
## Alpha360 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| Transformer(Ashish Vaswani, et al.) | Alpha360 | 0.0114±0.00 | 0.0716±0.03 | 0.0327±0.00 | 0.2248±0.02 | -0.0270±0.03 | -0.3378±0.37 | -0.1653±0.05 |
|
||||
| TabNet(Sercan O. Arik, et al.) | Alpha360 | 0.0099±0.00 | 0.0593±0.00 | 0.0290±0.00 | 0.1887±0.00 | -0.0369±0.00 | -0.3892±0.00 | -0.2145±0.00 |
|
||||
| MLP | Alpha360 | 0.0273±0.00 | 0.1870±0.02 | 0.0396±0.00 | 0.2910±0.02 | 0.0029±0.02 | 0.0274±0.23 | -0.1385±0.03 |
|
||||
| Localformer(Juyong Jiang, et al.) | Alpha360 | 0.0404±0.00 | 0.2932±0.04 | 0.0542±0.00 | 0.4110±0.03 | 0.0246±0.02 | 0.3211±0.21 | -0.1095±0.02 |
|
||||
| CatBoost((Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0378±0.00 | 0.2714±0.00 | 0.0467±0.00 | 0.3659±0.00 | 0.0292±0.00 | 0.3781±0.00 | -0.0862±0.00 |
|
||||
| XGBoost(Tianqi Chen, et al.) | Alpha360 | 0.0394±0.00 | 0.2909±0.00 | 0.0448±0.00 | 0.3679±0.00 | 0.0344±0.00 | 0.4527±0.02 | -0.1004±0.00 |
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha360 | 0.0404±0.00 | 0.3023±0.00 | 0.0495±0.00 | 0.3898±0.00 | 0.0468±0.01 | 0.6302±0.20 | -0.0860±0.01 |
|
||||
| LightGBM(Guolin Ke, et al.) | Alpha360 | 0.0400±0.00 | 0.3037±0.00 | 0.0499±0.00 | 0.4042±0.00 | 0.0558±0.00 | 0.7632±0.00 | -0.0659±0.00 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 |
|
||||
| LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
|
||||
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
- The base model of TCTS is GRU.
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -1,52 +1,38 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. This work introduces a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in the current minibatch) and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
At step <img src="https://latex.codecogs.com/png.latex?s" title="s" />, with training data <img src="https://latex.codecogs.com/png.latex?x_s,y_s" title="x_s,y_s" />, the scheduler <img src="https://latex.codecogs.com/png.latex?\varphi" title="\varphi" /> chooses a suitable task <img src="https://latex.codecogs.com/png.latex?T_{i_s}" title="T_{i_s}" /> (green solid lines) to update the model <img src="https://latex.codecogs.com/png.latex?f" title="f" /> (blue solid lines). After <img src="https://latex.codecogs.com/png.latex?S" title="S" /> steps, we evaluate the model <img src="https://latex.codecogs.com/png.latex?f" title="f" /> on the validation set and update the scheduler <img src="https://latex.codecogs.com/png.latex?\varphi" title="\varphi" /> (green dashed lines).
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
Due to different data versions and different Qlib versions, the original data and data preprocessing methods of the experimental settings in the paper are different from those experimental settings in the existing Qlib version. Therefore, we provide two versions of the code according to the two kinds of settings, 1) the [code](https://github.com/lwwang1995/tcts) that can be used to reproduce the experimental results and 2) the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) in the current Qlib baseline.
|
||||
|
||||
#### Setting1
|
||||
* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
* The main tasks <img src="https://latex.codecogs.com/png.latex?T_k" title="T_k" /> refers to forecasting return of stock <img src="https://latex.codecogs.com/png.latex?i" title="i" /> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
<img src="https://latex.codecogs.com/png.image?\dpi{110}&space;r_{i}^{t,k}&space;=&space;\frac{price_i^{t+k}}{price_i^{t+k-1}}-1" title="r_{i}^{t,k} = \frac{price_i^{t+k}}{price_i^{t+k-1}}-1" />
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
* Temporally correlated task sets <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_k&space;=&space;\{T_1,&space;T_2,&space;...&space;,&space;T_k\}" title="\mathcal{T}_k = \{T_1, T_2, ... , T_k\}" />, in this paper, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_3" title="\mathcal{T}_3" />, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_5" title="\mathcal{T}_5" /> and <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_{10}" title="\mathcal{T}_{10}" /> are used in <img src="https://latex.codecogs.com/png.latex?T_1" title="T_1" />, <img src="https://latex.codecogs.com/png.latex?T_2" title="T_2" />, and <img src="https://latex.codecogs.com/png.latex?T_3" title="T_3" />.
|
||||
|
||||
#### Setting2
|
||||
* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2014), validation (01/01/2015-12/31/2016), and test sets (01/01/2017-08/01/2020) based on the transaction time.
|
||||
|
||||
* The main tasks <img src="https://latex.codecogs.com/png.latex?T_k" title="T_k" /> refers to forecasting return of stock <img src="https://latex.codecogs.com/png.latex?i" title="i" /> as following,
|
||||
<div align=center>
|
||||
<img src="https://latex.codecogs.com/png.image?\dpi{110}&space;r_{i}^{t,k}&space;=&space;\frac{price_i^{t+1+k}}{price_i^{t+1}}-1" title="r_{i}^{t,k} = \frac{price_i^{t+1+k}}{price_i^{t+1}}-1" />
|
||||
</div>
|
||||
|
||||
* In Qlib baseline, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_3" title="\mathcal{T}_3" />, is used in <img src="https://latex.codecogs.com/png.latex?T_1" title="T_1" />.
|
||||
|
||||
### Experimental Result
|
||||
You can find the experimental result of setting1 in the [paper](http://proceedings.mlr.press/v139/wu21e/wu21e.pdf) and the experimental result of setting2 in this [page](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 25 KiB |
@@ -22,16 +22,17 @@ data_handler_config: &data_handler_config
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -1) / $close - 1",
|
||||
"Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -2) - 1"]
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -1) - 1",
|
||||
"Ref($close, -4) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -53,9 +54,8 @@ task:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
dropout: 0.3
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
@@ -64,10 +64,10 @@ task:
|
||||
fore_optimizer: adam
|
||||
weight_optimizer: adam
|
||||
output_dim: 3
|
||||
fore_lr: 5e-4
|
||||
weight_lr: 5e-4
|
||||
fore_lr: 2e-3
|
||||
weight_lr: 2e-3
|
||||
steps: 3
|
||||
target_label: 1
|
||||
target_label: 0
|
||||
lowest_valid_performance: 0.993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
@@ -92,7 +92,6 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
label_col: 1
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
|
||||
@@ -195,7 +195,8 @@ class Alpha158Formatter(GenericDataFormatter):
|
||||
|
||||
for col in column_names:
|
||||
if col not in {"forecast_time", "identifier"}:
|
||||
output[col] = self._target_scaler.inverse_transform(predictions[col])
|
||||
# Using [col] is for aligning with the format when fitting
|
||||
output[col] = self._target_scaler.inverse_transform(predictions[[col]])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -311,5 +311,11 @@ class TFTModel(ModelFT):
|
||||
# self.model.save(path)
|
||||
|
||||
# save qlib model wrapper
|
||||
self.model = None
|
||||
drop_attrs = ["model", "tf_graph", "sess", "data_formatter"]
|
||||
orig_attr = {}
|
||||
for attr in drop_attrs:
|
||||
orig_attr[attr] = getattr(self, attr)
|
||||
setattr(self, attr, None)
|
||||
super(TFTModel, self).to_pickle(path)
|
||||
for attr in drop_attrs:
|
||||
setattr(self, attr, orig_attr[attr])
|
||||
|
||||
@@ -16,8 +16,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -38,7 +38,7 @@ class TRAModel(Model):
|
||||
model_init_state=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
seed=0,
|
||||
seed=None,
|
||||
logdir=None,
|
||||
eval_train=True,
|
||||
eval_test=False,
|
||||
|
||||
@@ -57,8 +57,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -50,6 +51,7 @@ task:
|
||||
kwargs:
|
||||
d_feat: 158
|
||||
pretrain: True
|
||||
seed: 993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -50,6 +51,7 @@ task:
|
||||
kwargs:
|
||||
d_feat: 360
|
||||
pretrain: True
|
||||
seed: 993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -151,10 +151,9 @@ class NestedDecisionExecutionWorkflow:
|
||||
self._train_model(model, dataset)
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
@@ -189,10 +188,9 @@ class NestedDecisionExecutionWorkflow:
|
||||
backtest_config["benchmark"] = self.benchmark
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
|
||||
@@ -151,6 +151,9 @@ def get_all_results(folders) -> dict:
|
||||
if recorders[recorder_id].status == "FINISHED":
|
||||
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
if "1day.excess_return_with_cost.annualized_return" not in metrics:
|
||||
print(f"{recorder_id} is skipped due to incomplete result")
|
||||
continue
|
||||
result["annualized_return_with_cost"].append(metrics["1day.excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["1day.excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["1day.excess_return_with_cost.max_drawdown"])
|
||||
@@ -200,174 +203,183 @@ def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
||||
return temp_path
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
exp_folder_name: str = "run_all_model_records",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
|
||||
for multiple times, and this will be fixed in the future development.
|
||||
class ModelRunner:
|
||||
def _init_qlib(self, exp_folder_name):
|
||||
# init qlib
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
qlib.init(
|
||||
exp_manager={
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
times : int
|
||||
determines how many times the model should be running.
|
||||
models : str or list
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
exp_folder_name: str
|
||||
the name of the experiment folder
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
wait when errors raised when executing commands
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(
|
||||
self,
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
exp_folder_name: str = "run_all_model_records",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
|
||||
for multiple times, and this will be fixed in the future development.
|
||||
|
||||
Usage:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
Parameters:
|
||||
-----------
|
||||
times : int
|
||||
determines how many times the model should be running.
|
||||
models : str or list
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
exp_folder_name: str
|
||||
the name of the experiment folder
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
wait when errors raised when executing commands
|
||||
|
||||
.. code-block:: bash
|
||||
Usage:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
|
||||
# Case 1 - run all models multiple times
|
||||
python run_all_model.py 3
|
||||
.. code-block:: bash
|
||||
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py 3 mlp
|
||||
# Case 1 - run all models multiple times
|
||||
python run_all_model.py run 3
|
||||
|
||||
# Case 3 - run specific models multiple times with specific dataset
|
||||
python run_all_model.py 3 mlp Alpha158
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py run 3 mlp
|
||||
|
||||
# Case 4 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py 3 [mlp,tft,lstm] --exclude=True
|
||||
# Case 3 - run specific models multiple times with specific dataset
|
||||
python run_all_model.py run 3 mlp Alpha158
|
||||
|
||||
# Case 5 - run specific models for one time
|
||||
python run_all_model.py --models=[mlp,lightgbm]
|
||||
# Case 4 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py run 3 [mlp,tft,lstm] --exclude=True
|
||||
|
||||
# Case 6 - run other models except those are given as arguments for one time
|
||||
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
|
||||
# Case 5 - run specific models for one time
|
||||
python run_all_model.py run --models=[mlp,lightgbm]
|
||||
|
||||
"""
|
||||
# init qlib
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
qlib.init(
|
||||
exp_manager={
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
)
|
||||
# Case 6 - run other models except those are given as arguments for one time
|
||||
python run_all_model.py run --models=[mlp,tft,sfm] --exclude=True
|
||||
|
||||
# get all folders
|
||||
folders = get_all_folders(models, exclude)
|
||||
# init error messages:
|
||||
errors = dict()
|
||||
# run all the model for iterations
|
||||
for fn in folders:
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset)
|
||||
if yaml_path is None:
|
||||
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
|
||||
continue
|
||||
sys.stderr.write("\n")
|
||||
# create env by anaconda
|
||||
temp_dir, env_path, python_path, conda_activate = create_env()
|
||||
"""
|
||||
self._init_qlib(exp_folder_name)
|
||||
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
with open(req_path) as f:
|
||||
content = f.read()
|
||||
if "torch" in content:
|
||||
# automatically install pytorch according to nvidia's version
|
||||
execute(
|
||||
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
|
||||
) # for automatically installing torch according to the nvidia driver
|
||||
execute(
|
||||
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
else:
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
# get all folders
|
||||
folders = get_all_folders(models, exclude)
|
||||
# init error messages:
|
||||
errors = dict()
|
||||
# run all the model for iterations
|
||||
for fn in folders:
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset)
|
||||
if yaml_path is None:
|
||||
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
|
||||
continue
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
_errs.update({i: errs})
|
||||
errors[fn] = _errs
|
||||
# create env by anaconda
|
||||
temp_dir, env_path, python_path, conda_activate = create_env()
|
||||
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
with open(req_path) as f:
|
||||
content = f.read()
|
||||
if "torch" in content:
|
||||
# automatically install pytorch according to nvidia's version
|
||||
execute(
|
||||
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
|
||||
) # for automatically installing torch according to the nvidia driver
|
||||
execute(
|
||||
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
else:
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
_errs.update({i: errs})
|
||||
errors[fn] = _errs
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
if wait_before_rm_env:
|
||||
input("Press Enter to Continue")
|
||||
shutil.rmtree(env_path)
|
||||
# print errors
|
||||
sys.stderr.write(f"Here are some of the errors of the models...\n")
|
||||
pprint(errors)
|
||||
self._collect_results(exp_folder_name, dataset)
|
||||
|
||||
def _collect_results(self, exp_folder_name, dataset):
|
||||
folders = get_all_folders(exp_folder_name, dataset)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
if len(results) > 0:
|
||||
# calculating the mean and std
|
||||
sys.stderr.write(f"Calculating the mean and std of results...\n")
|
||||
results = cal_mean_std(results)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results, dataset)
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
if wait_before_rm_env:
|
||||
input("Press Enter to Continue")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
if len(results) > 0:
|
||||
# calculating the mean and std
|
||||
sys.stderr.write(f"Calculating the mean and std of results...\n")
|
||||
results = cal_mean_std(results)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results, dataset)
|
||||
sys.stderr.write("\n")
|
||||
# print errors
|
||||
sys.stderr.write(f"Here are some of the errors of the models...\n")
|
||||
pprint(errors)
|
||||
sys.stderr.write("\n")
|
||||
# move results folder
|
||||
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
|
||||
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
|
||||
# move results folder
|
||||
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
|
||||
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(run) # run all the model
|
||||
fire.Fire(ModelRunner) # run all the model
|
||||
|
||||
@@ -31,10 +31,9 @@ if __name__ == "__main__":
|
||||
},
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@ _version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is
|
||||
__version__ = _version_path.read_text(encoding="utf-8").strip()
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
import yaml
|
||||
import logging
|
||||
import platform
|
||||
@@ -151,14 +152,17 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
:param conf_path: A path to the qlib config in yml format
|
||||
"""
|
||||
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
if conf_path is None:
|
||||
config = {}
|
||||
else:
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
def get_project_path(config_name="config.yaml", cur_path: Union[Path, str, None] = None) -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
@@ -187,6 +191,7 @@ def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
"""
|
||||
if cur_path is None:
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
cur_path = Path(cur_path)
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
@@ -202,6 +207,40 @@ def auto_init(**kwargs):
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
- Skip initialization if already initialized
|
||||
|
||||
:**kwargs: it may contain following parameters
|
||||
cur_path: the start path to find the project path
|
||||
|
||||
Here are two examples of the configuration
|
||||
|
||||
Example 1)
|
||||
If you want create a new project-specific config based on a shared configure, you can use `conf_type: ref`
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
conf_type: ref
|
||||
qlib_cfg: '<shared_yaml_config_path>' # this could be null reference no config from other files
|
||||
# following configs in `qlib_cfg_update` is project=specific
|
||||
qlib_cfg_update:
|
||||
exp_manager:
|
||||
class: "MLflowExpManager"
|
||||
module_path: "qlib.workflow.expm"
|
||||
kwargs:
|
||||
uri: "file://<your mlflow experiment path>"
|
||||
default_exp_name: "Experiment"
|
||||
|
||||
Example 2)
|
||||
If you wan to create simple a stand alone config, you can use following config(a.k.a `conf_type: origin`)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
exp_manager:
|
||||
class: "MLflowExpManager"
|
||||
module_path: "qlib.workflow.expm"
|
||||
kwargs:
|
||||
uri: "file://<your mlflow experiment path>"
|
||||
default_exp_name: "Experiment"
|
||||
|
||||
"""
|
||||
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
|
||||
|
||||
@@ -210,6 +249,7 @@ def auto_init(**kwargs):
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
logger = get_module_logger("Initialization")
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
@@ -223,8 +263,14 @@ def auto_init(**kwargs):
|
||||
# - There is a shared configure file and you don't want to edit it inplace.
|
||||
# - The shared configure may be updated later and you don't want to copy it.
|
||||
# - You have some customized config.
|
||||
qlib_conf_path = conf["qlib_cfg"]
|
||||
qlib_conf_update = conf.get("qlib_cfg_update")
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
|
||||
logger = get_module_logger("Initialization")
|
||||
qlib_conf_path = conf.get("qlib_cfg", None)
|
||||
|
||||
# merge the arguments
|
||||
qlib_conf_update = conf.get("qlib_cfg_update", {})
|
||||
for k, v in kwargs.items():
|
||||
if k in qlib_conf_update:
|
||||
logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'")
|
||||
qlib_conf_update.update(kwargs)
|
||||
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update)
|
||||
logger.info(f"Auto load project config: {conf_pp}")
|
||||
|
||||
@@ -34,6 +34,7 @@ class Exchange:
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
impact_cost=0.0,
|
||||
extra_quote=None,
|
||||
quote_cls=NumpyQuote,
|
||||
**kwargs,
|
||||
@@ -95,6 +96,7 @@ class Exchange:
|
||||
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
|
||||
distinguish `not set` and `disable trade_unit`
|
||||
:param min_cost: min cost, default 5
|
||||
:param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1.
|
||||
:param extra_quote: pandas, dataframe consists of
|
||||
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
|
||||
The limit indicates that the etf is tradable on a specific day.
|
||||
@@ -164,9 +166,12 @@ class Exchange:
|
||||
all_fields = list(all_fields | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
|
||||
self.open_cost = open_cost
|
||||
self.close_cost = close_cost
|
||||
self.min_cost = min_cost
|
||||
self.impact_cost = impact_cost
|
||||
|
||||
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
|
||||
self.volume_threshold = volume_threshold
|
||||
self.extra_quote = extra_quote
|
||||
@@ -685,12 +690,14 @@ class Exchange:
|
||||
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
|
||||
)
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash):
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
|
||||
"""return the real order amount after cash limit for buying.
|
||||
Parameters
|
||||
----------
|
||||
trade_price : float
|
||||
position : cash
|
||||
cost_ratio : float
|
||||
|
||||
Return
|
||||
----------
|
||||
float
|
||||
@@ -699,10 +706,10 @@ class Exchange:
|
||||
max_trade_amount = 0
|
||||
if cash >= self.min_cost:
|
||||
# critical_price means the stock transaction price when the service fee is equal to min_cost.
|
||||
critical_price = self.min_cost / self.open_cost + self.min_cost
|
||||
critical_price = self.min_cost / cost_ratio + self.min_cost
|
||||
if cash >= critical_price:
|
||||
# the service fee is equal to open_cost * trade_amount
|
||||
max_trade_amount = cash / (1 + self.open_cost) / trade_price
|
||||
# the service fee is equal to cost_ratio * trade_amount
|
||||
max_trade_amount = cash / (1 + cost_ratio) / trade_price
|
||||
else:
|
||||
# the service fee is equal to min_cost
|
||||
max_trade_amount = (cash - self.min_cost) / trade_price
|
||||
@@ -718,6 +725,7 @@ class Exchange:
|
||||
:return: trade_price, trade_val, trade_cost
|
||||
"""
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
order.deal_amount = order.amount # set to full amount and clip it step by step
|
||||
# Clipping amount first
|
||||
@@ -726,8 +734,12 @@ class Exchange:
|
||||
# - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
|
||||
self._clip_amount_by_volume(order, dealt_order_amount)
|
||||
|
||||
# TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps
|
||||
trade_val = order.deal_amount * trade_price
|
||||
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
|
||||
|
||||
if order.direction == Order.SELL:
|
||||
cost_ratio = self.close_cost
|
||||
cost_ratio = self.close_cost + adj_cost_ratio
|
||||
# sell
|
||||
# if we don't know current position, we choose to sell all
|
||||
# Otherwise, we clip the amount based on current position
|
||||
@@ -750,14 +762,18 @@ class Exchange:
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
|
||||
elif order.direction == Order.BUY:
|
||||
cost_ratio = self.open_cost
|
||||
cost_ratio = self.open_cost + adj_cost_ratio
|
||||
# buy
|
||||
if position is not None:
|
||||
cash = position.get_cash()
|
||||
trade_val = order.deal_amount * trade_price
|
||||
if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
|
||||
if cash < max(trade_val * cost_ratio, self.min_cost):
|
||||
# cash cannot cover cost
|
||||
order.deal_amount = 0
|
||||
self.logger.debug(f"Order clipped due to cost higher than cash: {order}")
|
||||
elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
|
||||
# The money is not enough
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(max_buy_amount, order.deal_amount), order.factor
|
||||
)
|
||||
|
||||
@@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote):
|
||||
if is_single_value(start_time, end_time, self.freq, self.region):
|
||||
# this is a very special case.
|
||||
# skip aggregating function to speed-up the query calculation
|
||||
|
||||
# FIXME:
|
||||
# it will go to the else logic when it comes to the
|
||||
# 1) the day before holiday when daily trading
|
||||
# 2) the last minute of the day when intraday trading
|
||||
try:
|
||||
return self.data[stock_id].loc[start_time, field]
|
||||
except KeyError:
|
||||
|
||||
@@ -345,15 +345,19 @@ class Position(BasePosition):
|
||||
if stock_id not in self.position:
|
||||
raise KeyError("{} not in current position".format(stock_id))
|
||||
else:
|
||||
# decrease the amount of stock
|
||||
self.position[stock_id]["amount"] -= trade_amount
|
||||
# check if to delete
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
)
|
||||
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
|
||||
if np.isclose(self.position[stock_id]["amount"], trade_amount):
|
||||
# Selling all the stocks
|
||||
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount
|
||||
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
|
||||
self._del_stock(stock_id)
|
||||
else:
|
||||
# decrease the amount of stock
|
||||
self.position[stock_id]["amount"] -= trade_amount
|
||||
# check if to delete
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
)
|
||||
|
||||
new_cash = trade_val - cost
|
||||
if self._settle_type == self.ST_CASH:
|
||||
|
||||
102
qlib/backtest/signal.py
Normal file
102
qlib/backtest/signal.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.utils import init_instance_by_config
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import Dataset
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..utils.resam import resam_ts_data
|
||||
import pandas as pd
|
||||
import abc
|
||||
|
||||
|
||||
class Signal(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Some trading strategy make decisions based on other prediction signals
|
||||
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
|
||||
|
||||
This interface is tries to provide unified interface for those different sources
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
"""
|
||||
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[pd.Series, pd.DataFrame, None]:
|
||||
returns None if no signal in the specific day
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SignalWCache(Signal):
|
||||
"""
|
||||
Signal With pandas with based Cache
|
||||
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
||||
"""
|
||||
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : Union[pd.Series, pd.DataFrame]
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
|
||||
|
||||
instrument datetime
|
||||
SH600000 2008-01-02 0.079704
|
||||
2008-01-03 0.120125
|
||||
2008-01-04 0.878860
|
||||
2008-01-07 0.505539
|
||||
2008-01-08 0.395004
|
||||
"""
|
||||
self.signal_cache = convert_index_format(signal, level="datetime")
|
||||
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not algin with the decision frequency of strategy
|
||||
# so resampling from the data is necessary
|
||||
# the latest signal leverage more recent data and therefore is used in trading.
|
||||
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
||||
return signal
|
||||
|
||||
|
||||
class ModelSignal(SignalWCache):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset):
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
pred_scores = self.model.predict(dataset)
|
||||
if isinstance(pred_scores, pd.DataFrame):
|
||||
pred_scores = pred_scores.iloc[:, 0]
|
||||
super().__init__(pred_scores)
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
When using online data, update model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
"""
|
||||
# TODO: this method is not included in the framework and could be refactor later
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
|
||||
def create_signal_from(
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
|
||||
) -> Signal:
|
||||
"""
|
||||
create signal from diverse information
|
||||
This method will choose the right method to create a signal based on `obj`
|
||||
Please refer to the code below.
|
||||
"""
|
||||
if isinstance(obj, Signal):
|
||||
return obj
|
||||
elif isinstance(obj, (tuple, list)):
|
||||
return ModelSignal(*obj)
|
||||
elif isinstance(obj, (dict, str)):
|
||||
return init_instance_by_config(obj)
|
||||
elif isinstance(obj, (pd.DataFrame, pd.Series)):
|
||||
return SignalWCache(signal=obj)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of signal is not supported")
|
||||
@@ -70,7 +70,7 @@ class TradeCalendarManager:
|
||||
- If self.trade_step >= self.self.trade_len, it means the trading is finished
|
||||
- If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step
|
||||
"""
|
||||
return self.trade_step >= self.trade_len
|
||||
return self.trade_step >= self.trade_len - 1
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
|
||||
0
qlib/contrib/data/utils/__init__.py
Normal file
0
qlib/contrib/data/utils/__init__.py
Normal file
183
qlib/contrib/data/utils/sepdf.py
Normal file
183
qlib/contrib/data/utils/sepdf.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import pandas as pd
|
||||
from typing import Dict, Iterable
|
||||
|
||||
|
||||
def align_index(df_dict, join):
|
||||
res = {}
|
||||
for k, df in df_dict.items():
|
||||
if join is not None and k != join:
|
||||
df = df.reindex(df_dict[join].index)
|
||||
res[k] = df
|
||||
return res
|
||||
|
||||
|
||||
# Mocking the pd.DataFrame class
|
||||
class SepDataFrame:
|
||||
"""
|
||||
(Sep)erate DataFrame
|
||||
We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter).
|
||||
However, they are usally be used seperately at last.
|
||||
This will result in extra cost for concating and spliting data(reshaping and copying data in the memory is very expensive)
|
||||
|
||||
SepDataFrame tries to act like a DataFrame whose column with multiindex
|
||||
"""
|
||||
|
||||
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
|
||||
"""
|
||||
initialize the data based on the dataframe dictionary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_dict : Dict[str, pd.DataFrame]
|
||||
dataframe dictionary
|
||||
join : str
|
||||
how to join the data
|
||||
It will reindex the dataframe based on the join key.
|
||||
If join is None, the reindex step will be skipped
|
||||
|
||||
skip_align :
|
||||
for some cases, we can improve performance by skipping aligning index
|
||||
"""
|
||||
self.join = join
|
||||
|
||||
if skip_align:
|
||||
self._df_dict = df_dict
|
||||
else:
|
||||
self._df_dict = align_index(df_dict, join)
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
return SDFLoc(self, join=self.join)
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
return self._df_dict[self.join].index
|
||||
|
||||
def apply_each(self, method: str, skip_align=True, *args, **kwargs):
|
||||
"""
|
||||
Assumptions:
|
||||
- inplace methods will return None
|
||||
"""
|
||||
inplace = False
|
||||
df_dict = {}
|
||||
for k, df in self._df_dict.items():
|
||||
df_dict[k] = getattr(df, method)(*args, **kwargs)
|
||||
if df_dict[k] is None:
|
||||
inplace = True
|
||||
if not inplace:
|
||||
return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align)
|
||||
|
||||
def sort_index(self, *args, **kwargs):
|
||||
return self.apply_each("sort_index", True, *args, **kwargs)
|
||||
|
||||
def copy(self, *args, **kwargs):
|
||||
return self.apply_each("copy", True, *args, **kwargs)
|
||||
|
||||
def _update_join(self):
|
||||
if self.join not in self:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._df_dict[item]
|
||||
|
||||
def __setitem__(self, item: str, df: pd.DataFrame):
|
||||
# TODO: consider the join behavior
|
||||
self._df_dict[item] = df
|
||||
|
||||
def __delitem__(self, item: str):
|
||||
del self._df_dict[item]
|
||||
self._update_join()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._df_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._df_dict[self.join])
|
||||
|
||||
def droplevel(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `droplevel` method")
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
dfs = []
|
||||
for k, df in self._df_dict.items():
|
||||
df = df.head(0)
|
||||
df.columns = pd.MultiIndex.from_product([[k], df.columns])
|
||||
dfs.append(df)
|
||||
return pd.concat(dfs, axis=1).columns
|
||||
|
||||
# Useless methods
|
||||
@staticmethod
|
||||
def merge(df_dict: Dict[str, pd.DataFrame], join: str):
|
||||
all_df = df_dict[join]
|
||||
for k, df in df_dict.items():
|
||||
if k != join:
|
||||
all_df = all_df.join(df)
|
||||
return all_df
|
||||
|
||||
|
||||
class SDFLoc:
|
||||
"""Mock Class"""
|
||||
|
||||
def __init__(self, sdf: SepDataFrame, join):
|
||||
self._sdf = sdf
|
||||
self.axis = None
|
||||
self.join = join
|
||||
|
||||
def __call__(self, axis):
|
||||
self.axis = axis
|
||||
return self
|
||||
|
||||
def __getitem__(self, args):
|
||||
if self.axis == 1:
|
||||
if isinstance(args, str):
|
||||
return self._sdf[args]
|
||||
elif isinstance(args, (tuple, list)):
|
||||
new_df_dict = {k: self._sdf[k] for k in args}
|
||||
return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
elif self.axis == 0:
|
||||
return SepDataFrame(
|
||||
{k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True
|
||||
)
|
||||
else:
|
||||
df = self._sdf
|
||||
if isinstance(args, tuple):
|
||||
ax0, *ax1 = args
|
||||
if len(ax1) == 0:
|
||||
ax1 = None
|
||||
if ax1 is not None:
|
||||
df = df.loc(axis=1)[ax1]
|
||||
if ax0 is not None:
|
||||
df = df.loc(axis=0)[ax0]
|
||||
return df
|
||||
else:
|
||||
return df.loc(axis=0)[args]
|
||||
|
||||
|
||||
# Patch pandas DataFrame
|
||||
# Tricking isinstance to accept SepDataFrame as its subclass
|
||||
import builtins
|
||||
|
||||
|
||||
def _isinstance(instance, cls):
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
|
||||
if isinstance(cls, Iterable):
|
||||
for c in cls:
|
||||
if c is pd.DataFrame:
|
||||
return True
|
||||
elif cls is pd.DataFrame:
|
||||
return True
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602
|
||||
|
||||
|
||||
builtins.isinstance_orig = builtins.isinstance
|
||||
builtins.isinstance = _isinstance
|
||||
|
||||
if __name__ == "__main__":
|
||||
sdf = SepDataFrame({}, join=None)
|
||||
print(isinstance(sdf, (pd.DataFrame,)))
|
||||
print(isinstance(sdf, pd.DataFrame))
|
||||
@@ -38,6 +38,8 @@ class CatBoostModel(Model, FeatureInt):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
|
||||
@@ -64,6 +64,8 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
# initialize the sample weights
|
||||
N, F = x_train.shape
|
||||
|
||||
@@ -25,6 +25,8 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -83,6 +85,8 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
|
||||
@@ -82,6 +82,8 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_train["feature"], df_valid["label"]
|
||||
|
||||
@@ -51,6 +51,8 @@ class LinearModel(Model):
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
X, y = df_train["feature"].values, np.squeeze(df_train["label"].values)
|
||||
|
||||
if self.estimator in [self.OLS, self.RIDGE, self.LASSO]:
|
||||
|
||||
@@ -224,6 +224,8 @@ class ALSTM(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -207,6 +207,8 @@ class ALSTM(Model):
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
@@ -237,6 +237,8 @@ class GATs(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -245,6 +245,8 @@ class GATs(Model):
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
@@ -224,6 +224,8 @@ class GRU(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -206,6 +206,8 @@ class GRU(Model):
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
@@ -176,6 +176,8 @@ class LocalformerModel(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -153,6 +153,8 @@ class LocalformerModel(Model):
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
@@ -219,6 +219,8 @@ class LSTM(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -201,6 +201,8 @@ class LSTM(Model):
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
@@ -374,6 +374,8 @@ class SFM(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
|
||||
@@ -169,6 +169,8 @@ class TabnetModel(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -61,8 +61,9 @@ class TCTS(Model):
|
||||
weight_lr=5e-7,
|
||||
steps=3,
|
||||
GPU=0,
|
||||
seed=0,
|
||||
target_label=0,
|
||||
mode="soft",
|
||||
seed=None,
|
||||
lowest_valid_performance=0.993,
|
||||
**kwargs
|
||||
):
|
||||
@@ -87,6 +88,7 @@ class TCTS(Model):
|
||||
self.weight_lr = weight_lr
|
||||
self.steps = steps
|
||||
self.target_label = target_label
|
||||
self.mode = mode
|
||||
self.lowest_valid_performance = lowest_valid_performance
|
||||
self._fore_optimizer = fore_optimizer
|
||||
self._weight_optimizer = weight_optimizer
|
||||
@@ -100,6 +102,8 @@ class TCTS(Model):
|
||||
"\nn_epochs : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\ntarget_label : {}"
|
||||
"\nmode : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
@@ -111,6 +115,8 @@ class TCTS(Model):
|
||||
n_epochs,
|
||||
batch_size,
|
||||
early_stop,
|
||||
target_label,
|
||||
mode,
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
@@ -120,9 +126,17 @@ class TCTS(Model):
|
||||
|
||||
def loss_fn(self, pred, label, weight):
|
||||
|
||||
loc = torch.argmax(weight, 1)
|
||||
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
|
||||
return torch.mean(loss)
|
||||
if self.mode == "hard":
|
||||
loc = torch.argmax(weight, 1)
|
||||
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
elif self.mode == "soft":
|
||||
loss = (pred - label.transpose(0, 1)) ** 2
|
||||
return torch.mean(loss * weight.transpose(0, 1))
|
||||
|
||||
else:
|
||||
raise NotImplementedError("mode {} is not supported!".format(self.mode))
|
||||
|
||||
def train_epoch(self, x_train, y_train, x_valid, y_valid):
|
||||
|
||||
@@ -132,6 +146,10 @@ class TCTS(Model):
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
task_embedding = torch.zeros([self.batch_size, self.output_dim])
|
||||
task_embedding[:, self.target_label] = 1
|
||||
task_embedding = task_embedding.to(self.device)
|
||||
|
||||
init_fore_model = copy.deepcopy(self.fore_model)
|
||||
for p in init_fore_model.parameters():
|
||||
p.init_fore_model = False
|
||||
@@ -155,12 +173,13 @@ class TCTS(Model):
|
||||
|
||||
init_pred = init_fore_model(feature)
|
||||
pred = self.fore_model(feature)
|
||||
|
||||
dis = init_pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
|
||||
weight_feature = torch.cat(
|
||||
(feature, dis.transpose(0, 1), label, init_pred.view(-1, 1), task_embedding), 1
|
||||
)
|
||||
weight = self.weight_model(weight_feature)
|
||||
|
||||
loss = self.loss_fn(pred, label, weight) # hard
|
||||
loss = self.loss_fn(pred, label, weight)
|
||||
|
||||
self.fore_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
@@ -188,11 +207,11 @@ class TCTS(Model):
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
dis = pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1), task_embedding), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
loc = torch.argmax(weight, 1)
|
||||
valid_loss = torch.mean((pred - label[:, 0]) ** 2)
|
||||
loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
|
||||
valid_loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
|
||||
loss = torch.mean(valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
|
||||
|
||||
self.weight_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
@@ -207,7 +226,6 @@ class TCTS(Model):
|
||||
|
||||
self.fore_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
@@ -237,6 +255,8 @@ class TCTS(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
@@ -277,7 +297,7 @@ class TCTS(Model):
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.weight_model = MLPModel(
|
||||
d_feat=360 + 2 * self.output_dim + 1,
|
||||
d_feat=360 + 3 * self.output_dim + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
@@ -303,8 +323,6 @@ class TCTS(Model):
|
||||
best_loss = np.inf
|
||||
best_epoch = 0
|
||||
stop_round = 0
|
||||
fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
|
||||
weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
print("Epoch:", epoch)
|
||||
|
||||
@@ -74,7 +74,7 @@ class TRAModel(Model):
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
alpha=1.0,
|
||||
seed=0,
|
||||
seed=None,
|
||||
logdir=None,
|
||||
eval_train=False,
|
||||
eval_test=False,
|
||||
@@ -99,8 +99,9 @@ class TRAModel(Model):
|
||||
if transport_method == "router" and not eval_train:
|
||||
self.logger.warning("`eval_train` will be ignored when using TRA.router")
|
||||
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.model_config = model_config
|
||||
self.tra_config = tra_config
|
||||
|
||||
@@ -175,6 +175,8 @@ class TransformerModel(Model):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -151,6 +151,9 @@ class TransformerModel(Model):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .model_strategy import (
|
||||
from .signal_strategy import (
|
||||
TopkDropoutStrategy,
|
||||
WeightStrategyBase,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ This strategy is not well maintained
|
||||
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
from .model_strategy import WeightStrategyBase
|
||||
from .signal_strategy import WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
|
||||
@@ -80,18 +80,22 @@ class OrderGenWInteract(OrderGenerator):
|
||||
|
||||
:rtype: list
|
||||
"""
|
||||
if target_weight_position is None:
|
||||
return []
|
||||
|
||||
# calculate current_tradable_value
|
||||
current_amount_dict = current.get_stock_amount_dict()
|
||||
|
||||
current_total_value = trade_exchange.calculate_amount_position_value(
|
||||
amount_dict=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
only_tradable=False,
|
||||
)
|
||||
current_tradable_value = trade_exchange.calculate_amount_position_value(
|
||||
amount_dict=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
only_tradable=True,
|
||||
)
|
||||
# add cash
|
||||
@@ -105,9 +109,7 @@ class OrderGenWInteract(OrderGenerator):
|
||||
# value. Then just sell all the stocks
|
||||
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
|
||||
for stock_id in list(target_amount_dict.keys()):
|
||||
if trade_exchange.is_stock_tradable(
|
||||
stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
):
|
||||
if trade_exchange.is_stock_tradable(stock_id, start_time=trade_start_time, end_time=trade_end_time):
|
||||
del target_amount_dict[stock_id]
|
||||
else:
|
||||
# consider cost rate
|
||||
@@ -118,16 +120,16 @@ class OrderGenWInteract(OrderGenerator):
|
||||
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
|
||||
weight_position=target_weight_position,
|
||||
cash=current_tradable_value,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
order_list = trade_exchange.generate_order_for_target_amount_position(
|
||||
target_position=target_amount_dict,
|
||||
current_position=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
return TradeDecisionWO(order_list, self)
|
||||
return order_list
|
||||
|
||||
|
||||
class OrderGenWOInteract(OrderGenerator):
|
||||
@@ -163,8 +165,11 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
:param trade_date:
|
||||
:type trade_date: pd.Timestamp
|
||||
|
||||
:rtype: list
|
||||
:rtype: list of generated orders
|
||||
"""
|
||||
if target_weight_position is None:
|
||||
return []
|
||||
|
||||
risk_total_value = risk_degree * current.calculate_value()
|
||||
|
||||
current_stock = current.get_stock_list()
|
||||
@@ -172,13 +177,17 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
for stock_id in target_weight_position:
|
||||
# Current rule will ignore the stock that not hold and cannot be traded at predict date
|
||||
if trade_exchange.is_stock_tradable(
|
||||
stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
stock_id=stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
) and trade_exchange.is_stock_tradable(
|
||||
stock_id=stock_id, start_time=pred_start_time, end_time=pred_end_time
|
||||
):
|
||||
amount_dict[stock_id] = (
|
||||
risk_total_value
|
||||
* target_weight_position[stock_id]
|
||||
/ trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time)
|
||||
/ trade_exchange.get_close(stock_id, start_time=pred_start_time, end_time=pred_end_time)
|
||||
)
|
||||
# TODO: Qlib use None to represent trading suspension. So last close price can't be the estimated trading price.
|
||||
# Maybe a close price with forward fill will be a better solution.
|
||||
elif stock_id in current_stock:
|
||||
amount_dict[stock_id] = (
|
||||
risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id)
|
||||
@@ -188,7 +197,7 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
order_list = trade_exchange.generate_order_for_target_amount_position(
|
||||
target_position=amount_dict,
|
||||
current_position=current.get_stock_amount_dict(),
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
return TradeDecisionWO(order_list, self)
|
||||
return order_list
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
@@ -1,27 +1,33 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import copy
|
||||
from qlib.backtest.signal import Signal, create_signal_from
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.model.base import BaseModel
|
||||
from qlib.backtest.position import Position
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
class TopkDropoutStrategy(ModelStrategy):
|
||||
class TopkDropoutStrategy(BaseStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
dataset,
|
||||
*,
|
||||
topk,
|
||||
n_drop,
|
||||
signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None,
|
||||
method_sell="bottom",
|
||||
method_buy="top",
|
||||
risk_degree=0.95,
|
||||
@@ -30,6 +36,8 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
trade_exchange=None,
|
||||
level_infra=None,
|
||||
common_infra=None,
|
||||
model=None,
|
||||
dataset=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -39,6 +47,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
the number of stocks in the portfolio.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
signal :
|
||||
the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
|
||||
the decision of the strategy will base on the given signal
|
||||
method_sell : str
|
||||
dropout method_sell, random/bottom.
|
||||
method_buy : str
|
||||
@@ -64,7 +75,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
@@ -74,6 +85,13 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
# This is trying to be compatible with previous version of qlib task config
|
||||
if model is not None and dataset is not None:
|
||||
warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning)
|
||||
signal = model, dataset
|
||||
|
||||
self.signal: Signal = create_signal_from(signal)
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
@@ -87,7 +105,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
|
||||
if pred_score is None:
|
||||
return TradeDecisionWO([], self)
|
||||
if self.only_tradable:
|
||||
@@ -235,15 +253,15 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
return TradeDecisionWO(sell_order_list + buy_order_list, self)
|
||||
|
||||
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
class WeightStrategyBase(BaseStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
dataset,
|
||||
*,
|
||||
signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
level_infra=None,
|
||||
@@ -251,6 +269,9 @@ class WeightStrategyBase(ModelStrategy):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
signal :
|
||||
the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
|
||||
the decision of the strategy will base on the given signal
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
@@ -260,13 +281,15 @@ class WeightStrategyBase(ModelStrategy):
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
|
||||
self.signal: Signal = create_signal_from(signal)
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
@@ -298,7 +321,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
|
||||
if pred_score is None:
|
||||
return TradeDecisionWO([], self)
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
@@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp):
|
||||
|
||||
if save:
|
||||
save_name = "results-{:}.pkl".format(key)
|
||||
self.recorder.save_objects(**{save_name: results})
|
||||
self.save(**{save_name: results})
|
||||
logger.info(
|
||||
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
|
||||
save_name, self.recorder.experiment_id
|
||||
@@ -57,22 +57,20 @@ class MultiSegRecord(RecordTemp):
|
||||
)
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
class SignalMseRecord(RecordTemp):
|
||||
"""
|
||||
This is the Signal MSE Record class that computes the mean squared error (MSE).
|
||||
This class inherits the ``SignalMseRecord`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
self.check(parent=True)
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
def generate(self):
|
||||
self.check()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
@@ -81,9 +79,8 @@ class SignalMseRecord(SignalRecord):
|
||||
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
return paths
|
||||
return ["mse.pkl", "rmse.pkl"]
|
||||
|
||||
@@ -320,6 +320,7 @@ class TSDataSampler:
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
self.idx_map = self.idx_map2arr(self.idx_map)
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
start=time_to_slc_point(start), end=time_to_slc_point(end)
|
||||
@@ -328,6 +329,25 @@ class TSDataSampler:
|
||||
|
||||
del self.data # save memory
|
||||
|
||||
@staticmethod
|
||||
def idx_map2arr(idx_map):
|
||||
# pytorch data sampler will have better memory control without large dict or list
|
||||
# - https://github.com/pytorch/pytorch/issues/13243
|
||||
# - https://github.com/airctic/icevision/issues/613
|
||||
# So we convert the dict into int array.
|
||||
# The arr_map is expected to behave the same as idx_map
|
||||
|
||||
dtype = np.int32
|
||||
# set a index out of bound to indicate the none existing
|
||||
no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max)
|
||||
|
||||
max_idx = max(idx_map.keys())
|
||||
arr_map = []
|
||||
for i in range(max_idx + 1):
|
||||
arr_map.append(idx_map.get(i, no_existing_idx))
|
||||
arr_map = np.array(arr_map, dtype=dtype)
|
||||
return arr_map
|
||||
|
||||
@staticmethod
|
||||
def flt_idx_map(flt_data, idx_map):
|
||||
idx = 0
|
||||
@@ -385,6 +405,10 @@ class TSDataSampler:
|
||||
idx_map[real_idx] = (i, j)
|
||||
return idx_df, idx_map
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return self.__len__() == 0
|
||||
|
||||
def _get_indices(self, row: int, col: int) -> np.array:
|
||||
"""
|
||||
get series indices of self.data_arr from the row, col indices of self.idx_df
|
||||
@@ -520,20 +544,18 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
# make sure the calendar is updated to latest when loading data from new config
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
self.cal = cal
|
||||
self.cal = sorted(cal)
|
||||
|
||||
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
|
||||
@staticmethod
|
||||
def _extend_slice(slc: slice, cal: list, step_len: int) -> slice:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
start, end = slc.start, slc.stop
|
||||
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
|
||||
pad_start_idx = max(0, start_idx - self.step_len)
|
||||
pad_start = self.cal[pad_start_idx]
|
||||
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
|
||||
return data
|
||||
start_idx = bisect.bisect_left(cal, pd.Timestamp(start))
|
||||
pad_start_idx = max(0, start_idx - step_len)
|
||||
pad_start = cal[pad_start_idx]
|
||||
return slice(pad_start, end)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
@@ -542,13 +564,15 @@ class TSDatasetH(DatasetH):
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = self._prepare_raw_seg(slc, **kwargs)
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
|
||||
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
|
||||
data = super()._prepare_seg(ext_slice, **kwargs)
|
||||
|
||||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs["col_set"] = flt_col
|
||||
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
|
||||
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
@@ -82,8 +82,6 @@ class DataHandler(Serializable):
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
"""
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
|
||||
# Setup data loader
|
||||
assert data_loader is not None # to make start_time end_time could have None default value
|
||||
@@ -302,6 +300,7 @@ class DataHandlerLP(DataHandler):
|
||||
DK_R = "raw"
|
||||
DK_I = "infer"
|
||||
DK_L = "learn"
|
||||
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
|
||||
|
||||
# process type
|
||||
PTYPE_I = "independent"
|
||||
@@ -543,7 +542,7 @@ class DataHandlerLP(DataHandler):
|
||||
raise AttributeError(
|
||||
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
|
||||
)
|
||||
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
df = getattr(self, self.ATTR_MAP[data_key])
|
||||
return df
|
||||
|
||||
def fetch(
|
||||
@@ -624,3 +623,33 @@ class DataHandlerLP(DataHandler):
|
||||
df = self._get_df_by_key(data_key).head()
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
@classmethod
|
||||
def cast(cls, handler: "DataHandlerLP") -> "DataHandlerLP":
|
||||
"""
|
||||
Motivation
|
||||
- A user create a datahandler in his customized package. Then he want to share the processed handler to other users without introduce the package dependency and complicated data processing logic.
|
||||
- This class make it possible by casting the class to DataHandlerLP and only keep the processed data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler : DataHandlerLP
|
||||
A subclass of DataHandlerLP
|
||||
|
||||
Returns
|
||||
-------
|
||||
DataHandlerLP:
|
||||
the converted processed data
|
||||
"""
|
||||
new_hd: DataHandlerLP = object.__new__(DataHandlerLP)
|
||||
new_hd.from_cast = True # add a mark for the casted instance
|
||||
|
||||
for key in list(DataHandlerLP.ATTR_MAP.values()) + [
|
||||
"instruments",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"fetch_orig",
|
||||
"drop_raw",
|
||||
]:
|
||||
setattr(new_hd, key, getattr(handler, key, None))
|
||||
return new_hd
|
||||
|
||||
@@ -13,7 +13,7 @@ import pandas as pd
|
||||
from typing import Union, List, Type
|
||||
from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps
|
||||
from .base import Expression, ExpressionOps, Feature
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_callable_kwargs
|
||||
|
||||
@@ -1485,6 +1485,7 @@ OpsList = [
|
||||
IdxMax,
|
||||
IdxMin,
|
||||
If,
|
||||
Feature,
|
||||
]
|
||||
|
||||
|
||||
@@ -1517,7 +1518,7 @@ class OpsWrapper:
|
||||
else:
|
||||
_ops_class = _operator
|
||||
|
||||
if not issubclass(_ops_class, ExpressionOps):
|
||||
if not issubclass(_ops_class, Expression):
|
||||
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))
|
||||
|
||||
if _ops_class.__name__ in self._ops:
|
||||
|
||||
@@ -70,9 +70,9 @@ def fill_placeholder(config: dict, config_extend: dict):
|
||||
# bfs
|
||||
top = 0
|
||||
tail = 1
|
||||
item_quene = [config]
|
||||
item_queue = [config]
|
||||
while top < tail:
|
||||
now_item = item_quene[top]
|
||||
now_item = item_queue[top]
|
||||
top += 1
|
||||
if isinstance(now_item, list):
|
||||
item_keys = range(len(now_item))
|
||||
@@ -80,9 +80,9 @@ def fill_placeholder(config: dict, config_extend: dict):
|
||||
item_keys = now_item.keys()
|
||||
for key in item_keys:
|
||||
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
|
||||
item_quene.append(now_item[key])
|
||||
item_queue.append(now_item[key])
|
||||
tail += 1
|
||||
elif now_item[key] in config_extend.keys():
|
||||
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
|
||||
now_item[key] = config_extend[now_item[key]]
|
||||
return config
|
||||
|
||||
@@ -114,10 +114,19 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
|
||||
task_config = fill_placeholder(task_config, placehorder_value)
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
if isinstance(records, dict): # uniform the data format to list
|
||||
records = [records]
|
||||
|
||||
for record in records:
|
||||
r = init_instance_by_config(record, recorder=rec)
|
||||
# Some recorder require the parameter `model` and `dataset`.
|
||||
# try to automatically pass in them to the initialization function
|
||||
# to make defining the tasking easier
|
||||
r = init_instance_by_config(
|
||||
record,
|
||||
recorder=rec,
|
||||
default_module="qlib.workflow.record_temp",
|
||||
try_kwargs={"model": model, "dataset": dataset},
|
||||
)
|
||||
r.generate()
|
||||
return rec
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
from typing import List, Tuple, Union
|
||||
import pandas as pd
|
||||
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import DatasetH
|
||||
@@ -16,7 +17,7 @@ from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
|
||||
__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
|
||||
|
||||
class BaseStrategy:
|
||||
@@ -193,43 +194,6 @@ class BaseStrategy:
|
||||
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
"""Model-based trading strategy, use model to make predictions for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : BaseModel
|
||||
the model used in when making predictions
|
||||
dataset : DatasetH
|
||||
provide test data for model
|
||||
kwargs : dict
|
||||
arguments that will be passed into `reset` method
|
||||
"""
|
||||
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
When using online data, pdate model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
"""
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
|
||||
class RLStrategy(BaseStrategy):
|
||||
"""RL-based strategy"""
|
||||
|
||||
|
||||
@@ -44,36 +44,10 @@ class TestAutoData(unittest.TestCase):
|
||||
)
|
||||
|
||||
provider_uri_map = {"1min": cls.provider_uri_1min, "day": provider_uri_day}
|
||||
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
init(
|
||||
provider_uri=cls.provider_uri,
|
||||
provider_uri=provider_uri_map,
|
||||
region=REG_CN,
|
||||
expression_cache=None,
|
||||
dataset_cache=None,
|
||||
**client_config,
|
||||
**cls._setup_kwargs,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,10 @@ RECORD_CONFIG = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": {
|
||||
"dataset": "<DATASET>",
|
||||
"model": "<MODEL>",
|
||||
},
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
|
||||
@@ -27,7 +27,7 @@ import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple, Any, Text, Optional
|
||||
from typing import Dict, Union, Tuple, Any, Text, Optional
|
||||
from types import ModuleType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -199,6 +199,7 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
|
||||
----------
|
||||
config : [dict, str]
|
||||
similar to config
|
||||
please refer to the doc of init_instance_by_config
|
||||
|
||||
default_module : Python module or str
|
||||
It should be a python module to load the class type
|
||||
@@ -219,9 +220,12 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
|
||||
_callable = config["class"] # the class type itself is passed in
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
# a.b.c.ClassName
|
||||
*m_path, cls = config.split(".")
|
||||
m_path = ".".join(m_path)
|
||||
module = get_module_by_module_path(default_module if m_path == "" else m_path)
|
||||
|
||||
_callable = getattr(module, config)
|
||||
_callable = getattr(module, cls)
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -232,7 +236,11 @@ get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
|
||||
config: Union[str, dict, object],
|
||||
default_module=None,
|
||||
accept_types: Union[type, Tuple[type]] = (),
|
||||
try_kwargs: Dict = {},
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
get initialized instance with config
|
||||
@@ -256,7 +264,9 @@ def init_instance_by_config(
|
||||
1) specify a pickle object
|
||||
- path like 'file:///<path to pickle file>/obj.pkl'
|
||||
2) specify a class name
|
||||
- "ClassName": getattr(module, config)() will be used.
|
||||
- "ClassName": getattr(module, "ClassName")() will be used.
|
||||
3) specify module path with class name
|
||||
- "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
|
||||
object example:
|
||||
instance of accept_types
|
||||
default_module : Python module
|
||||
@@ -270,6 +280,10 @@ def init_instance_by_config(
|
||||
Optional. If the config is a instance of specific type, return the config directly.
|
||||
This will be passed into the second parameter of isinstance.
|
||||
|
||||
try_kwargs: Dict
|
||||
Try to pass in kwargs in `try_kwargs` when initialized the instance
|
||||
If error occurred, it will fail back to initialization without try_kwargs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object:
|
||||
@@ -286,7 +300,33 @@ def init_instance_by_config(
|
||||
return pickle.load(f)
|
||||
|
||||
klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
try:
|
||||
return klass(**cls_kwargs, **try_kwargs, **kwargs)
|
||||
except (TypeError,):
|
||||
# TypeError for handling errors like
|
||||
# 1: `XXX() got multiple values for keyword argument 'YYY'`
|
||||
# 2: `XXX() got an unexpected keyword argument 'YYY'
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def class_casting(obj: object, cls: type):
|
||||
"""
|
||||
Python doesn't provide the downcasting mechanism.
|
||||
We use the trick here to downcast the class
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : object
|
||||
the object to be cast
|
||||
cls : type
|
||||
the target class type
|
||||
"""
|
||||
orig_cls = obj.__class__
|
||||
obj.__class__ = cls
|
||||
yield
|
||||
obj.__class__ = orig_cls
|
||||
|
||||
|
||||
def compare_dict_value(src_data: dict, dst_data: dict):
|
||||
|
||||
@@ -18,3 +18,9 @@ class LoadObjectError(QlibException):
|
||||
"""Error type for Recorder when can not load object"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExpAlreadyExistError(Exception):
|
||||
"""Experiment already exists"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -401,6 +401,10 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
def columns(self):
|
||||
return self.indices[1]
|
||||
|
||||
def __getitem__(self, args):
|
||||
# NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean
|
||||
return self.iloc[args]
|
||||
|
||||
def _align_indices(self, other: "IndexData") -> "IndexData":
|
||||
"""
|
||||
Align all indices of `other` to `self` before performing the arithmetic operations.
|
||||
@@ -409,7 +413,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
Parameters
|
||||
----------
|
||||
other : "IndexData"
|
||||
the index in `other` is to be chagned
|
||||
the index in `other` is to be changed
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -455,7 +459,8 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
"""
|
||||
return len(self.data)
|
||||
|
||||
def sum(self, axis=None):
|
||||
def sum(self, axis=None, dtype=None, out=None):
|
||||
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
|
||||
# FIXME: weird logic and not general
|
||||
if axis is None:
|
||||
return np.nansum(self.data)
|
||||
@@ -468,7 +473,8 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
else:
|
||||
raise ValueError(f"axis must be None, 0 or 1")
|
||||
|
||||
def mean(self, axis=None):
|
||||
def mean(self, axis=None, dtype=None, out=None):
|
||||
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
|
||||
# FIXME: weird logic and not general
|
||||
if axis is None:
|
||||
return np.nanmean(self.data)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
import pandas as pd
|
||||
|
||||
from queue import Queue
|
||||
|
||||
|
||||
class ParallelExt(Parallel):
|
||||
@@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
|
||||
return pd.concat(dfs, axis=axis).sort_index()
|
||||
else:
|
||||
return _naive_group_apply(df)
|
||||
|
||||
|
||||
class AsyncCaller:
|
||||
"""
|
||||
This AsyncCaller tries to make it easier to async call
|
||||
|
||||
Currently, it is used in MLflowRecorder to make functions like `log_params` async
|
||||
|
||||
NOTE:
|
||||
- This caller didn't consider the return value
|
||||
"""
|
||||
|
||||
STOP_MARK = "__STOP"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._q = Queue()
|
||||
self._stop = False
|
||||
self._t = Thread(target=self.run)
|
||||
self._t.start()
|
||||
|
||||
def close(self):
|
||||
self._q.put(self.STOP_MARK)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
data = self._q.get()
|
||||
if data == self.STOP_MARK:
|
||||
break
|
||||
else:
|
||||
data()
|
||||
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
self._q.put(partial(func, *args, **kwargs))
|
||||
|
||||
def wait(self, close=True):
|
||||
if close:
|
||||
self.close()
|
||||
self._t.join()
|
||||
|
||||
@staticmethod
|
||||
def async_dec(ac_attr):
|
||||
def decorator_func(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if isinstance(getattr(self, ac_attr, None), Callable):
|
||||
return getattr(self, ac_attr)(func, self, *args, **kwargs)
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_func
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Optional
|
||||
from .expm import MLflowExpManager
|
||||
from typing import Any, Dict, Text, Optional
|
||||
from .expm import ExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
@@ -16,7 +16,7 @@ class QlibRecorder:
|
||||
"""
|
||||
|
||||
def __init__(self, exp_manager):
|
||||
self.exp_manager = exp_manager
|
||||
self.exp_manager: ExpManager = exp_manager
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
|
||||
@@ -334,6 +334,26 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
@contextmanager
|
||||
def uri_context(self, uri: Text):
|
||||
"""
|
||||
Temporarily set the exp_manager's uri to uri
|
||||
|
||||
NOTE:
|
||||
- Please refer to the NOTE in the `set_uri`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : Text
|
||||
the temporal uri
|
||||
"""
|
||||
prev_uri = self.exp_manager._current_uri
|
||||
self.exp_manager.set_uri(uri)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.exp_manager.set_uri(prev_uri)
|
||||
|
||||
def get_recorder(
|
||||
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
|
||||
) -> Recorder:
|
||||
@@ -360,11 +380,11 @@ class QlibRecorder:
|
||||
.. code-block:: Python
|
||||
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
with R.start(experiment_name='test'):
|
||||
recorder = R.get_recorder()
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
with R.start(experiment_name='test'):
|
||||
recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
|
||||
|
||||
# Case 3
|
||||
@@ -413,12 +433,18 @@ class QlibRecorder:
|
||||
"""
|
||||
self.get_exp().delete_recorder(recorder_id, recorder_name)
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any]):
|
||||
"""
|
||||
Method for saving objects as artifacts in the experiment to the uri. It supports either saving
|
||||
from a local file/directory, or directly saving objects. User can use valid python's keywords arguments
|
||||
to specify the object to be saved as well as its name (name: value).
|
||||
|
||||
In summary, this API is designs for saving **objects** to **the experiments management backend path**,
|
||||
1. Qlib provide two methods to specify **objects**
|
||||
- Passing in the object directly by passing with `**kwargs` (e.g. R.save_objects(trained_model=model))
|
||||
- Passing in the local path to the object, i.e. `local_path` parameter.
|
||||
2. `artifact_path` represents the **the experiments management backend path**
|
||||
|
||||
- If `active recorder` exists: it will save the objects through the active recorder.
|
||||
- If `active recorder` not exists: the system will create a default experiment, and a new recorder and save objects under it.
|
||||
|
||||
@@ -431,13 +457,20 @@ class QlibRecorder:
|
||||
.. code-block:: Python
|
||||
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
with R.start(experiment_name='test'):
|
||||
pred = model.predict(dataset)
|
||||
R.save_objects(**{"pred.pkl": pred}, artifact_path='prediction')
|
||||
rid = R.get_recorder().id
|
||||
...
|
||||
R.get_recorder(recorder_id=rid).load_object("prediction/pred.pkl") # after saving objects, you can load the previous object with this api
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
R.save_objects(local_path='results/pred.pkl')
|
||||
with R.start(experiment_name='test'):
|
||||
R.save_objects(local_path='results/pred.pkl', artifact_path="prediction")
|
||||
rid = R.get_recorder().id
|
||||
...
|
||||
R.get_recorder(recorder_id=rid).load_object("prediction/pred.pkl") # after saving objects, you can load the previous object with this api
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -445,7 +478,14 @@ class QlibRecorder:
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
artifact_path : str
|
||||
the relative path for the artifact to be stored in the URI.
|
||||
**kwargs: Dict[Text, Any]
|
||||
the object to be saved.
|
||||
For example, `{"pred.pkl": pred}`
|
||||
"""
|
||||
if local_path is not None and len(kwargs) > 0:
|
||||
raise ValueError(
|
||||
"You can choose only one of `local_path`(save the files in a path) or `kwargs`(pass in the objects directly)"
|
||||
)
|
||||
self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs)
|
||||
|
||||
def load_object(self, name: Text):
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from urllib.parse import urlparse
|
||||
import mlflow
|
||||
from filelock import FileLock
|
||||
from mlflow.exceptions import MlflowException
|
||||
from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode
|
||||
from mlflow.entities import ViewType
|
||||
import os, logging
|
||||
from pathlib import Path
|
||||
@@ -15,6 +15,7 @@ from .exp import MLflowExperiment, Experiment
|
||||
from ..config import C
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
from ..utils.exceptions import ExpAlreadyExistError
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
@@ -94,6 +95,10 @@ class ExpManager:
|
||||
Returns
|
||||
-------
|
||||
An experiment object.
|
||||
|
||||
Raise
|
||||
-----
|
||||
ExpAlreadyExistError
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `create_exp` method.")
|
||||
|
||||
@@ -200,7 +205,14 @@ class ExpManager:
|
||||
if pr.scheme == "file":
|
||||
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f:
|
||||
return self.create_exp(experiment_name), True
|
||||
return self.create_exp(experiment_name), True
|
||||
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
|
||||
try:
|
||||
return self.create_exp(experiment_name), True
|
||||
except ExpAlreadyExistError:
|
||||
return (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
|
||||
"""
|
||||
@@ -345,10 +357,15 @@ class MLflowExpManager(ExpManager):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
assert experiment_name is not None
|
||||
# init experiment
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
try:
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
except MlflowException as e:
|
||||
if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS):
|
||||
raise ExpAlreadyExistError()
|
||||
raise e
|
||||
|
||||
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
|
||||
experiment._default_name = self._default_exp_name
|
||||
|
||||
return experiment
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None):
|
||||
|
||||
@@ -21,19 +21,65 @@ Situations Description
|
||||
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
|
||||
will train models task by task and strategy by strategy.
|
||||
|
||||
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
|
||||
nothing until all tasks have been prepared. It makes user can train all tasks in
|
||||
the end of `routine` or `first_train`.
|
||||
Online + DelayTrainer DelayTrainer will skip concrete training until all tasks have been prepared by
|
||||
different strategies. It makes users can parallelly train all tasks at the end of
|
||||
`routine` or `first_train`. Otherwise, these functions will get stuck when each
|
||||
strategy prepare tasks.
|
||||
|
||||
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
|
||||
need to consider using Trainer. This means it will REAL train your models in
|
||||
every routine and prepare signals for every routine.
|
||||
Simulation + Trainer It will behave in the same way as `Online + Trainer`. The only difference is that it
|
||||
is for simulation/backtesting instead of online trading
|
||||
|
||||
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
|
||||
for the ability to multitasking. It means all tasks in all routines
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
different time segments (based on whether or not any new model is online).
|
||||
========================= ===================================================================================
|
||||
|
||||
Here is some pseudo code the demonstrate the workflow of each situation
|
||||
|
||||
For simplicity
|
||||
- Only one strategy is used in the strategy
|
||||
- `update_online_pred` is only called in the online mode and is ignored
|
||||
|
||||
1) `Online + Trainer`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tasks = first_train()
|
||||
models = trainer.train(tasks)
|
||||
trainer.end_train(models)
|
||||
for day in online_trading_days:
|
||||
# OnlineManager.routine
|
||||
models = trainer.train(strategy.prepare_tasks()) # for each strategy
|
||||
strategy.prepare_online_models(models) # for each strategy
|
||||
|
||||
trainer.end_train(models)
|
||||
prepare_signals() # prepare trading signals daily
|
||||
|
||||
|
||||
`Online + DelayTrainer`: the workflow is the same as `Online + Trainer`.
|
||||
|
||||
|
||||
2) `Simulation + DelayTrainer`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# simulate
|
||||
tasks = first_train()
|
||||
models = trainer.train(tasks)
|
||||
for day in historical_calendars:
|
||||
# OnlineManager.routine
|
||||
models = trainer.train(strategy.prepare_tasks()) # for each strategy
|
||||
strategy.prepare_online_models(models) # for each strategy
|
||||
# delay_prepare()
|
||||
# FIXME: Currently the delay_prepare is not implemented in a proper way.
|
||||
trainer.end_train(<for all previous models>)
|
||||
prepare_signals()
|
||||
|
||||
|
||||
# Can we simplify current workflow?
|
||||
- Can reduce the number of state of tasks?
|
||||
- For each task, we have three phases (i.e. task, partly trained task, final trained task)
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -58,7 +104,7 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
|
||||
STATUS_SIMULATING = "simulating" # when calling `simulate`
|
||||
STATUS_NORMAL = "normal" # the normal status
|
||||
STATUS_ONLINE = "online" # the normal status. It is used when online trading
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -87,12 +133,24 @@ class OnlineManager(Serializable):
|
||||
self.begin_time = pd.Timestamp(begin_time)
|
||||
self.cur_time = self.begin_time
|
||||
# OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
|
||||
# It records the online servnig models of each strategy for each day.
|
||||
self.history = {}
|
||||
if trainer is None:
|
||||
trainer = TrainerR()
|
||||
self.trainer = trainer
|
||||
self.signals = None
|
||||
self.status = self.STATUS_NORMAL
|
||||
self.status = self.STATUS_ONLINE
|
||||
|
||||
def _postpone_action(self):
|
||||
"""
|
||||
Should the workflow to postpone the following actions to the end (in delay_prepare)
|
||||
- trainer.end_train
|
||||
- prepare_signals
|
||||
|
||||
Postpone these actions is to support simulating/backtest online strategies without time dependencies.
|
||||
All the actions can be done parallelly at the end.
|
||||
"""
|
||||
return self.status == self.STATUS_SIMULATING and self.trainer.is_delay()
|
||||
|
||||
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
|
||||
"""
|
||||
@@ -113,12 +171,12 @@ class OnlineManager(Serializable):
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
|
||||
# FIXME: Train multiple online models at `first_train` will result in getting too much online models at the
|
||||
# start.
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
if not self._postpone_action():
|
||||
for strategy, models in zip(strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
|
||||
@@ -160,10 +218,10 @@ class OnlineManager(Serializable):
|
||||
|
||||
# The online model may changes in the above processes
|
||||
# So updating the predictions of online models should be the last step
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
if self.status == self.STATUS_ONLINE:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
if not self._postpone_action():
|
||||
for strategy, models in zip(self.strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
@@ -278,13 +336,13 @@ class OnlineManager(Serializable):
|
||||
signal_kwargs=signal_kwargs,
|
||||
)
|
||||
# delay prepare the models and signals
|
||||
if self.trainer.is_delay():
|
||||
if self._postpone_action():
|
||||
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
|
||||
|
||||
# FIXME: get logging level firstly and restore it here
|
||||
set_global_logger_level(logging.DEBUG)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
self.status = self.STATUS_NORMAL
|
||||
self.status = self.STATUS_ONLINE
|
||||
return self.get_signals()
|
||||
|
||||
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
|
||||
@@ -295,6 +353,8 @@ class OnlineManager(Serializable):
|
||||
model_kwargs: the params for `end_train`
|
||||
signal_kwargs: the params for `prepare_signals`
|
||||
"""
|
||||
# FIXME:
|
||||
# This method is not implemented in the proper way!!!
|
||||
last_models = {}
|
||||
signals_time = D.calendar()[0]
|
||||
need_prepare = False
|
||||
|
||||
@@ -9,6 +9,9 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import Union, List
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
@@ -16,7 +19,7 @@ from ..data.dataset.handler import DataHandlerLP
|
||||
from ..backtest import backtest as normal_backtest
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..utils import flatten_dict, class_casting
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
@@ -32,6 +35,7 @@ class RecordTemp:
|
||||
"""
|
||||
|
||||
artifact_path = None
|
||||
depend_cls = None # the depend class of the record; the record will depend on the results generated by `depend_cls`
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, path=None):
|
||||
@@ -44,6 +48,16 @@ class RecordTemp:
|
||||
|
||||
return "/".join(names)
|
||||
|
||||
def save(self, **kwargs):
|
||||
"""
|
||||
It behaves the same as self.recorder.save_objects.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
"""
|
||||
art_path = self.get_path()
|
||||
if art_path == "":
|
||||
art_path = None
|
||||
self.recorder.save_objects(artifact_path=art_path, **kwargs)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self._recorder = recorder
|
||||
|
||||
@@ -66,31 +80,37 @@ class RecordTemp:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `generate` method.")
|
||||
|
||||
def load(self, name):
|
||||
def load(self, name: str, parents: bool = True):
|
||||
"""
|
||||
Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API
|
||||
with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them
|
||||
in the future::
|
||||
|
||||
sar = SigAnaRecord(recorder)
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
It behaves the same as self.recorder.load_object.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
|
||||
parents : bool
|
||||
Each recorder has different `artifact_path`.
|
||||
So parents recursively find the path in parents
|
||||
Sub classes has higher priority
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
# try to load the saved object
|
||||
obj = self.recorder.load_object(name)
|
||||
return obj
|
||||
try:
|
||||
return self.recorder.load_object(self.get_path(name))
|
||||
except LoadObjectError:
|
||||
if parents:
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
return self.load(name, parents=True)
|
||||
|
||||
def list(self):
|
||||
"""
|
||||
List the supported artifacts.
|
||||
Users don't have to consider self.get_path
|
||||
|
||||
Return
|
||||
------
|
||||
@@ -98,21 +118,45 @@ class RecordTemp:
|
||||
"""
|
||||
return []
|
||||
|
||||
def check(self, cls="self"):
|
||||
def check(self, include_self: bool = False, parents: bool = True):
|
||||
"""
|
||||
Check if the records is properly generated and saved.
|
||||
It is useful in following examples
|
||||
- checking if the depended files complete before genrating new things.
|
||||
- checking if the final files is completed
|
||||
|
||||
Parameters
|
||||
----------
|
||||
include_self : bool
|
||||
is the file generated by self included
|
||||
parents : bool
|
||||
will we check parents
|
||||
|
||||
Raise
|
||||
------
|
||||
FileExistsError: whether the records are stored properly.
|
||||
FileNotFoundError
|
||||
: whether the records are stored properly.
|
||||
"""
|
||||
artifacts = set(self.recorder.list_artifacts())
|
||||
if cls == "self":
|
||||
cls = self
|
||||
flist = cls.list()
|
||||
for item in flist:
|
||||
if item not in artifacts:
|
||||
raise FileExistsError(item)
|
||||
if include_self:
|
||||
|
||||
# Some mlflow backend will not list the directly recursively.
|
||||
# So we force to the directly
|
||||
artifacts = {}
|
||||
|
||||
def _get_arts(dirn):
|
||||
if dirn not in artifacts:
|
||||
artifacts[dirn] = self.recorder.list_artifacts(dirn)
|
||||
return artifacts[dirn]
|
||||
|
||||
for item in self.list():
|
||||
ps = self.get_path(item).split("/")
|
||||
dirn, fn = "/".join(ps[:-1]), ps[-1]
|
||||
if self.get_path(item) not in _get_arts(dirn):
|
||||
raise FileNotFoundError
|
||||
if parents:
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
|
||||
|
||||
class SignalRecord(RecordTemp):
|
||||
@@ -127,26 +171,20 @@ class SignalRecord(RecordTemp):
|
||||
|
||||
@staticmethod
|
||||
def generate_label(dataset):
|
||||
# NOTE:
|
||||
# Python doesn't provide the downcasting mechanism.
|
||||
# We use the trick here to downcast the class
|
||||
orig_cls = dataset.__class__
|
||||
dataset.__class__ = DatasetH
|
||||
|
||||
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = dataset.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
dataset.__class__ = orig_cls
|
||||
with class_casting(dataset, DatasetH):
|
||||
params = dict(segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
|
||||
try:
|
||||
# Assume the backend handler is DataHandlerLP
|
||||
raw_label = dataset.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
return raw_label
|
||||
|
||||
def generate(self, **kwargs):
|
||||
@@ -154,7 +192,7 @@ class SignalRecord(RecordTemp):
|
||||
pred = self.model.predict(self.dataset)
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame("score")
|
||||
self.recorder.save_objects(**{"pred.pkl": pred})
|
||||
self.save(**{"pred.pkl": pred})
|
||||
|
||||
logger.info(
|
||||
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
@@ -165,15 +203,11 @@ class SignalRecord(RecordTemp):
|
||||
|
||||
if isinstance(self.dataset, DatasetH):
|
||||
raw_label = self.generate_label(self.dataset)
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
self.save(**{"label.pkl": raw_label})
|
||||
|
||||
@staticmethod
|
||||
def list():
|
||||
def list(self):
|
||||
return ["pred.pkl", "label.pkl"]
|
||||
|
||||
def load(self, name="pred.pkl"):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
@@ -214,19 +248,11 @@ class HFSignalRecord(SignalRecord):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [
|
||||
self.get_path("ic.pkl"),
|
||||
self.get_path("ric.pkl"),
|
||||
self.get_path("long_pre.pkl"),
|
||||
self.get_path("short_pre.pkl"),
|
||||
self.get_path("long_short_r.pkl"),
|
||||
self.get_path("long_avg_r.pkl"),
|
||||
]
|
||||
return paths
|
||||
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]
|
||||
|
||||
|
||||
class SigAnaRecord(RecordTemp):
|
||||
@@ -235,16 +261,26 @@ class SigAnaRecord(RecordTemp):
|
||||
"""
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
pre_class = SignalRecord
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False):
|
||||
super().__init__(recorder=recorder)
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
self.label_col = label_col
|
||||
self.skip_existing = skip_existing
|
||||
|
||||
def generate(self, **kwargs):
|
||||
self.check(self.pre_class)
|
||||
if self.skip_existing:
|
||||
try:
|
||||
self.check(include_self=True, parents=False)
|
||||
except FileNotFoundError:
|
||||
pass # continue to generating metrics
|
||||
else:
|
||||
logger.info("The results has previously generated, generation skipped.")
|
||||
return
|
||||
|
||||
self.check()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
@@ -276,13 +312,13 @@ class SigAnaRecord(RecordTemp):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
|
||||
paths = ["ic.pkl", "ric.pkl"]
|
||||
if self.ana_long_short:
|
||||
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
|
||||
paths.extend(["long_short_r.pkl", "long_avg_r.pkl"])
|
||||
return paths
|
||||
|
||||
|
||||
@@ -369,17 +405,11 @@ class PortAnaRecord(RecordTemp):
|
||||
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
|
||||
)
|
||||
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
|
||||
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
|
||||
for _freq, indicators_normal in indicator_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq not in portfolio_metric_dict:
|
||||
@@ -401,9 +431,7 @@ class PortAnaRecord(RecordTemp):
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -428,9 +456,7 @@ class PortAnaRecord(RecordTemp):
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -442,20 +468,19 @@ class PortAnaRecord(RecordTemp):
|
||||
for _freq in self.all_freq:
|
||||
list_path.extend(
|
||||
[
|
||||
PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
|
||||
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
|
||||
f"report_normal_{_freq}.pkl",
|
||||
f"positions_normal_{_freq}.pkl",
|
||||
]
|
||||
)
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
|
||||
list_path.append(f"port_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
for _analysis_freq in self.indicator_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
|
||||
list_path.append(f"indicator_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
return list_path
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from qlib.utils.serial import Serializable
|
||||
import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
@@ -8,8 +9,10 @@ from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from qlib.utils.paral import AsyncCaller
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
from ..log import TimeInspector, get_module_logger
|
||||
from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
@@ -65,6 +68,8 @@ class Recorder:
|
||||
Save objects such as prediction file or model checkpoints to the artifact URI. User
|
||||
can save object through keywords arguments (name:value).
|
||||
|
||||
Please refer to the docs of qlib.workflow:R.save_objects
|
||||
|
||||
Parameters
|
||||
----------
|
||||
local_path : str
|
||||
@@ -225,6 +230,7 @@ class MLflowRecorder(Recorder):
|
||||
if mlflow_run.info.end_time is not None
|
||||
else None
|
||||
)
|
||||
self.async_log = None
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
@@ -283,6 +289,10 @@ class MLflowRecorder(Recorder):
|
||||
self.status = Recorder.STATUS_R
|
||||
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
|
||||
|
||||
# NOTE: making logging async.
|
||||
# - This may cause delay when uploading results
|
||||
# - The logging time may not be accurate
|
||||
self.async_log = AsyncCaller()
|
||||
return run
|
||||
|
||||
def end_run(self, status: str = Recorder.STATUS_S):
|
||||
@@ -296,6 +306,9 @@ class MLflowRecorder(Recorder):
|
||||
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.status != Recorder.STATUS_S:
|
||||
self.status = status
|
||||
with TimeInspector.logt("waiting `async_log`"):
|
||||
self.async_log.wait()
|
||||
self.async_log = None
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
@@ -331,18 +344,27 @@ class MLflowRecorder(Recorder):
|
||||
try:
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
with Path(path).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
data = pickle.load(f)
|
||||
ar = self.client._tracking_client._get_artifact_repo(self.id)
|
||||
if isinstance(ar, AzureBlobArtifactRepository):
|
||||
# for saving disk space
|
||||
# For safety, only remove redundant file for specific ArtifactRepository
|
||||
shutil.rmtree(Path(path).absolute().parent)
|
||||
return data
|
||||
except Exception as e:
|
||||
raise LoadObjectError(message=str(e))
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def log_params(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_param(self.id, name, data)
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_metric(self.id, name, data, step=step)
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def set_tags(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.set_tag(self.id, name, data)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user