mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 09:01:18 +08:00
Compare commits
47 Commits
v0.9.2
...
xuyang1/su
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2df211c320 | ||
|
|
effed382e9 | ||
|
|
86ffd1799d | ||
|
|
aef11536e3 | ||
|
|
8b0fdf1623 | ||
|
|
9a36f8da20 | ||
|
|
b7757d5008 | ||
|
|
ee5e5cfdd8 | ||
|
|
6cb87ecfd1 | ||
|
|
9119bcdd3c | ||
|
|
4fccf8112d | ||
|
|
73bd79ca1a | ||
|
|
7e84f3aae2 | ||
|
|
1326ac614d | ||
|
|
f12184cc0f | ||
|
|
a70386ad52 | ||
|
|
74619ed8d8 | ||
|
|
1a523df007 | ||
|
|
f9cc8a5aaa | ||
|
|
7762c5a1fd | ||
|
|
fa7ef29281 | ||
|
|
429c9a7c66 | ||
|
|
80fbc00792 | ||
|
|
01accec24c | ||
|
|
1d88830b0d | ||
|
|
ad7498e287 | ||
|
|
73d51f05b4 | ||
|
|
3b56b8e6c0 | ||
|
|
40e0c329ba | ||
|
|
e376648860 | ||
|
|
5f37f32184 | ||
|
|
d46b4c1ebf | ||
|
|
0515524b51 | ||
|
|
cda32d5703 | ||
|
|
e2332a004b | ||
|
|
08d9dbccc9 | ||
|
|
e7cd93a36d | ||
|
|
3919678028 | ||
|
|
421b1403b2 | ||
|
|
94102fb742 | ||
|
|
74a5d7c8af | ||
|
|
ce39b4b6f8 | ||
|
|
2af35d9c89 | ||
|
|
f37643550b | ||
|
|
55611aa43e | ||
|
|
f24253efd2 | ||
|
|
7c4f3b8a7d |
5
.github/release-drafter.yml
vendored
5
.github/release-drafter.yml
vendored
@@ -14,9 +14,6 @@ categories:
|
||||
label:
|
||||
- 'doc'
|
||||
- 'documentation'
|
||||
- title: '🧹 Maintenance'
|
||||
label:
|
||||
- 'maintenance'
|
||||
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
|
||||
change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks.
|
||||
version-resolver:
|
||||
@@ -33,4 +30,4 @@ version-resolver:
|
||||
template: |
|
||||
## Changes
|
||||
|
||||
$CHANGES
|
||||
$CHANGES
|
||||
21
.github/workflows/test_qlib_from_source.yml
vendored
21
.github/workflows/test_qlib_from_source.yml
vendored
@@ -20,28 +20,18 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Update pip to the latest version
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install pip==23.0
|
||||
python -m pip install pip==23.0.1
|
||||
|
||||
- name: Installing pytorch for macos
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
@@ -139,7 +129,8 @@ jobs:
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl /tmp/qlibpublic/data --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
|
||||
- name: Install Lightgbm for MacOS
|
||||
if: ${{ matrix.os == 'macos-11' || matrix.os == 'macos-latest' }}
|
||||
|
||||
18
.github/workflows/test_qlib_from_source_slow.yml
vendored
18
.github/workflows/test_qlib_from_source_slow.yml
vendored
@@ -20,28 +20,18 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Test qlib from source slow
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error".
|
||||
# So we make the version number of python 3.7 for MacOS more specific.
|
||||
# refs: https://github.com/actions/setup-python/issues/682
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-11' && matrix.python-version == '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7.16"
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-11' || matrix.python-version != '3.7')
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Set up Python tools
|
||||
# pip release version 23.1 on Apr.15 2023, CI failed to run, Please refer to #1495 ofr detailed logs.
|
||||
# The pip version has been temporarily fixed to 23.0
|
||||
# The pip version has been temporarily fixed to 23.0.1
|
||||
run: |
|
||||
python -m pip install pip==23.0
|
||||
python -m pip install pip==23.0.1
|
||||
pip install --upgrade cython numpy
|
||||
pip install -e .[dev]
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -22,6 +22,7 @@ dist/
|
||||
qlib/VERSION.txt
|
||||
qlib/data/_libs/expanding.cpp
|
||||
qlib/data/_libs/rolling.cpp
|
||||
qlib/finco/prompt_cache.json
|
||||
examples/estimator/estimator_example/
|
||||
examples/rl/data/
|
||||
examples/rl/checkpoints/
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| KRNN and Sandwich models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1414/) on May 26, 2023 |
|
||||
| Release Qlib v0.9.0 | :octocat: [Released](https://github.com/microsoft/qlib/releases/tag/v0.9.0) on Dec 9, 2022 |
|
||||
| RL Learning Framework | :hammer: :chart_with_upwards_trend: Released on Nov 10, 2022. [#1332](https://github.com/microsoft/qlib/pull/1332), [#1322](https://github.com/microsoft/qlib/pull/1322), [#1316](https://github.com/microsoft/qlib/pull/1316),[#1299](https://github.com/microsoft/qlib/pull/1299),[#1263](https://github.com/microsoft/qlib/pull/1263), [#1244](https://github.com/microsoft/qlib/pull/1244), [#1169](https://github.com/microsoft/qlib/pull/1169), [#1125](https://github.com/microsoft/qlib/pull/1125), [#1076](https://github.com/microsoft/qlib/pull/1076)|
|
||||
| HIST and IGMTF models | :chart_with_upwards_trend: [Released](https://github.com/microsoft/qlib/pull/1040) on Apr 10, 2022 |
|
||||
@@ -354,8 +353,6 @@ Here is a list of models built on `Qlib`.
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
- [IGMTF based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/IGMTF/)
|
||||
- [HIST based on pytorch (Wentao Xu, et al.2021)](examples/benchmarks/HIST/)
|
||||
- [KRNN based on pytorch](examples/benchmarks/KRNN/)
|
||||
- [Sandwich based on pytorch](examples/benchmarks/Sandwich/)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ Here are some example:
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py download_data --file_name csv_data_cn.zip --target_dir ~/.qlib/csv_data/cn_data
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# KRNN
|
||||
* Code: [https://github.com/microsoft/FOST/blob/main/fostool/model/krnn.py](https://github.com/microsoft/FOST/blob/main/fostool/model/krnn.py)
|
||||
|
||||
|
||||
# Introductions about the settings/configs.
|
||||
* Torch_geometric is used in the original model in FOST, but we didn't use it.
|
||||
* make use your CUDA version matches the torch version to allow the usage of GPU, we use CUDA==10.2 and torch.__version__==1.12.1
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
numpy==1.23.4
|
||||
pandas==1.5.2
|
||||
@@ -1,91 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: KRNN
|
||||
module_path: qlib.contrib.model.pytorch_krnn
|
||||
kwargs:
|
||||
fea_dim: 6
|
||||
cnn_dim: 8
|
||||
cnn_kernel_size: 3
|
||||
rnn_dim: 8
|
||||
rnn_dups: 2
|
||||
rnn_layers: 2
|
||||
n_epochs: 200
|
||||
lr: 0.001
|
||||
early_stop: 20
|
||||
batch_size: 2000
|
||||
metric: loss
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -26,7 +26,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0279±0.00 | 0.2181±0.01 | 0.0421±0.00 | 0.3429±0.01 | 0.0262±0.02 | 0.4133±0.25 | -0.1090±0.03 |
|
||||
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 |
|
||||
| TabNet(Sercan O. Arik, et al.) | Alpha158 | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04 | 0.3676±0.54 | -0.1089±0.08 |
|
||||
| Transformer(Ashish Vaswani, et al.) | Alpha158 | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02 | 0.3970±0.26 | -0.1101±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02 | 0.5160±0.25 | -0.1017±0.02 |
|
||||
@@ -68,8 +68,6 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
| IGMTF(Wentao Xu, et al.) | Alpha360 | 0.0480±0.00 | 0.3589±0.02 | 0.0606±0.00 | 0.4773±0.01 | 0.0946±0.02 | 1.3509±0.25 | -0.0716±0.02 |
|
||||
| HIST(Wentao Xu, et al.) | Alpha360 | 0.0522±0.00 | 0.3530±0.01 | 0.0667±0.00 | 0.4576±0.01 | 0.0987±0.02 | 1.3726±0.27 | -0.0681±0.01 |
|
||||
| KRNN | Alpha360 | 0.0173±0.01 | 0.1210±0.06 | 0.0270±0.01 | 0.2018±0.04 | -0.0465±0.05 | -0.5415±0.62 | -0.2919±0.13 |
|
||||
| Sandwich | Alpha360 | 0.0258±0.00 | 0.1924±0.04 | 0.0337±0.00 | 0.2624±0.03 | 0.0005±0.03 | 0.0001±0.33 | -0.1752±0.05 |
|
||||
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
@@ -136,7 +134,7 @@ If you want to contribute your new models, you can follow the steps below.
|
||||
- `README.md`: a brief introduction to your models
|
||||
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
|
||||
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
|
||||
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py#LL286C22-L286C22) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
|
||||
4. Please updated your results in the benchmark tables, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on 20 runs with different random seeds, if you don't have enough computational resource, you can ask for help in the PR).
|
||||
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
|
||||
|
||||
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# Sandwich
|
||||
* Code: [https://github.com/microsoft/FOST/blob/main/fostool/model/sandwich.py](https://github.com/microsoft/FOST/blob/main/fostool/model/sandwich.py)
|
||||
|
||||
|
||||
# Introductions about the settings/configs.
|
||||
* Torch_geometric is used in the original model in FOST, but we didn't use it.
|
||||
make use your CUDA version matches the torch version to allow the usage of GPU, we use CUDA==10.2 and torch.version==1.12.1
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
numpy==1.23.4
|
||||
pandas==1.5.2
|
||||
@@ -1,93 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: Sandwich
|
||||
module_path: qlib.contrib.model.pytorch_sandwich
|
||||
kwargs:
|
||||
fea_dim: 6
|
||||
cnn_dim_1: 16
|
||||
cnn_dim_2: 16
|
||||
cnn_kernel_size: 3
|
||||
rnn_dim_1: 8
|
||||
rnn_dim_2: 8
|
||||
rnn_dups: 2
|
||||
rnn_layers: 2
|
||||
n_epochs: 200
|
||||
lr: 0.001
|
||||
early_stop: 20
|
||||
batch_size: 2000
|
||||
metric: loss
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
.PHONY: clean
|
||||
|
||||
clean:
|
||||
-rm -r *.pkl mlruns || true
|
||||
@@ -34,14 +34,14 @@ class DDGDA:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sim_task_model: Literal["linear", "gbdt"] = "gbdt",
|
||||
sim_task_model: Literal["linear", "gbdt"] = "linear",
|
||||
forecast_model: Literal["linear", "gbdt"] = "linear",
|
||||
h_path: Optional[str] = None,
|
||||
test_end: Optional[str] = None,
|
||||
train_start: Optional[str] = None,
|
||||
meta_1st_train_end: Optional[str] = None,
|
||||
task_ext_conf: Optional[dict] = None,
|
||||
alpha: float = 0.01,
|
||||
alpha: float = 0.0,
|
||||
proxy_hd: str = "handler_proxy.pkl",
|
||||
):
|
||||
"""
|
||||
@@ -116,9 +116,7 @@ class DDGDA:
|
||||
|
||||
feature_selected = feature_df.loc[:, col_selected.index]
|
||||
|
||||
feature_selected = feature_selected.groupby("datetime", group_keys=False).apply(
|
||||
lambda df: (df - df.mean()).div(df.std())
|
||||
)
|
||||
feature_selected = feature_selected.groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std()))
|
||||
feature_selected = feature_selected.fillna(0.0)
|
||||
|
||||
df_all = {
|
||||
@@ -170,8 +168,7 @@ class DDGDA:
|
||||
# - Only the dataset part is important, in current version of meta model will integrate the
|
||||
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
|
||||
sim_task = rb.basic_task()
|
||||
# the train_start for training meta model does not necessarily align with final rolling
|
||||
train_start = "2008-01-01" if self.rb_kwargs.get("train_start") is None else self.rb_kwargs.get("train_start")
|
||||
train_start = self.rb_kwargs.get("train_start", "2008-01-01")
|
||||
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
|
||||
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
proxy_forecast_model_task = {
|
||||
@@ -215,7 +212,7 @@ class DDGDA:
|
||||
with R.start(experiment_name=self.meta_exp_name):
|
||||
R.log_params(**kwargs)
|
||||
mm = MetaModelDS(
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=30, seed=43, alpha=self.alpha
|
||||
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43, alpha=self.alpha
|
||||
)
|
||||
mm.fit(md)
|
||||
R.save_objects(model=mm)
|
||||
|
||||
@@ -8,17 +8,15 @@ The table below shows the performances of different solutions on different forec
|
||||
Here is the [crowd sourced version of qlib data](data_collector/crowd_source/README.md): https://github.com/chenditc/investment_data/releases
|
||||
```bash
|
||||
wget https://github.com/chenditc/investment_data/releases/download/20220720/qlib_bin.tar.gz
|
||||
mkdir -p ~/.qlib/qlib_data/cn_data
|
||||
tar -zxvf qlib_bin.tar.gz -C ~/.qlib/qlib_data/cn_data --strip-components=2
|
||||
rm -f qlib_bin.tar.gz
|
||||
```
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------|---------|------|------|---------|-----------|-------------------|-------------------|--------------|
|
||||
| RR[Linear] |Alpha158 |0.0945|0.5989|0.1069 |0.6495 |0.0857 |1.3682 |-0.0986 |
|
||||
| DDG-DA[Linear] |Alpha158 |0.0983|0.6157|0.1108 |0.6646 |0.0764 |1.1904 |-0.0769 |
|
||||
| RR[LightGBM] |Alpha158 |0.0816|0.5887|0.0912 |0.6263 |0.0771 |1.3196 |-0.0909 |
|
||||
| DDG-DA[LightGBM] |Alpha158 |0.0878|0.6185|0.0975 |0.6524 |0.1261 |2.0096 |-0.0744 |
|
||||
|------------------|---------|----|------|---------|-----------|-------------------|-------------------|--------------|
|
||||
| RR[Linear] |Alpha158 |0.089|0.577|0.102 |0.627 |0.093 |1.458 |-0.073 |
|
||||
| DDG-DA[Linear] |Alpha158 |0.096|0.636|0.107 |0.677 |0.067 |0.996 |-0.091 |
|
||||
| RR[LightGBM] |Alpha158 |0.082|0.589|0.091 |0.626 |0.077 |1.320 |-0.091 |
|
||||
| DDG-DA[LightGBM] |Alpha158 |0.085|0.658|0.094 |0.686 |0.115 |1.792 |-0.068 |
|
||||
|
||||
- The label horizon of the `Alpha158` dataset is set to 20.
|
||||
- The rolling time intervals are set to 20 trading days.
|
||||
|
||||
@@ -67,12 +67,11 @@ class RollingBenchmark:
|
||||
def basic_task(self):
|
||||
"""For fast training rolling"""
|
||||
if self.model_type == "gbdt":
|
||||
conf_path = DIRNAME / "workflow_config_lightgbm_Alpha158.yaml"
|
||||
conf_path = DIRNAME.parent.parent / "benchmarks" / "LightGBM" / "workflow_config_lightgbm_Alpha158.yaml"
|
||||
# dump the processed data on to disk for later loading to speed up the processing
|
||||
h_path = DIRNAME / "lightgbm_alpha158_handler_horizon{}.pkl".format(self.horizon)
|
||||
elif self.model_type == "linear":
|
||||
# We use ridge regression to stabilize the performance
|
||||
conf_path = DIRNAME / "workflow_config_linear_Alpha158.yaml"
|
||||
conf_path = DIRNAME.parent.parent / "benchmarks" / "Linear" / "workflow_config_linear_Alpha158.yaml"
|
||||
h_path = DIRNAME / "linear_alpha158_handler_horizon{}.pkl".format(self.horizon)
|
||||
else:
|
||||
raise AssertionError("Model type is not supported!")
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LinearModel
|
||||
module_path: qlib.contrib.model.linear
|
||||
kwargs:
|
||||
estimator: ridge
|
||||
alpha: 0.05
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.9.2"
|
||||
__version__ = "0.9.1.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
111
qlib/contrib/analyzer.py
Normal file
111
qlib/contrib/analyzer.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
logger = get_module_logger("analysis", logging.INFO)
|
||||
|
||||
|
||||
class AnalyzerTemp:
|
||||
def __init__(self, recorder, output_dir=None, **kwargs):
|
||||
self.recorder = recorder
|
||||
self.output_dir = Path(output_dir) if output_dir else "./"
|
||||
|
||||
def load(self, name: str):
|
||||
"""
|
||||
It behaves the same as self.recorder.load_object.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
return self.recorder.load_object(name)
|
||||
|
||||
def analyse(self, **kwargs):
|
||||
"""
|
||||
Analyse data index, distribution .etc
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
|
||||
Return
|
||||
------
|
||||
The handled data.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `analysis` method.")
|
||||
|
||||
|
||||
class HFAnalyzer(AnalyzerTemp):
|
||||
"""
|
||||
This is the Signal Analysis class that generates the analysis results such as IC and IR.
|
||||
|
||||
default output image filename is "HFAnalyzerTable.jpeg"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self):
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Average Return": long_short_r.mean(),
|
||||
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
|
||||
}
|
||||
)
|
||||
|
||||
table = [[k, v] for (k, v) in metrics.items()]
|
||||
plt.table(cellText=table, loc="center")
|
||||
plt.axis("off")
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzerTable.jpeg"))
|
||||
plt.clf()
|
||||
|
||||
plt.scatter(np.arange(0, len(pred)), pred.iloc[:, 0])
|
||||
plt.scatter(np.arange(0, len(label)), label.iloc[:, 0])
|
||||
plt.title("HFAnalyzer")
|
||||
plt.savefig(self.output_dir.joinpath("HFAnalyzer.jpeg"))
|
||||
return "HFAnalyzer.jpeg"
|
||||
|
||||
|
||||
class SignalAnalyzer(AnalyzerTemp):
|
||||
"""
|
||||
This is the Signal Analysis class that generates the analysis results such as IC and IR.
|
||||
|
||||
default output image filename is "signalAnalysis.jpeg"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def analyse(self, dataset=None, **kwargs):
|
||||
label = self.load("label.pkl")
|
||||
|
||||
plt.hist(label)
|
||||
plt.title("SignalAnalyzer")
|
||||
plt.savefig(self.output_dir.joinpath("signalAnalysis.jpeg"))
|
||||
|
||||
return "signalAnalysis.jpeg"
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Optional
|
||||
from qlib.utils.data import update_config
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
@@ -57,12 +59,13 @@ class Alpha360(DataHandlerLP):
|
||||
fit_end_time=None,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
data_loader: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
_data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
@@ -74,12 +77,14 @@ class Alpha360(DataHandlerLP):
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
if data_loader is not None:
|
||||
update_config(_data_loader, data_loader)
|
||||
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
data_loader=_data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
**kwargs
|
||||
@@ -153,12 +158,13 @@ class Alpha158(DataHandlerLP):
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
filter_pipe=None,
|
||||
inst_processors=None,
|
||||
data_loader: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
_data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
@@ -170,11 +176,13 @@ class Alpha158(DataHandlerLP):
|
||||
"inst_processors": inst_processors,
|
||||
},
|
||||
}
|
||||
if data_loader is not None:
|
||||
update_config(_data_loader, data_loader)
|
||||
super().__init__(
|
||||
instruments=instruments,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
data_loader=_data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
|
||||
@@ -1,511 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
########################################################################
|
||||
########################################################################
|
||||
########################################################################
|
||||
|
||||
|
||||
class CNNEncoderBase(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, kernel_size, device):
|
||||
"""Build a basic CNN encoder
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
The input dimension
|
||||
output_dim : int
|
||||
The output dimension
|
||||
kernel_size : int
|
||||
The size of convolutional kernels
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.kernel_size = kernel_size
|
||||
self.device = device
|
||||
|
||||
# set padding to ensure the same length
|
||||
# it is correct only when kernel_size is odd, dilation is 1, stride is 1
|
||||
self.conv = nn.Conv1d(input_dim, output_dim, kernel_size, padding=(kernel_size - 1) // 2)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
input data
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Updated representations
|
||||
"""
|
||||
|
||||
# input shape: [batch_size, seq_len*input_dim]
|
||||
# output shape: [batch_size, seq_len, input_dim]
|
||||
x = x.view(x.shape[0], -1, self.input_dim).permute(0, 2, 1).to(self.device)
|
||||
y = self.conv(x) # [batch_size, output_dim, conved_seq_len]
|
||||
y = y.permute(0, 2, 1) # [batch_size, conved_seq_len, output_dim]
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class KRNNEncoderBase(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dup_num, rnn_layers, dropout, device):
|
||||
"""Build K parallel RNNs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
The input dimension
|
||||
output_dim : int
|
||||
The output dimension
|
||||
dup_num : int
|
||||
The number of parallel RNNs
|
||||
rnn_layers: int
|
||||
The number of RNN layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.dup_num = dup_num
|
||||
self.rnn_layers = rnn_layers
|
||||
self.dropout = dropout
|
||||
self.device = device
|
||||
|
||||
self.rnn_modules = nn.ModuleList()
|
||||
for _ in range(dup_num):
|
||||
self.rnn_modules.append(nn.GRU(input_dim, output_dim, num_layers=self.rnn_layers, dropout=dropout))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input data
|
||||
n_id : torch.Tensor
|
||||
Node indices
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Updated representations
|
||||
"""
|
||||
|
||||
# input shape: [batch_size, seq_len, input_dim]
|
||||
# output shape: [batch_size, seq_len, output_dim]
|
||||
# [seq_len, batch_size, input_dim]
|
||||
batch_size, seq_len, input_dim = x.shape
|
||||
x = x.permute(1, 0, 2).to(self.device)
|
||||
|
||||
hids = []
|
||||
for rnn in self.rnn_modules:
|
||||
h, _ = rnn(x) # [seq_len, batch_size, output_dim]
|
||||
hids.append(h)
|
||||
# [seq_len, batch_size, output_dim, num_dups]
|
||||
hids = torch.stack(hids, dim=-1)
|
||||
hids = hids.view(seq_len, batch_size, self.output_dim, self.dup_num)
|
||||
hids = hids.mean(dim=3)
|
||||
hids = hids.permute(1, 0, 2)
|
||||
|
||||
return hids
|
||||
|
||||
|
||||
class CNNKRNNEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, cnn_input_dim, cnn_output_dim, cnn_kernel_size, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device
|
||||
):
|
||||
"""Build an encoder composed of CNN and KRNN
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cnn_input_dim : int
|
||||
The input dimension of CNN
|
||||
cnn_output_dim : int
|
||||
The output dimension of CNN
|
||||
cnn_kernel_size : int
|
||||
The size of convolutional kernels
|
||||
rnn_output_dim : int
|
||||
The output dimension of KRNN
|
||||
rnn_dup_num : int
|
||||
The number of parallel duplicates for KRNN
|
||||
rnn_layers : int
|
||||
The number of RNN layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.cnn_encoder = CNNEncoderBase(cnn_input_dim, cnn_output_dim, cnn_kernel_size, device)
|
||||
self.krnn_encoder = KRNNEncoderBase(cnn_output_dim, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input data
|
||||
n_id : torch.Tensor
|
||||
Node indices
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Updated representations
|
||||
"""
|
||||
cnn_out = self.cnn_encoder(x)
|
||||
krnn_out = self.krnn_encoder(cnn_out)
|
||||
|
||||
return krnn_out
|
||||
|
||||
|
||||
class KRNNModel(nn.Module):
|
||||
def __init__(self, fea_dim, cnn_dim, cnn_kernel_size, rnn_dim, rnn_dups, rnn_layers, dropout, device, **params):
|
||||
"""Build a KRNN model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fea_dim : int
|
||||
The feature dimension
|
||||
cnn_dim : int
|
||||
The hidden dimension of CNN
|
||||
cnn_kernel_size : int
|
||||
The size of convolutional kernels
|
||||
rnn_dim : int
|
||||
The hidden dimension of KRNN
|
||||
rnn_dups : int
|
||||
The number of parallel duplicates
|
||||
rnn_layers: int
|
||||
The number of RNN layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.encoder = CNNKRNNEncoder(
|
||||
cnn_input_dim=fea_dim,
|
||||
cnn_output_dim=cnn_dim,
|
||||
cnn_kernel_size=cnn_kernel_size,
|
||||
rnn_output_dim=rnn_dim,
|
||||
rnn_dup_num=rnn_dups,
|
||||
rnn_layers=rnn_layers,
|
||||
dropout=dropout,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.out_fc = nn.Linear(rnn_dim, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
# x: [batch_size, node_num, seq_len, input_dim]
|
||||
encode = self.encoder(x)
|
||||
out = self.out_fc(encode[:, -1, :]).squeeze().to(self.device)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class KRNN(Model):
|
||||
"""KRNN Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fea_dim=6,
|
||||
cnn_dim=64,
|
||||
cnn_kernel_size=3,
|
||||
rnn_dim=64,
|
||||
rnn_dups=3,
|
||||
rnn_layers=2,
|
||||
dropout=0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("KRNN")
|
||||
self.logger.info("KRNN pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.fea_dim = fea_dim
|
||||
self.cnn_dim = cnn_dim
|
||||
self.cnn_kernel_size = cnn_kernel_size
|
||||
self.rnn_dim = rnn_dim
|
||||
self.rnn_dups = rnn_dups
|
||||
self.rnn_layers = rnn_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"KRNN parameters setting:"
|
||||
"\nfea_dim : {}"
|
||||
"\ncnn_dim : {}"
|
||||
"\ncnn_kernel_size : {}"
|
||||
"\nrnn_dim : {}"
|
||||
"\nrnn_dups : {}"
|
||||
"\nrnn_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size: {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
fea_dim,
|
||||
cnn_dim,
|
||||
cnn_kernel_size,
|
||||
rnn_dim,
|
||||
rnn_dups,
|
||||
rnn_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.krnn_model = KRNNModel(
|
||||
fea_dim=self.fea_dim,
|
||||
cnn_dim=self.cnn_dim,
|
||||
cnn_kernel_size=self.cnn_kernel_size,
|
||||
rnn_dim=self.rnn_dim,
|
||||
rnn_dups=self.rnn_dups,
|
||||
rnn_layers=self.rnn_layers,
|
||||
dropout=self.dropout,
|
||||
device=self.device,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.krnn_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.krnn_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.krnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
self.krnn_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.krnn_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.krnn_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.krnn_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.krnn_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.krnn_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.krnn_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.krnn_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.krnn_model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
@@ -1,381 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from .pytorch_krnn import CNNKRNNEncoder
|
||||
|
||||
|
||||
class SandwichModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
fea_dim,
|
||||
cnn_dim_1,
|
||||
cnn_dim_2,
|
||||
cnn_kernel_size,
|
||||
rnn_dim_1,
|
||||
rnn_dim_2,
|
||||
rnn_dups,
|
||||
rnn_layers,
|
||||
dropout,
|
||||
device,
|
||||
**params
|
||||
):
|
||||
"""Build a Sandwich model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fea_dim : int
|
||||
The feature dimension
|
||||
cnn_dim_1 : int
|
||||
The hidden dimension of the first CNN
|
||||
cnn_dim_2 : int
|
||||
The hidden dimension of the second CNN
|
||||
cnn_kernel_size : int
|
||||
The size of convolutional kernels
|
||||
rnn_dim_1 : int
|
||||
The hidden dimension of the first KRNN
|
||||
rnn_dim_2 : int
|
||||
The hidden dimension of the second KRNN
|
||||
rnn_dups : int
|
||||
The number of parallel duplicates
|
||||
rnn_layers: int
|
||||
The number of RNN layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.first_encoder = CNNKRNNEncoder(
|
||||
cnn_input_dim=fea_dim,
|
||||
cnn_output_dim=cnn_dim_1,
|
||||
cnn_kernel_size=cnn_kernel_size,
|
||||
rnn_output_dim=rnn_dim_1,
|
||||
rnn_dup_num=rnn_dups,
|
||||
rnn_layers=rnn_layers,
|
||||
dropout=dropout,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.second_encoder = CNNKRNNEncoder(
|
||||
cnn_input_dim=rnn_dim_1,
|
||||
cnn_output_dim=cnn_dim_2,
|
||||
cnn_kernel_size=cnn_kernel_size,
|
||||
rnn_output_dim=rnn_dim_2,
|
||||
rnn_dup_num=rnn_dups,
|
||||
rnn_layers=rnn_layers,
|
||||
dropout=dropout,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.out_fc = nn.Linear(rnn_dim_2, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
# x: [batch_size, node_num, seq_len, input_dim]
|
||||
encode = self.first_encoder(x)
|
||||
encode = self.second_encoder(encode)
|
||||
out = self.out_fc(encode[:, -1, :]).squeeze().to(self.device)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Sandwich(Model):
|
||||
"""Sandwich Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fea_dim=6,
|
||||
cnn_dim_1=64,
|
||||
cnn_dim_2=32,
|
||||
cnn_kernel_size=3,
|
||||
rnn_dim_1=16,
|
||||
rnn_dim_2=8,
|
||||
rnn_dups=3,
|
||||
rnn_layers=2,
|
||||
dropout=0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("Sandwich")
|
||||
self.logger.info("Sandwich pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.fea_dim = fea_dim
|
||||
self.cnn_dim_1 = cnn_dim_1
|
||||
self.cnn_dim_2 = cnn_dim_2
|
||||
self.cnn_kernel_size = cnn_kernel_size
|
||||
self.rnn_dim_1 = rnn_dim_1
|
||||
self.rnn_dim_2 = rnn_dim_2
|
||||
self.rnn_dups = rnn_dups
|
||||
self.rnn_layers = rnn_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"Sandwich parameters setting:"
|
||||
"\nfea_dim : {}"
|
||||
"\ncnn_dim_1 : {}"
|
||||
"\ncnn_dim_2 : {}"
|
||||
"\ncnn_kernel_size : {}"
|
||||
"\nrnn_dim_1 : {}"
|
||||
"\nrnn_dim_2 : {}"
|
||||
"\nrnn_dups : {}"
|
||||
"\nrnn_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size: {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
fea_dim,
|
||||
cnn_dim_1,
|
||||
cnn_dim_2,
|
||||
cnn_kernel_size,
|
||||
rnn_dim_1,
|
||||
rnn_dim_2,
|
||||
rnn_dups,
|
||||
rnn_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.sandwich_model = SandwichModel(
|
||||
fea_dim=self.fea_dim,
|
||||
cnn_dim_1=self.cnn_dim_1,
|
||||
cnn_dim_2=self.cnn_dim_2,
|
||||
cnn_kernel_size=self.cnn_kernel_size,
|
||||
rnn_dim_1=self.rnn_dim_1,
|
||||
rnn_dim_2=self.rnn_dim_2,
|
||||
rnn_dups=self.rnn_dups,
|
||||
rnn_layers=self.rnn_layers,
|
||||
dropout=self.dropout,
|
||||
device=self.device,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.sandwich_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.sandwich_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.sandwich_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
self.sandwich_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.sandwich_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.sandwich_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.sandwich_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.sandwich_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.sandwich_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.sandwich_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.sandwich_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
with torch.no_grad():
|
||||
pred = self.sandwich_model(x_batch).detach().cpu().numpy()
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
@@ -168,8 +168,7 @@ class TCN(Model):
|
||||
self.TCN_model.train()
|
||||
|
||||
for data in data_loader:
|
||||
data = torch.transpose(data, 1, 2)
|
||||
feature = data[:, 0:-1, :].to(self.device)
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.TCN_model(feature.float())
|
||||
@@ -188,8 +187,8 @@ class TCN(Model):
|
||||
losses = []
|
||||
|
||||
for data in data_loader:
|
||||
data = torch.transpose(data, 1, 2)
|
||||
feature = data[:, 0:-1, :].to(self.device)
|
||||
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
|
||||
18
qlib/finco/.env.example
Normal file
18
qlib/finco/.env.example
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
OPENAI_API_KEY=your_api_key
|
||||
|
||||
# USE_AZURE=True
|
||||
# AZURE_API_BASE=your_api_base
|
||||
# AZURE_API_VERSION=your_api_version
|
||||
|
||||
# use gpt-4 means more token but more wait time
|
||||
# MODEL=gpt-4
|
||||
# MAX_TOKENS=1600
|
||||
# MAX_RETRY=1000
|
||||
|
||||
|
||||
MAX_TOKENS=1600
|
||||
MAX_RETRY=120
|
||||
|
||||
CONTINOUS_MODE=True
|
||||
DEBUG_MODE=True
|
||||
22
qlib/finco/README.md
Normal file
22
qlib/finco/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# This is an experimental branch of "`FI`nancial `CO`pilot of `Qlib`"
|
||||
|
||||
## Installation
|
||||
|
||||
- To run this module, you need to first install Qlib following the instruction in [install-from-source](/README.md#install-from-source) or follow:
|
||||
|
||||
```python
|
||||
python -m pip install git+https://github.com/microsoft/qlib.git@finco
|
||||
```
|
||||
|
||||
- then you need to install other dependencies of finco:
|
||||
```python
|
||||
python -m pip install pydantic openai python-dotenv
|
||||
```
|
||||
|
||||
## Quick run
|
||||
|
||||
To run this module, you can start the workflow easily with one command:
|
||||
|
||||
```sh
|
||||
cd qlib/finco; python cli.py "your prompt"
|
||||
```
|
||||
13
qlib/finco/__init__.py
Normal file
13
qlib/finco/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
def get_finco_path() -> Path:
|
||||
"""
|
||||
return the template path
|
||||
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
|
||||
"""
|
||||
return DIRNAME
|
||||
15
qlib/finco/cli.py
Normal file
15
qlib/finco/cli.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import WorkflowManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
wm = WorkflowManager()
|
||||
wm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
15
qlib/finco/cli_learn.py
Normal file
15
qlib/finco/cli_learn.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import fire
|
||||
from qlib.finco.workflow import LearnManager
|
||||
from dotenv import load_dotenv
|
||||
from qlib import auto_init
|
||||
|
||||
|
||||
def main(prompt=None):
|
||||
load_dotenv(verbose=True, override=True)
|
||||
lm = LearnManager()
|
||||
lm.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
auto_init()
|
||||
fire.Fire(main)
|
||||
32
qlib/finco/conf.py
Normal file
32
qlib/finco/conf.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# TODO: use pydantic for other modules in Qlib
|
||||
from pydantic import BaseSettings
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class Config(SingletonBaseClass):
|
||||
"""
|
||||
This config is for fast demo purpose.
|
||||
Please use BaseSettings insetead in the future
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.temperature = 0.5 if os.getenv("TEMPERATURE") is None else float(os.getenv("TEMPERATURE"))
|
||||
self.max_tokens = 800 if os.getenv("MAX_TOKENS") is None else int(os.getenv("MAX_TOKENS"))
|
||||
|
||||
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.use_azure = os.getenv("USE_AZURE") == "True"
|
||||
self.azure_api_base = os.getenv("AZURE_API_BASE")
|
||||
self.azure_api_version = os.getenv("AZURE_API_VERSION")
|
||||
self.model = os.getenv("MODEL") or ("gpt-35-turbo" if self.use_azure else "gpt-3.5-turbo")
|
||||
|
||||
self.max_retry = int(os.getenv("MAX_RETRY")) if os.getenv("MAX_RETRY") is not None else None
|
||||
|
||||
self.continuous_mode = (
|
||||
os.getenv("CONTINOUS_MODE") == "True" if os.getenv("CONTINOUS_MODE") is not None else False
|
||||
)
|
||||
self.debug_mode = os.getenv("DEBUG_MODE") == "True" if os.getenv("DEBUG_MODE") is not None else False
|
||||
self.workspace = os.getenv("WORKSPACE") if os.getenv("WORKSPACE") is not None else "./finco_workspace"
|
||||
self.max_past_message_include = int(os.getenv("MAX_PAST_MESSAGE_INCLUDE") or 6) // 2 * 2
|
||||
156
qlib/finco/knowledge.py
Normal file
156
qlib/finco/knowledge.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
from typing import List
|
||||
|
||||
from qlib.workflow import R
|
||||
from qlib.finco.log import FinCoLog
|
||||
from qlib.finco.llm import APIBackend
|
||||
|
||||
|
||||
class Knowledge:
|
||||
"""
|
||||
Use to handle knowledge in finCo such as experiment and outside domain information
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def load(self, **kwargs):
|
||||
"""
|
||||
Load knowledge in memory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
def brief(self, **kwargs):
|
||||
"""
|
||||
Return a brief summary of knowledge
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Return
|
||||
------
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `load` method.")
|
||||
|
||||
|
||||
class KnowledgeExperiment(Knowledge):
|
||||
"""
|
||||
Handle knowledge from experiments
|
||||
"""
|
||||
|
||||
def __init__(self, exp_name, rec_id=None):
|
||||
super().__init__()
|
||||
self.exp_name = exp_name
|
||||
self.exp = None
|
||||
self.recs = []
|
||||
|
||||
self.load(exp_name=exp_name, rec_id=rec_id)
|
||||
|
||||
def load(self, exp_name, rec_id=None):
|
||||
recs = []
|
||||
self.exp = R.get_exp(experiment_name=exp_name)
|
||||
for r in self.exp.list_recorders(rtype=self.exp.RT_L):
|
||||
if rec_id is not None and r.id != rec_id:
|
||||
continue
|
||||
recs.append(r)
|
||||
self.recs.extend(recs)
|
||||
|
||||
def brief(self):
|
||||
docs = []
|
||||
for recorder in self.recs:
|
||||
docs.append({"exp_name": self.exp.name, "record_info": recorder.info,
|
||||
"config": recorder.load_object("config"),
|
||||
"context_summary": recorder.load_object("context_summary")})
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class Topic:
|
||||
|
||||
def __init__(self, name: str, describe: Template):
|
||||
self.name = name
|
||||
self.describe = describe
|
||||
self.docs = []
|
||||
self.knowledge = None
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def summarize(self, docs: list):
|
||||
self.logger.info(f"Summarize topic: \nname: {self.name}\ndescribe: {self.describe.module}")
|
||||
prompt_workflow_selection = self.describe.render(docs=docs)
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
user_prompt=prompt_workflow_selection
|
||||
)
|
||||
|
||||
self.knowledge = response
|
||||
self.docs = docs
|
||||
|
||||
|
||||
class KnowledgeBase:
|
||||
"""
|
||||
Load knowledge, offer brief information of knowledge and common handle interfaces
|
||||
"""
|
||||
|
||||
def __init__(self, init_path=None, topics: List[Topic] = None):
|
||||
self.logger = FinCoLog()
|
||||
init_path = init_path if init_path else Path.cwd()
|
||||
|
||||
if not init_path.exists():
|
||||
self.logger.warning(f"{init_path} not exist, create empty directory.")
|
||||
Path.mkdir(init_path)
|
||||
|
||||
self.knowledge = self.load(path=init_path)
|
||||
|
||||
# todo: replace list with persistent storage strategy such as ES/pinecone to enable
|
||||
# literal search/semantic search
|
||||
self.docs = self.brief(knowledge=self.knowledge)
|
||||
|
||||
self.topics = topics if topics else []
|
||||
|
||||
def load(self, path) -> List:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
knowledge = []
|
||||
path = path if path.name == "mlruns" else path.joinpath("mlruns")
|
||||
R.set_uri(path.as_uri())
|
||||
for exp_name in R.list_experiments():
|
||||
knowledge.append(KnowledgeExperiment(exp_name=exp_name))
|
||||
|
||||
self.logger.plain_info(f"Load knowledge from: {path} finished.")
|
||||
return knowledge
|
||||
|
||||
def update(self, path):
|
||||
# note: only update new knowledge in future
|
||||
knowledge = self.load(path)
|
||||
self.knowledge = knowledge
|
||||
self.docs = self.brief(self.knowledge)
|
||||
self.logger.plain_info(f"Update knowledge finished.")
|
||||
|
||||
def brief(self, knowledge: List[Knowledge]) -> List:
|
||||
docs = []
|
||||
for k in knowledge:
|
||||
docs.extend(k.brief())
|
||||
|
||||
self.logger.plain_info(f"Generate brief knowledge summary finished.")
|
||||
return docs
|
||||
|
||||
def query(self, content: str = None):
|
||||
# todo: query by DSL
|
||||
return self.docs
|
||||
|
||||
def query_topics(self):
|
||||
knowledge_of_topics = []
|
||||
for topic in self.topics:
|
||||
knowledge_of_topics.append({topic.name: topic.knowledge})
|
||||
return knowledge_of_topics
|
||||
|
||||
def summarize_by_topic(self):
|
||||
for topic in self.topics:
|
||||
topic.summarize(self.docs)
|
||||
111
qlib/finco/llm.py
Normal file
111
qlib/finco/llm.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import time
|
||||
import openai
|
||||
import json
|
||||
from typing import Optional
|
||||
from qlib.finco.conf import Config
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from qlib.finco.log import FinCoLog
|
||||
|
||||
|
||||
class APIBackend(SingletonBaseClass):
|
||||
def __init__(self):
|
||||
self.cfg = Config()
|
||||
openai.api_key = self.cfg.openai_api_key
|
||||
if self.cfg.use_azure:
|
||||
openai.api_type = "azure"
|
||||
openai.api_base = self.cfg.azure_api_base
|
||||
openai.api_version = self.cfg.azure_api_version
|
||||
self.use_azure = self.cfg.use_azure
|
||||
|
||||
self.debug_mode = False
|
||||
if self.cfg.debug_mode:
|
||||
self.debug_mode = True
|
||||
cwd = os.getcwd()
|
||||
self.cache_file_location = os.path.join(cwd, "prompt_cache.json")
|
||||
self.cache = (
|
||||
json.load(open(self.cache_file_location, "r")) if os.path.exists(self.cache_file_location) else {}
|
||||
)
|
||||
|
||||
def build_messages_and_create_chat_completion(self, user_prompt, system_prompt=None, former_messages=[], **kwargs):
|
||||
"""build the messages to avoid implementing several redundant lines of code"""
|
||||
cfg = Config()
|
||||
# TODO: system prompt should always be provided. In development stage we can use default value
|
||||
if system_prompt is None:
|
||||
try:
|
||||
system_prompt = cfg.system_prompt
|
||||
except AttributeError:
|
||||
FinCoLog().warning("system_prompt is not set, using default value.")
|
||||
system_prompt = "You are an AI assistant who helps to answer user's questions about finance."
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
]
|
||||
messages.extend(former_messages[-1*cfg.max_past_message_include:])
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
)
|
||||
fcl = FinCoLog()
|
||||
response = self.try_create_chat_completion(messages=messages, **kwargs)
|
||||
fcl.log_message(messages)
|
||||
fcl.log_response(response)
|
||||
return response
|
||||
|
||||
def try_create_chat_completion(self, max_retry=10, **kwargs):
|
||||
max_retry = self.cfg.max_retry if self.cfg.max_retry is not None else max_retry
|
||||
for i in range(max_retry):
|
||||
try:
|
||||
response = self.create_chat_completion(**kwargs)
|
||||
return response
|
||||
except (openai.error.RateLimitError, openai.error.Timeout, openai.error.APIError) as e:
|
||||
print(e)
|
||||
print(f"Retrying {i+1}th time...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
except openai.InvalidRequestError as e:
|
||||
print("Invalid request, will try to reduce the messages length and retry...")
|
||||
if len(kwargs["messages"]) > 2:
|
||||
kwargs["messages"] = kwargs["messages"][[0]] + kwargs["messages"][3:]
|
||||
continue
|
||||
raise e
|
||||
raise Exception(f"Failed to create chat completion after {max_retry} retries.")
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages,
|
||||
model=None,
|
||||
temperature: float = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
|
||||
if self.debug_mode:
|
||||
key = json.dumps(messages)
|
||||
if key in self.cache:
|
||||
return self.cache[key]
|
||||
|
||||
if temperature is None:
|
||||
temperature = self.cfg.temperature
|
||||
if max_tokens is None:
|
||||
max_tokens = self.cfg.max_tokens
|
||||
|
||||
if self.cfg.use_azure:
|
||||
response = openai.ChatCompletion.create(
|
||||
engine=self.cfg.model,
|
||||
messages=messages,
|
||||
max_tokens=self.cfg.max_tokens,
|
||||
)
|
||||
else:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.cfg.model,
|
||||
messages=messages,
|
||||
)
|
||||
resp = response.choices[0].message["content"]
|
||||
if self.debug_mode:
|
||||
self.cache[key] = resp
|
||||
json.dump(self.cache, open(self.cache_file_location, "w"))
|
||||
return resp
|
||||
131
qlib/finco/log.py
Normal file
131
qlib/finco/log.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
This module will base on Qlib's logger module and provides some interactive functions.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from typing import Dict, List
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class LogColors:
|
||||
"""
|
||||
ANSI color codes for use in console output.
|
||||
"""
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
BLUE = "\033[94m"
|
||||
MAGENTA = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
WHITE = "\033[97m"
|
||||
GRAY = "\033[90m"
|
||||
BLACK = "\033[30m"
|
||||
|
||||
BOLD = "\033[1m"
|
||||
ITALIC = "\033[3m"
|
||||
|
||||
END = "\033[0m"
|
||||
|
||||
@classmethod
|
||||
def get_all_colors(cls):
|
||||
names = dir(cls)
|
||||
names = [name for name in names if not name.startswith("__") and not callable(getattr(cls, name))]
|
||||
var_values = [getattr(cls, name) for name in names]
|
||||
return var_values
|
||||
|
||||
def render(self, text: str, color: str = "", style: str = ""):
|
||||
"""
|
||||
render text by input color and style. It's not recommend that input text is already rendered.
|
||||
"""
|
||||
# This method is called too frequently, which is not good.
|
||||
colors = self.get_all_colors()
|
||||
# Perhaps color and font should be distinguished here.
|
||||
if color:
|
||||
assert color in colors, f"color should be in: {colors} but now is: {color}"
|
||||
if style:
|
||||
assert style in colors, f"style should be in: {colors} but now is: {style}"
|
||||
|
||||
text = f"{color}{text}{self.END}"
|
||||
text = f"{style}{text}{self.END}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@contextmanager
|
||||
def formatting_log(logger, title="Info"):
|
||||
"""
|
||||
a context manager, print liens before and after a function
|
||||
"""
|
||||
length = {"Start": 120, "Task": 120, "Info": 60, "Interact": 60, "End": 120}.get(title, 60)
|
||||
color, bold = (LogColors.YELLOW, LogColors.BOLD) \
|
||||
if title in ["Start", "Task", "Info", "Interact", "End"] else (LogColors.CYAN, "")
|
||||
logger.info("")
|
||||
logger.info(f"{color}{bold}{'-'} {title} {'-' * (length - len(title))}{LogColors.END}")
|
||||
yield
|
||||
logger.info("")
|
||||
|
||||
|
||||
class FinCoLog(SingletonBaseClass):
|
||||
# TODO:
|
||||
# - config to file logger and save it into workspace
|
||||
def __init__(self) -> None:
|
||||
self.logger = logging.Logger("interactive")
|
||||
# TODO: merge these with Qlib's default logger.
|
||||
# We can do the same thing by changing the default log dict of Qlib.
|
||||
# Reference: https://github.com/microsoft/qlib/blob/main/qlib/config.py#L155
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
def log_message(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
messages is some info like this [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
},
|
||||
]
|
||||
"""
|
||||
with formatting_log(self.logger, "GPT Messages"):
|
||||
for m in messages:
|
||||
self.logger.info(
|
||||
f"{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END} "
|
||||
f"{LogColors.CYAN}{m['role']}{LogColors.END}\n"
|
||||
+ f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} "
|
||||
f"{LogColors.CYAN}{m['content']}{LogColors.END}\n")
|
||||
|
||||
def log_response(self, response: str):
|
||||
with formatting_log(self.logger, "GPT Response"):
|
||||
self.logger.info(
|
||||
f"{LogColors.CYAN}{response}{LogColors.END}\n")
|
||||
|
||||
# TODO:
|
||||
# It looks wierd if we only have logger
|
||||
def info(self, *args, plain=False, title="Info"):
|
||||
if plain:
|
||||
return self.plain_info(*args)
|
||||
with formatting_log(self.logger, title):
|
||||
for arg in args:
|
||||
self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}")
|
||||
|
||||
def plain_info(self, *args):
|
||||
for arg in args:
|
||||
self.logger.info(
|
||||
f"{LogColors.YELLOW}{LogColors.BOLD}Info:{LogColors.END}{LogColors.WHITE}{arg}{LogColors.END}")
|
||||
|
||||
def warning(self, *args):
|
||||
for arg in args:
|
||||
self.logger.warning(
|
||||
f"{LogColors.BLUE}{LogColors.BOLD}Warning:{LogColors.END}{arg}")
|
||||
|
||||
def error(self, *args):
|
||||
for arg in args:
|
||||
self.logger.error(
|
||||
f"{LogColors.RED}{LogColors.BOLD}Error:{LogColors.END}{arg}")
|
||||
32
qlib/finco/prompt_template.py
Normal file
32
qlib/finco/prompt_template.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from jinja2 import Template
|
||||
import yaml
|
||||
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
from qlib.finco import get_finco_path
|
||||
|
||||
|
||||
class PromptTemplate(SingletonBaseClass):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
_template = yaml.load(open(Path.joinpath(get_finco_path(), "prompt_template.yaml"), "r"),
|
||||
Loader=yaml.FullLoader)
|
||||
for k, v in _template.items():
|
||||
if k == "mods":
|
||||
continue
|
||||
self.__setattr__(k, Template(v))
|
||||
|
||||
def get(self, key: str):
|
||||
return self.__dict__.get(key, Template(""))
|
||||
|
||||
def update(self, key: str, value):
|
||||
self.__setattr__(key, value)
|
||||
|
||||
def save(self, file_path: Union[str, Path]):
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
Path.mkdir(file_path.parent, exist_ok=True)
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
yaml.dump(self.__dict__, f)
|
||||
1012
qlib/finco/prompt_template.yaml
Normal file
1012
qlib/finco/prompt_template.yaml
Normal file
File diff suppressed because it is too large
Load Diff
1110
qlib/finco/task.py
Normal file
1110
qlib/finco/task.py
Normal file
File diff suppressed because it is too large
Load Diff
12
qlib/finco/tpl/README.md
Normal file
12
qlib/finco/tpl/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
This is a set of templates that should be copied for a new project.
|
||||
|
||||
Here are the explanations for the templates folder.
|
||||
|
||||
| folder | explanations |
|
||||
|--------|------------------------------------------------------------------|
|
||||
| sl | Default configuration for supervised learning |
|
||||
| sl-cfg | Like configuration in sl. But the dataset is highly configurable |
|
||||
|
||||
|
||||
# TODO
|
||||
- [ ] [Copier](https://copier.readthedocs.io/en/stable/#quick-start) may be useful if the generation process becomes complicated
|
||||
13
qlib/finco/tpl/__init__.py
Normal file
13
qlib/finco/tpl/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
def get_tpl_path() -> Path:
|
||||
"""
|
||||
return the template path
|
||||
Because the template path is located in the folder. We don't know where it is located. So __file__ for this module will be used.
|
||||
"""
|
||||
return DIRNAME
|
||||
83
qlib/finco/tpl/sl-cfg/workflow_config.yaml
Normal file
83
qlib/finco/tpl/sl-cfg/workflow_config.yaml
Normal file
File diff suppressed because one or more lines are too long
@@ -1,6 +1,7 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
experiment_name: finCo
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
@@ -14,7 +15,7 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
38
qlib/finco/utils.py
Normal file
38
qlib/finco/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
|
||||
from fuzzywuzzy import fuzz
|
||||
|
||||
|
||||
class SingletonMeta(type):
|
||||
_instance = None
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
|
||||
class SingletonBaseClass(metaclass=SingletonMeta):
|
||||
"""
|
||||
Because we try to support defining Singleton with `class A(SingletonBaseClass)` instead of `A(metaclass=SingletonMeta)`
|
||||
This class becomes necessary
|
||||
|
||||
"""
|
||||
# TODO: Add move this class to Qlib's general utils.
|
||||
|
||||
|
||||
def parse_json(response):
|
||||
try:
|
||||
return json.loads(response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
|
||||
raise Exception(f"Failed to parse response: {response}, please report it or help us to fix it.")
|
||||
|
||||
|
||||
def similarity(text1, text2):
|
||||
text1 = text1 if isinstance(text1, str) else ""
|
||||
text2 = text2 if isinstance(text2, str) else ""
|
||||
|
||||
# Maybe we can use other similarity algorithm such as tfidf
|
||||
return fuzz.ratio(text1, text2)
|
||||
223
qlib/finco/workflow.py
Normal file
223
qlib/finco/workflow.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import sys
|
||||
import copy
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from qlib.finco.task import HighLevelPlanTask, SummarizeTask, TrainTask
|
||||
from qlib.finco.prompt_template import PromptTemplate, Template
|
||||
from qlib.finco.log import FinCoLog, LogColors
|
||||
from qlib.finco.utils import similarity
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.conf import Config
|
||||
from qlib.finco.knowledge import KnowledgeBase, Topic
|
||||
|
||||
|
||||
class WorkflowContextManager:
|
||||
"""Context Manager stores the context of the workflow"""
|
||||
|
||||
"""All context are key value pairs which saves the input, output and status of the whole workflow"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context = {}
|
||||
self.logger = FinCoLog()
|
||||
|
||||
def set_context(self, key, value):
|
||||
if key in self.context:
|
||||
self.logger.warning("The key already exists in the context, the value will be overwritten")
|
||||
self.context[key] = value
|
||||
|
||||
def get_context(self, key):
|
||||
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
|
||||
if key not in self.context:
|
||||
self.logger.warning("The key doesn't exist in the context")
|
||||
return None
|
||||
return self.context[key]
|
||||
|
||||
def update_context(self, key, new_value):
|
||||
# NOTE: if the key doesn't exist, return None. In the future, we may raise an error to detect abnormal behavior
|
||||
if key not in self.context:
|
||||
self.logger.warning("The key doesn't exist in the context")
|
||||
self.context.update({key: new_value})
|
||||
|
||||
def get_all_context(self):
|
||||
"""return a deep copy of the context"""
|
||||
"""TODO: do we need to return a deep copy?"""
|
||||
return copy.deepcopy(self.context)
|
||||
|
||||
def retrieve(self, query: str) -> dict:
|
||||
if query in self.context.keys():
|
||||
return {query: self.context.get(query)}
|
||||
|
||||
# Note: retrieve information from context by string similarity maybe abandon in future
|
||||
scores = {}
|
||||
for k, v in self.context.items():
|
||||
scores.update({k: max(similarity(query, k), similarity(query, v))})
|
||||
max_score_key = max(scores, key=scores.get)
|
||||
return {max_score_key: self.context.get(max_score_key)}
|
||||
|
||||
def clear(self, reserve: list = None):
|
||||
if reserve is None:
|
||||
reserve = []
|
||||
|
||||
_context = {k: self.get_context(k) for k in reserve}
|
||||
self.context = _context
|
||||
|
||||
|
||||
class WorkflowManager:
|
||||
"""This manage the whole task automation workflow including tasks and actions"""
|
||||
|
||||
def __init__(self, workspace=None) -> None:
|
||||
self.logger = FinCoLog()
|
||||
|
||||
if workspace is None:
|
||||
self._workspace = Path.cwd() / "finco_workspace"
|
||||
else:
|
||||
self._workspace = Path(workspace)
|
||||
self.conf = Config()
|
||||
self._confirm_and_rm()
|
||||
|
||||
self.prompt_template = PromptTemplate()
|
||||
self.context = WorkflowContextManager()
|
||||
self.context.set_context("workspace", self._workspace)
|
||||
self.default_user_prompt = "Please help me build a low turnover strategy that focus more on longterm return in China A csi300. Please help to use lightgbm model."
|
||||
|
||||
def _confirm_and_rm(self):
|
||||
# if workspace exists, please confirm and remove it. Otherwise exit.
|
||||
if self._workspace.exists() and not self.conf.continuous_mode:
|
||||
self.logger.info(title="Interact")
|
||||
flag = input(
|
||||
LogColors().render(
|
||||
f"Will be deleted: \n\t{self._workspace}\n"
|
||||
f"If you do not need to delete {self._workspace},"
|
||||
f" please change the workspace dir or rename existing files\n"
|
||||
f"Are you sure you want to delete, yes(Y/y), no (N/n):",
|
||||
color=LogColors.WHITE)
|
||||
)
|
||||
if str(flag) not in ["Y", "y"]:
|
||||
sys.exit()
|
||||
else:
|
||||
# remove self._workspace
|
||||
shutil.rmtree(self._workspace)
|
||||
elif self._workspace.exists() and self.conf.continuous_mode:
|
||||
shutil.rmtree(self._workspace)
|
||||
|
||||
def set_context(self, key, value):
|
||||
"""Direct call set_context method of the context manager"""
|
||||
self.context.set_context(key, value)
|
||||
|
||||
def get_context(self) -> WorkflowContextManager:
|
||||
return self.context
|
||||
|
||||
def run(self, prompt: str) -> Path:
|
||||
"""
|
||||
The workflow manager is supposed to generate a codebase based on the prompt
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prompt: str
|
||||
the prompt user gives
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The workflow manager is expected to produce output that includes a codebase containing generated code, results, and reports in a designated location.
|
||||
The path is returned
|
||||
|
||||
The output path should follow a specific format:
|
||||
- TODO: design
|
||||
There is a summarized report where user can start from.
|
||||
"""
|
||||
|
||||
# NOTE: The following items are not designed to make the workflow very flexible.
|
||||
# - The generated tasks can't be changed after geting new information from the execution retuls.
|
||||
# - But it is required in some cases, if we want to build a external dataset, it maybe have to plan like autogpt...
|
||||
|
||||
# NOTE: default user prompt might be changed in the future and exposed to the user
|
||||
if prompt is None:
|
||||
self.set_context("user_prompt", self.default_user_prompt)
|
||||
else:
|
||||
self.set_context("user_prompt", prompt)
|
||||
self.logger.info(f"user_prompt: {self.get_context().get_context('user_prompt')}", title="Start")
|
||||
|
||||
# NOTE: list may not be enough for general task list
|
||||
task_list = [HighLevelPlanTask(), SummarizeTask()]
|
||||
task_finished = []
|
||||
while len(task_list):
|
||||
task_list_info = [str(task) for task in task_list]
|
||||
|
||||
# task list is not long, so sort it is not a big problem
|
||||
# TODO: sort the task list based on the priority of the task
|
||||
# task_list = sorted(task_list, key=lambda x: x.task_type)
|
||||
t = task_list.pop(0)
|
||||
self.logger.info(f"Task finished: {[str(task) for task in task_finished]}",
|
||||
f"Task in queue: {task_list_info}",
|
||||
f"Executing task: {str(t)}",
|
||||
title="Task")
|
||||
|
||||
t.assign_context_manager(self.context)
|
||||
res = t.execute()
|
||||
t.summarize()
|
||||
task_finished.append(t)
|
||||
self.context.set_context("task_finished", task_finished)
|
||||
self.logger.plain_info(f"{str(t)} finished.\n\n\n")
|
||||
|
||||
task_list = res + task_list
|
||||
|
||||
return self._workspace
|
||||
|
||||
|
||||
class LearnManager:
|
||||
__DEFAULT_TOPICS = ["IC", "MaxDropDown"]
|
||||
|
||||
def __init__(self):
|
||||
self.epoch = 0
|
||||
self.wm = WorkflowManager()
|
||||
|
||||
topics = [Topic(name=topic, describe=self.wm.prompt_template.get(f"Topic_{topic}")) for topic in
|
||||
self.__DEFAULT_TOPICS]
|
||||
self.knowledge_base = KnowledgeBase(init_path=Path.cwd().joinpath('knowledge'), topics=topics)
|
||||
|
||||
def run(self, prompt):
|
||||
# todo: add early stop condition
|
||||
for i in range(10):
|
||||
self.wm.run(prompt)
|
||||
self.knowledge_base.update(self.wm._workspace)
|
||||
self.knowledge_base.summarize_by_topic()
|
||||
self.learn()
|
||||
self.epoch += 1
|
||||
|
||||
def learn(self):
|
||||
workspace = self.wm.context.get_context("workspace")
|
||||
|
||||
def _drop_duplicate_task(_task: List):
|
||||
unique_task = {}
|
||||
for obj in _task:
|
||||
task_name = obj.__class__.__name__
|
||||
if task_name not in unique_task:
|
||||
unique_task[task_name] = obj
|
||||
return list(unique_task.values())
|
||||
|
||||
# one task maybe run several times in workflow
|
||||
task_finished = _drop_duplicate_task(self.wm.context.get_context("task_finished"))
|
||||
|
||||
user_prompt = self.wm.context.get_context("user_prompt")
|
||||
summary = self.wm.context.get_context("summary")
|
||||
|
||||
for task in task_finished:
|
||||
prompt_workflow_selection = self.wm.prompt_template.get(f"{self.__class__.__name__}_user").render(
|
||||
summary=summary, brief=self.knowledge_base.query_topics(),
|
||||
task_finished=[str(t) for t in task_finished],
|
||||
task=task.__class__.__name__, system=task.system.render(), user_prompt=user_prompt
|
||||
)
|
||||
|
||||
response = APIBackend().build_messages_and_create_chat_completion(
|
||||
user_prompt=prompt_workflow_selection,
|
||||
system_prompt=self.wm.prompt_template.get(f"{self.__class__.__name__}_system").render()
|
||||
)
|
||||
|
||||
# todo: response assertion
|
||||
task.prompt_template.update(key=f"{task.__class__.__name__}_system", value=Template(response))
|
||||
|
||||
self.wm.prompt_template.save(Path.joinpath(workspace, f"prompts/checkpoint_{self.epoch}.yml"))
|
||||
self.wm.context.clear(reserve=["workspace"])
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import qlib
|
||||
@@ -12,15 +11,13 @@ import datetime
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from cryptography.fernet import Fernet
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
class GetData:
|
||||
DATASET_VERSION = "v2"
|
||||
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
|
||||
# "?" is not included in the token.
|
||||
TOKEN = "gAAAAABkmDhojHc0VSCDdNK1MqmRzNLeDFXe5hy8obHpa6SDQh4de6nW5gtzuD-fa6O_WZb0yyqYOL7ndOfJX_751W3xN5YB4-n-P22jK-t6ucoZqhT70KPD0Lf0_P328QPJVZ1gDnjIdjhi2YLOcP4BFTHLNYO0mvzszR8TKm9iT5AKRvuysWnpi8bbYwGU9zAcJK3x9EPL43hOGtxliFHcPNGMBoJW4g_ercdhi0-Qgv5_JLsV-29_MV-_AhuaYvJuN2dEywBy"
|
||||
KEY = "EYcA8cgorA8X9OhyMwVfuFxn_1W3jGk6jCbs3L2oPoA="
|
||||
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
@@ -32,44 +29,24 @@ class GetData:
|
||||
"""
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def merge_remote_url(self, file_name: str):
|
||||
fernet = Fernet(self.KEY)
|
||||
token = fernet.decrypt(self.TOKEN).decode()
|
||||
return f"{self.REMOTE_URL}/{file_name}?{token}"
|
||||
def normalize_dataset_version(self, dataset_version: str = None):
|
||||
if dataset_version is None:
|
||||
dataset_version = self.DATASET_VERSION
|
||||
return dataset_version
|
||||
|
||||
def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True):
|
||||
"""
|
||||
Download the specified file to the target folder.
|
||||
def merge_remote_url(self, file_name: str, dataset_version: str = None):
|
||||
return f"{self.REMOTE_URL}/{self.normalize_dataset_version(dataset_version)}/{file_name}"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
file_name: str
|
||||
dataset name, needs to endwith .zip, value from [rl_data.zip, csv_data_cn.zip, ...]
|
||||
may contain folder names, for example: v2/qlib_data_simple_cn_1d_latest.zip
|
||||
delete_old: bool
|
||||
delete an existing directory, by default True
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get rl data
|
||||
python get_data.py download_data --file_name rl_data.zip --target_dir ~/.qlib/qlib_data/rl_data
|
||||
When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/rl_data.zip?{token}
|
||||
|
||||
# get cn csv data
|
||||
python get_data.py download_data --file_name csv_data_cn.zip --target_dir ~/.qlib/csv_data/cn_data
|
||||
When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/csv_data_cn.zip?{token}
|
||||
-------
|
||||
|
||||
"""
|
||||
def _download_data(
|
||||
self, file_name: str, target_dir: [Path, str], delete_old: bool = True, dataset_version: str = None
|
||||
):
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
target_dir.mkdir(exist_ok=True, parents=True)
|
||||
# saved file name
|
||||
_target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + os.path.basename(file_name)
|
||||
_target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + file_name
|
||||
target_path = target_dir.joinpath(_target_file_name)
|
||||
|
||||
url = self.merge_remote_url(file_name)
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
resp = requests.get(url, stream=True, timeout=60)
|
||||
resp.raise_for_status()
|
||||
if resp.status_code != 200:
|
||||
@@ -79,7 +56,7 @@ class GetData:
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{os.path.basename(file_name)} downloading......")
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
@@ -90,8 +67,8 @@ class GetData:
|
||||
if self.delete_zip_file:
|
||||
target_path.unlink()
|
||||
|
||||
def check_dataset(self, file_name: str):
|
||||
url = self.merge_remote_url(file_name)
|
||||
def check_dataset(self, file_name: str, dataset_version: str = None):
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
resp = requests.get(url, stream=True, timeout=60)
|
||||
status = True
|
||||
if resp.status_code == 404:
|
||||
@@ -163,11 +140,9 @@ class GetData:
|
||||
---------
|
||||
# get 1d data
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/v2/qlib_data_cn_1d_latest.zip?{token}
|
||||
|
||||
# get 1min data
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --interval 1min --region cn
|
||||
When this command is run, the data will be downloaded from this link: https://qlibpublic.blob.core.windows.net/data/default/stock_data/v2/qlib_data_cn_1min_latest.zip?{token}
|
||||
-------
|
||||
|
||||
"""
|
||||
@@ -180,12 +155,29 @@ class GetData:
|
||||
|
||||
qlib_version = ".".join(re.findall(r"(\d+)\.+", qlib.__version__))
|
||||
|
||||
def _get_file_name_with_version(qlib_version, dataset_version):
|
||||
dataset_version = "v2" if dataset_version is None else dataset_version
|
||||
file_name_with_version = f"{dataset_version}/{name}_{region.lower()}_{interval.lower()}_{qlib_version}.zip"
|
||||
return file_name_with_version
|
||||
def _get_file_name(v):
|
||||
return self.QLIB_DATA_NAME.format(
|
||||
dataset_name=name, region=region.lower(), interval=interval.lower(), qlib_version=v
|
||||
)
|
||||
|
||||
file_name = _get_file_name_with_version(qlib_version, dataset_version=version)
|
||||
if not self.check_dataset(file_name):
|
||||
file_name = _get_file_name_with_version("latest", dataset_version=version)
|
||||
self.download_data(file_name.lower(), target_dir, delete_old)
|
||||
file_name = _get_file_name(qlib_version)
|
||||
if not self.check_dataset(file_name, version):
|
||||
file_name = _get_file_name("latest")
|
||||
self._download_data(file_name.lower(), target_dir, delete_old, dataset_version=version)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "csv_data_cn.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..utils import fill_placeholder, flatten_dict, class_casting, get_date_by_s
|
||||
from ..utils.time import Freq
|
||||
from ..utils.data import deepcopy_basic_type
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
from qlib.contrib.analyzer import HFAnalyzer, SignalAnalyzer
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
@@ -156,6 +156,9 @@ class RecordTemp:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
|
||||
def analyse(self):
|
||||
raise NotImplementedError(f"Please implement the `analysis` method.")
|
||||
|
||||
|
||||
class SignalRecord(RecordTemp):
|
||||
"""
|
||||
|
||||
15
scripts/finco/README.md
Normal file
15
scripts/finco/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
|
||||
# Requirements
|
||||
|
||||
|
||||
Use following install command to complete the project.
|
||||
```
|
||||
pip install -e '.[finco]'
|
||||
```
|
||||
|
||||
|
||||
# TODOs
|
||||
|
||||
- [ ] Select the appropriate LLM API
|
||||
- Which API is more suitable for meeting our requirements - the original API or an alternative like LangChain?
|
||||
15
scripts/finco/cmd.sh
Normal file
15
scripts/finco/cmd.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
set -x # show command
|
||||
set -e # Error on exception
|
||||
|
||||
DIR="$(
|
||||
cd "$(dirname "$(readlink -f "$0")")" || exit
|
||||
pwd -P
|
||||
)"
|
||||
# --load the cridentials
|
||||
if [ -e $DIR/cridential.sh ]; then
|
||||
source $DIR/cridential.sh
|
||||
fi
|
||||
|
||||
# run the command
|
||||
python -m qlib.finco.cli "please help me build a low turnover strategy that focus more on longterm return"
|
||||
3
scripts/finco/cridential.sh.example
Normal file
3
scripts/finco/cridential.sh.example
Normal file
@@ -0,0 +1,3 @@
|
||||
export OPENAI_API_TYPE=azure # This only necessary for Azure OpenAI
|
||||
export OPENAI_API_KEY=
|
||||
export OPENAI_API_BASE=
|
||||
9
setup.py
9
setup.py
@@ -80,7 +80,6 @@ REQUIRED = [
|
||||
"gym",
|
||||
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
|
||||
"protobuf<=3.20.1;python_version<='3.8'",
|
||||
"cryptography",
|
||||
]
|
||||
|
||||
# Numpy include
|
||||
@@ -174,6 +173,14 @@ setup(
|
||||
"tianshou<=0.4.10",
|
||||
"torch",
|
||||
],
|
||||
"finco": [
|
||||
# finco is not necessary for all Qlib users; So a single require section is used for it.
|
||||
"openapi",
|
||||
"pydantic", # Please add it to basic requirements after the design of pydantic is state.
|
||||
"python-dotenv", # I don't think this is necessary if we use pydantic.
|
||||
"fuzzywuzzy",
|
||||
"python-Levenshtein" # not necessary but accelerate fuzzywuzzy calculation
|
||||
],
|
||||
},
|
||||
include_package_data=True,
|
||||
classifiers=[
|
||||
|
||||
71
tests/finco/test_cfg.py
Normal file
71
tests/finco/test_cfg.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import unittest
|
||||
import shutil
|
||||
import difflib
|
||||
from qlib.finco.tpl import get_tpl_path
|
||||
import ruamel.yaml as yaml
|
||||
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
from pathlib import Path
|
||||
from qlib.finco.tpl import get_tpl_path
|
||||
from qlib.finco.task import YamlEditTask
|
||||
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
|
||||
class FincoTpl(TestAutoData):
|
||||
def test_tpl_consistence(self):
|
||||
"""Motivation: make sure the configuable template is consistent with the default config"""
|
||||
tpl_p = get_tpl_path()
|
||||
with (tpl_p / "sl" / "workflow_config.yaml").open("rb") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
# init_data_handler
|
||||
hd: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"])
|
||||
# NOTE: The config in workflow_config.yaml is generated by the following code:
|
||||
# dump in yaml format to file without auto linebreak
|
||||
# print(yaml.dump(hd.data_loader.fields, width=10000, stream=open("_tmp", "w")))
|
||||
|
||||
with (tpl_p / "sl-cfg" / "workflow_config.yaml").open("rb") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
hd_ds: DataHandlerLP = init_instance_by_config(config["task"]["dataset"]["kwargs"]["handler"])
|
||||
self.assertEqual(hd_ds.data_loader.fields, hd.data_loader.fields)
|
||||
|
||||
check = hd_ds.fetch().fillna(0.0) == hd.fetch().fillna(0.0)
|
||||
self.assertTrue(check.all().all())
|
||||
|
||||
def test_update_yaml(self):
|
||||
p = get_tpl_path() / "sl" / "workflow_config.yaml"
|
||||
p_new = DIRNAME / "_test_config.yaml"
|
||||
shutil.copy(p, p_new)
|
||||
updated_content = """
|
||||
class: LGBModelTest
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 1.8879
|
||||
learning_rate: 0.3
|
||||
subsample: 0.8790
|
||||
lambda_l1: 205.7000
|
||||
lambda_l2: 580.9769
|
||||
max_depth: 9
|
||||
num_leaves: 211
|
||||
num_threads: 21
|
||||
"""
|
||||
t = YamlEditTask(p_new, "task.model", updated_content)
|
||||
t.execute()
|
||||
# NOTE: the formmat is changed by ruamel.yaml, so it can't be compared by text directly..
|
||||
# print the diff between p and p_new with difflib
|
||||
# with p.open("r") as fp:
|
||||
# content = fp.read()
|
||||
# with p_new.open("r") as fp:
|
||||
# content_new = fp.read()
|
||||
# for line in difflib.unified_diff(content, content_new, fromfile="original", tofile="new", lineterm=""):
|
||||
# print(line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
66
tests/finco/test_sumarize.py
Normal file
66
tests/finco/test_sumarize.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import unittest
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from dotenv import load_dotenv
|
||||
# pydantic support load_dotenv, so load_dotenv will be deprecated in the future.
|
||||
|
||||
from qlib.finco.task import SummarizeTask
|
||||
from qlib.finco.workflow import WorkflowContextManager
|
||||
from qlib.finco.llm import APIBackend
|
||||
from qlib.finco.workflow import WorkflowManager
|
||||
|
||||
load_dotenv(verbose=True, override=True)
|
||||
|
||||
|
||||
class TestSummarize(unittest.TestCase):
|
||||
|
||||
def test_chat(self):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your are a professional financial assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How to write a perfect quant strategy.",
|
||||
},
|
||||
]
|
||||
response = APIBackend().try_create_chat_completion(messages=messages)
|
||||
print(response)
|
||||
|
||||
def test_execution(self):
|
||||
task = SummarizeTask()
|
||||
context = WorkflowContextManager()
|
||||
context.set_context("workspace", "../../examples/benchmarks/Linear")
|
||||
context.set_context("user_prompt", "My main focus is on the performance of the strategy's return."
|
||||
"Please summarize the information and give me some advice.")
|
||||
task.assign_context_manager(context)
|
||||
resp = task.execute()
|
||||
print(resp)
|
||||
|
||||
def test_generate_batch_result(self):
|
||||
wm = WorkflowManager()
|
||||
|
||||
prompt = wm.default_user_prompt
|
||||
# prompt = ""
|
||||
|
||||
workdir = os.path.dirname(wm.get_context().get_context("workspace"))
|
||||
summaries_path = os.path.join(workdir, "summaries")
|
||||
|
||||
if not os.path.exists(summaries_path):
|
||||
os.makedirs(summaries_path)
|
||||
|
||||
for i in range(10):
|
||||
wm.run(prompt)
|
||||
if os.path.exists(f"{workdir}/finCoReport.md"):
|
||||
shutil.move(f"{workdir}/finCoReport.md", f"{workdir}/summaries/finCoReport{i}.md")
|
||||
|
||||
def test_parse2txt(self):
|
||||
task = SummarizeTask()
|
||||
resp = task.get_info_from_file("")
|
||||
print(resp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
23
tests/finco/test_utils.py
Normal file
23
tests/finco/test_utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import unittest
|
||||
from qlib.finco.utils import SingletonBaseClass
|
||||
|
||||
|
||||
class TimeUtils(unittest.TestCase):
|
||||
|
||||
def test_singleton(self):
|
||||
# self.assertEqual(self.to_str(data.tail()), self.to_str(res))
|
||||
closure_checker = []
|
||||
|
||||
class A(SingletonBaseClass):
|
||||
|
||||
def __init__(self) -> None:
|
||||
closure_checker.append(0)
|
||||
|
||||
A()
|
||||
self.assertEqual(len(closure_checker), 1)
|
||||
A()
|
||||
self.assertEqual(len(closure_checker), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -35,7 +35,7 @@ class TestDumpData(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
GetData().download_data(file_name="csv_data_cn.zip", target_dir=SOURCE_DIR)
|
||||
GetData().csv_data_cn(SOURCE_DIR)
|
||||
TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
|
||||
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestGetData(unittest.TestCase):
|
||||
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
||||
|
||||
def test_1_csv_data(self):
|
||||
GetData().download_data(file_name="csv_data_cn.zip", target_dir=SOURCE_DIR)
|
||||
GetData().csv_data_cn(SOURCE_DIR)
|
||||
stock_name = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
self.assertEqual(len(stock_name), 85, "get csv data failed")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user