mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 09:01:18 +08:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4709909782 | ||
|
|
a0f49fe2e7 | ||
|
|
2840570dd3 | ||
|
|
00ad122175 | ||
|
|
3493f29e16 | ||
|
|
e33de44cb9 | ||
|
|
e843e021a2 | ||
|
|
5aa5a6f356 | ||
|
|
f490708025 | ||
|
|
41a5778684 | ||
|
|
ef161715f7 | ||
|
|
d087054a59 | ||
|
|
350fbe91c9 | ||
|
|
2aca74cd21 | ||
|
|
92ff3d20b9 | ||
|
|
0552120a2e | ||
|
|
3480fd932f | ||
|
|
957f9a18e9 | ||
|
|
6c83632fc4 | ||
|
|
125922b77a | ||
|
|
5e69d089c0 | ||
|
|
c10c349b20 | ||
|
|
7cb1f7cee0 | ||
|
|
d0ff5eea9d | ||
|
|
e99f00b445 | ||
|
|
e50ad4309e | ||
|
|
d89ae2370f |
24
README.md
24
README.md
@@ -11,9 +11,11 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Release Qlib v0.8.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 |
|
||||
| ADD model | [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 |
|
||||
| ADARNN model | [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 |
|
||||
| TCN model | [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 |
|
||||
| Nested Decision Framework | [Released](https://github.com/microsoft/qlib/pull/438) on Oct 1, 2021. [Example](https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py) and [Doc](https://qlib.readthedocs.io/en/latest/component/highfreq.html) |
|
||||
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
@@ -67,7 +69,6 @@ Your feedbacks about the features are very important.
|
||||
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
|
||||
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
|
||||
| Meta-Learning-based data selection | Initial opensource version under development |
|
||||
|
||||
# Framework of Qlib
|
||||
@@ -159,15 +160,17 @@ Load and prepare data by running the following code:
|
||||
|
||||
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in
|
||||
the same repository.
|
||||
Users could create the same dataset with it.
|
||||
Users could create the same dataset with it. [Description of dataset](https://github.com/microsoft/qlib/tree/main/scripts/data_collector#description-of-dataset)
|
||||
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
### Automatic update of daily frequency data (from yahoo finance)
|
||||
> This step is *Optional* if users only want to try their models and strategies on history data.
|
||||
>
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
|
||||
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
>
|
||||
> For more information, please refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
@@ -397,17 +400,26 @@ Join IM discussion groups:
|
||||
||
|
||||
|
||||
# Contributing
|
||||
We appreciate all contributions and thank all the contributors!
|
||||
<a href="https://github.com/microsoft/qlib/graphs/contributors"><img src="https://contrib.rocks/image?repo=microsoft/qlib" /></a>
|
||||
|
||||
Before we released Qlib as an open-source project on Github in Sep 2020, Qlib is an internal project in our group. Unfortunately, the internal commit history is not kept. A lot of members in our group have also contributed a lot to Qlib, which includes Ruihua Wang, Yinda Zhang, Haisu Yu, Shuyu Wang, Bochen Pang, and [Dong Zhou](https://github.com/evanzd/evanzd). Especially thanks to [Dong Zhou](https://github.com/evanzd/evanzd) due to his initial version of Qlib.
|
||||
|
||||
## Guidance
|
||||
|
||||
This project welcomes contributions and suggestions.
|
||||
**Here are some
|
||||
[code standards](docs/developer/code_standard.rst) when you submit a pull request.**
|
||||
[code standards](docs/developer/code_standard.rst) for submiting a pull request.**
|
||||
|
||||
If you want to contribute to Qlib's document, you can follow the steps in the figure below.
|
||||
Making contributions is not a hard thing. Solving an issue(maybe just answering a question raised in [issues list](https://github.com/microsoft/qlib/issues) or [gitter](https://gitter.im/Microsoft/qlib)), fixing/issuing a bug, improving the documents and even fixing a typo are important contributions to Qlib.
|
||||
|
||||
For example, if you want to contribute to Qlib's document/code, you can follow the steps in the figure below.
|
||||
<p align="center">
|
||||
<img src="https://github.com/demon143/qlib/blob/main/docs/_static/img/change%20doc.gif" />
|
||||
</p>
|
||||
|
||||
|
||||
## Licence
|
||||
Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.8.0
|
||||
0.8.0.99
|
||||
|
||||
@@ -338,7 +338,7 @@ DataHandlerLP
|
||||
|
||||
In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
|
||||
|
||||
In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some leanable ``Processors`` which can learn the parameters of data processing(e.g., parameters for zscore normalization). When new data comes in, these `trained` ``Processors`` can then process the new data and thus processing real-time data in an efficient way becomes possible. More information about ``Processors`` will be listed in the next subsection.
|
||||
In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some learnable ``Processors`` which can learn the parameters of data processing(e.g., parameters for zscore normalization). When new data comes in, these `trained` ``Processors`` can then process the new data and thus processing real-time data in an efficient way becomes possible. More information about ``Processors`` will be listed in the next subsection.
|
||||
|
||||
|
||||
Interface
|
||||
|
||||
@@ -112,6 +112,9 @@ A prediction sample is shown as follows.
|
||||
|
||||
``Forecast Model`` module can make predictions, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
Normally, the prediction score is the output of the models. But some models are learned from a label with a different scale. So the scale of the prediction score may be different from your expectation(e.g. the return of instruments).
|
||||
|
||||
Qlib didn't add a step to scale the prediction score to a unified scale. Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order). So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).
|
||||
|
||||
Running backtest
|
||||
-----------------
|
||||
@@ -283,4 +286,4 @@ The backtest results are in the following form:
|
||||
|
||||
Reference
|
||||
===================
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
# High-Frequency Dataset
|
||||
# Introduction
|
||||
This folder contains 2 examples
|
||||
- A high-frequency dataset example
|
||||
- An example of predicting the price trend in high-frequency data
|
||||
|
||||
## High-Frequency Dataset
|
||||
|
||||
This dataset is an example for RL high frequency trading.
|
||||
|
||||
## Get High-Frequency Data
|
||||
### Get High-Frequency Data
|
||||
|
||||
Get high-frequency data by running the following command:
|
||||
```bash
|
||||
python workflow.py get_data
|
||||
```
|
||||
|
||||
## Dump & Reload & Reinitialize the Dataset
|
||||
### Dump & Reload & Reinitialize the Dataset
|
||||
|
||||
|
||||
The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.
|
||||
@@ -27,9 +32,9 @@ Run the example by running the following command:
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
## Benchmarks Performance (predicting the price trend in high-frequency data)
|
||||
|
||||
Here are the results of models for predicting the price trend in high-frequency data. We will keep updating benchmark models in future.
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
|
||||
@@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM, task_train
|
||||
from qlib.model.trainer import TrainerR, TrainerRM, task_train
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class RollingTaskExample:
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_pool=None, # if user want to "rolling_task"
|
||||
task_config=None,
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
@@ -43,14 +43,19 @@ class RollingTaskExample:
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
if task_pool is None:
|
||||
self.trainer = TrainerR(experiment_name=self.experiment_name)
|
||||
else:
|
||||
self.task_pool = task_pool
|
||||
self.trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
@@ -66,10 +71,10 @@ class RollingTaskExample:
|
||||
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
self.trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# NOTE: this is only used for TrainerRM
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
@@ -248,7 +248,7 @@ class ModelRunner:
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
it could be url on the we or local path (NOTE: the local path must be a absolute path)
|
||||
exp_folder_name: str
|
||||
the name of the experiment folder
|
||||
wait_before_rm_env : bool
|
||||
|
||||
@@ -62,12 +62,6 @@
|
||||
"import qlib\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
" risk_analysis,\n",
|
||||
")\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
_version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is copyed from setup.py
|
||||
__version__ = _version_path.read_text(encoding="utf-8").strip()
|
||||
__version__ = "0.8.0.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
@@ -16,6 +15,16 @@ from .log import get_module_logger
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs :
|
||||
clear_mem_cache: str
|
||||
the default value is True;
|
||||
Will the memory cache be clear.
|
||||
It is often used to improve performance when init will be called for multiple times
|
||||
"""
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
|
||||
@@ -29,7 +38,9 @@ def init(default_conf="client", **kwargs):
|
||||
logger.warning("Skip initialization because `skip_if_reg is True`")
|
||||
return
|
||||
|
||||
H.clear()
|
||||
clear_mem_cache = kwargs.pop("clear_mem_cache", True)
|
||||
if clear_mem_cache:
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# mount nfs
|
||||
|
||||
@@ -401,9 +401,9 @@ class Exchange:
|
||||
def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time):
|
||||
def get_volume(self, stock_id, start_time, end_time, method="sum"):
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method="sum")
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
|
||||
if direction == OrderDir.SELL:
|
||||
|
||||
@@ -395,9 +395,25 @@ class NestedExecutor(BaseExecutor):
|
||||
if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx:
|
||||
# if force align the range limit, skip the steps outside the decision range limit
|
||||
|
||||
_inner_trade_decision: BaseTradeDecision = self.inner_strategy.generate_trade_decision(
|
||||
_inner_execute_result
|
||||
)
|
||||
res = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
|
||||
# NOTE: !!!!!
|
||||
# the two lines below is for a special case in RL
|
||||
# To solve the confliction below
|
||||
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop
|
||||
# For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)])
|
||||
# - However, RL-based framework has it's own script to run the loop
|
||||
# For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])
|
||||
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution below is proposed
|
||||
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of RL Framework
|
||||
# - Each step of (RL Env) will make (inner Qlib Executor) one step forward
|
||||
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy
|
||||
# So the two lines below is the implementation of yielding control rights
|
||||
if isinstance(res, GeneratorType):
|
||||
res = yield from res
|
||||
|
||||
_inner_trade_decision: BaseTradeDecision = res
|
||||
|
||||
trade_decision.mod_inner_decision(_inner_trade_decision) # propagate part of decision information
|
||||
|
||||
# NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting
|
||||
@@ -407,6 +423,7 @@ class NestedExecutor(BaseExecutor):
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(
|
||||
trade_decision=_inner_trade_decision, level=level + 1
|
||||
)
|
||||
self.post_inner_exe_step(_inner_execute_result)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
|
||||
inner_order_indicators.append(
|
||||
@@ -418,6 +435,17 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
|
||||
|
||||
def post_inner_exe_step(self, inner_exe_res):
|
||||
"""
|
||||
A hook for doing sth after each step of inner strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
@@ -55,9 +55,9 @@ class TradeCalendarManager:
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
_calendar = Cal.calendar(freq=freq, future=True)
|
||||
self._calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq)
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True)
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
|
||||
@@ -10,6 +10,7 @@ Two modes are supported
|
||||
- server
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
@@ -18,7 +19,11 @@ import logging
|
||||
import platform
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.utils.time import Freq
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -296,7 +301,9 @@ class QlibConfig(Config):
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_uri(self, freq: str = None) -> Path:
|
||||
def get_data_uri(self, freq: Optional[Union[str, Freq]] = None) -> Path:
|
||||
if freq is not None:
|
||||
freq = str(freq) # converting Freq to string
|
||||
if freq is None or freq not in self.provider_uri:
|
||||
freq = QlibConfig.DEFAULT_FREQ
|
||||
_provider_uri = self.provider_uri[freq]
|
||||
|
||||
@@ -90,7 +90,13 @@ class Alpha360(DataHandlerLP):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
|
||||
def get_feature_config(self):
|
||||
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
# To make it easier to learn models from this dataset, all the prices and volume
|
||||
# are normalized by the latest price and volume data ( dividing by $close, $volume)
|
||||
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
|
||||
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
|
||||
@@ -267,7 +267,7 @@ class DNNModelPytorch(Model):
|
||||
loss = torch.mul(sqr_loss, w).mean()
|
||||
return loss
|
||||
elif loss_type == "binary":
|
||||
loss = nn.BCELoss(weight=w)
|
||||
loss = nn.BCEWithLogitsLoss(weight=w)
|
||||
return loss(pred, target)
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
@@ -334,16 +334,8 @@ class Net(nn.Module):
|
||||
dnn_layers.append(seq)
|
||||
drop_input = nn.Dropout(0.05)
|
||||
dnn_layers.append(drop_input)
|
||||
if loss == "mse":
|
||||
fc = nn.Linear(hidden_units, output_dim)
|
||||
dnn_layers.append(fc)
|
||||
|
||||
elif loss == "binary":
|
||||
fc = nn.Linear(hidden_units, output_dim)
|
||||
sigmoid = nn.Sigmoid()
|
||||
dnn_layers.append(nn.Sequential(fc, sigmoid))
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
fc = nn.Linear(hidden_units, output_dim)
|
||||
dnn_layers.append(fc)
|
||||
# optimizer
|
||||
self.dnn_layers = nn.ModuleList(dnn_layers)
|
||||
self._weight_init()
|
||||
|
||||
@@ -57,7 +57,7 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
|
||||
).figure
|
||||
|
||||
t_df = t_df.loc[:, ["long-short", "long-average"]]
|
||||
_bin_size = ((t_df.max() - t_df.min()) / 20).min()
|
||||
_bin_size = float(((t_df.max() - t_df.min()) / 20).min())
|
||||
group_hist_figure = SubplotsGraph(
|
||||
t_df,
|
||||
kind_map=dict(kind="DistplotGraph", kwargs=dict(bin_size=_bin_size)),
|
||||
|
||||
@@ -46,6 +46,7 @@ class Tuner:
|
||||
space=self.space,
|
||||
algo=tpe.suggest,
|
||||
max_evals=self.max_evals,
|
||||
show_progressbar=False,
|
||||
)
|
||||
self.logger.info("Local best params: {} ".format(self.best_params))
|
||||
TimeInspector.log_cost_time(
|
||||
|
||||
@@ -27,7 +27,6 @@ from .inst_processor import InstProcessor
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import Freq
|
||||
from ..utils.resam import resam_calendar
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
|
||||
@@ -197,6 +197,8 @@ class Fillna(Processor):
|
||||
|
||||
class MinMaxNorm(Processor):
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None):
|
||||
# NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!
|
||||
# `fit_end_time` **must not** include any information from the test data!!!
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
@@ -226,6 +228,8 @@ class ZScoreNorm(Processor):
|
||||
"""ZScore Normalization"""
|
||||
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None):
|
||||
# NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!
|
||||
# `fit_end_time` **must not** include any information from the test data!!!
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
@@ -263,6 +267,8 @@ class RobustZScoreNorm(Processor):
|
||||
"""
|
||||
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None, clip_outlier=True):
|
||||
# NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!
|
||||
# `fit_end_time` **must not** include any information from the test data!!!
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
@@ -302,7 +308,13 @@ class CSZScoreNorm(Processor):
|
||||
|
||||
|
||||
class CSRankNorm(Processor):
|
||||
"""Cross Sectional Rank Normalization"""
|
||||
"""
|
||||
Cross Sectional Rank Normalization.
|
||||
"Cross Sectional" is often used to describe data operations.
|
||||
The operations across different stocks are often called Cross Sectional Operation.
|
||||
|
||||
For example, CSRankNorm is an operation that grouping the data by each day and rank `across` all the stocks in each day.
|
||||
"""
|
||||
|
||||
def __init__(self, fields_group=None):
|
||||
self.fields_group = fields_group
|
||||
|
||||
@@ -11,6 +11,7 @@ import pandas as pd
|
||||
from qlib.utils.time import Freq
|
||||
from qlib.utils.resam import resam_calendar
|
||||
from qlib.config import C
|
||||
from qlib.data.cache import H
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
|
||||
|
||||
@@ -33,15 +34,15 @@ class FileStorageMixin:
|
||||
if hasattr(self, _v):
|
||||
return getattr(self, _v)
|
||||
if len(self.provider_uri) == 1 and C.DEFAULT_FREQ in self.provider_uri:
|
||||
freq = filter(
|
||||
freq_l = filter(
|
||||
lambda _freq: not _freq.endswith("_future"),
|
||||
map(lambda x: x.stem, self.dpm.get_data_uri(C.DEFAULT_FREQ).joinpath("calendars").glob("*.txt")),
|
||||
)
|
||||
else:
|
||||
freq = self.provider_uri.keys()
|
||||
freq = list(freq)
|
||||
setattr(self, _v, freq)
|
||||
return freq
|
||||
freq_l = self.provider_uri.keys()
|
||||
freq_l = [Freq(freq) for freq in freq_l]
|
||||
setattr(self, _v, freq_l)
|
||||
return freq_l
|
||||
|
||||
@property
|
||||
def uri(self) -> Path:
|
||||
@@ -65,15 +66,28 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
|
||||
self.future = future
|
||||
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self.resample_freq = None
|
||||
self.enable_read_cache = True # TODO: make it configurable
|
||||
|
||||
@property
|
||||
def file_name(self) -> str:
|
||||
return f"{self.use_freq}_future.txt" if self.future else f"{self.use_freq}.txt".lower()
|
||||
return f"{self._freq_file}_future.txt" if self.future else f"{self._freq_file}.txt".lower()
|
||||
|
||||
@property
|
||||
def use_freq(self) -> str:
|
||||
return self.freq if self.resample_freq is None else self.resample_freq
|
||||
def _freq_file(self) -> str:
|
||||
"""the freq to read from file"""
|
||||
if not hasattr(self, "_freq_file_cache"):
|
||||
freq = Freq(self.freq)
|
||||
if freq not in self.support_freq:
|
||||
# NOTE: uri
|
||||
# 1. If `uri` does not exist
|
||||
# - Get the `min_uri` of the closest `freq` under the same "directory" as the `uri`
|
||||
# - Read data from `min_uri` and resample to `freq`
|
||||
|
||||
freq = Freq.get_recent_freq(freq, self.support_freq)
|
||||
if freq is None:
|
||||
raise ValueError(f"can't find a freq from {self.support_freq} that can resample to {self.freq}!")
|
||||
self._freq_file_cache = freq
|
||||
return self._freq_file_cache
|
||||
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
|
||||
if not self.uri.exists():
|
||||
@@ -90,25 +104,21 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
|
||||
@property
|
||||
def uri(self) -> Path:
|
||||
freq = self.freq
|
||||
if freq not in self.support_freq:
|
||||
# NOTE: uri
|
||||
# 1. If `uri` does not exist
|
||||
# - Get the `min_uri` of the closest `freq` under the same "directory" as the `uri`
|
||||
# - Read data from `min_uri` and resample to `freq`
|
||||
|
||||
freq = Freq.get_recent_freq(freq, self.support_freq)
|
||||
if freq is None:
|
||||
raise ValueError(f"can't find a freq from {self.support_freq} that can resample to {self.freq}!")
|
||||
self.resample_freq = freq
|
||||
return self.dpm.get_data_uri(self.use_freq).joinpath(f"{self.storage_name}s", self.file_name)
|
||||
return self.dpm.get_data_uri(self._freq_file).joinpath(f"{self.storage_name}s", self.file_name)
|
||||
|
||||
@property
|
||||
def data(self) -> List[CalVT]:
|
||||
self.check()
|
||||
_calendar = self._read_calendar()
|
||||
if self.resample_freq is not None:
|
||||
_calendar = resam_calendar(np.array(list(map(pd.Timestamp, _calendar))), self.resample_freq, self.freq)
|
||||
# If cache is enabled, then return cache directly
|
||||
if self.enable_read_cache:
|
||||
key = "orig_file" + str(self.uri)
|
||||
if not key in H["c"]:
|
||||
H["c"][key] = self._read_calendar()
|
||||
_calendar = H["c"][key]
|
||||
else:
|
||||
_calendar = self._read_calendar()
|
||||
if Freq(self._freq_file) != Freq(self.freq):
|
||||
_calendar = resam_calendar(np.array(list(map(pd.Timestamp, _calendar))), self._freq_file, self.freq)
|
||||
return _calendar
|
||||
|
||||
def _get_storage_freq(self) -> List[str]:
|
||||
|
||||
@@ -86,10 +86,61 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
|
||||
return R.get_recorder()
|
||||
|
||||
|
||||
def get_item_from_obj(config: dict, name_path: str) -> object:
|
||||
"""
|
||||
Follow the name_path to get values from config
|
||||
For example:
|
||||
If we follow the example in in the Parameters section,
|
||||
Timestamp('2008-01-02 00:00:00') will be returned
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : dict
|
||||
e.g.
|
||||
{'dataset': {'class': 'DatasetH',
|
||||
'kwargs': {'handler': {'class': 'Alpha158',
|
||||
'kwargs': {'end_time': '2020-08-01',
|
||||
'fit_end_time': '<dataset.kwargs.segments.train.1>',
|
||||
'fit_start_time': '<dataset.kwargs.segments.train.0>',
|
||||
'instruments': 'csi100',
|
||||
'start_time': '2008-01-01'},
|
||||
'module_path': 'qlib.contrib.data.handler'},
|
||||
'segments': {'test': (Timestamp('2017-01-03 00:00:00'),
|
||||
Timestamp('2019-04-08 00:00:00')),
|
||||
'train': (Timestamp('2008-01-02 00:00:00'),
|
||||
Timestamp('2014-12-31 00:00:00')),
|
||||
'valid': (Timestamp('2015-01-05 00:00:00'),
|
||||
Timestamp('2016-12-30 00:00:00'))}}
|
||||
}}
|
||||
name_path : str
|
||||
e.g.
|
||||
"dataset.kwargs.segments.train.1"
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
the retrieved object
|
||||
"""
|
||||
cur_cfg = config
|
||||
for k in name_path.split("."):
|
||||
if isinstance(cur_cfg, dict):
|
||||
cur_cfg = cur_cfg[k]
|
||||
elif k.isdigit():
|
||||
cur_cfg = cur_cfg[int(k)]
|
||||
else:
|
||||
raise ValueError(f"Error when getting {k} from cur_cfg")
|
||||
return cur_cfg
|
||||
|
||||
|
||||
def fill_placeholder(config: dict, config_extend: dict):
|
||||
"""
|
||||
Detect placeholder in config and fill them with config_extend.
|
||||
The item of dict must be single item(int, str, etc), dict and list. Tuples are not supported.
|
||||
There are two type of variables:
|
||||
- user-defined variables :
|
||||
e.g. when config_extend is `{"<MODEL>": model, "<DATASET>": dataset}`, "<MODEL>" and "<DATASET>" in `config` will be replaced with `model` `dataset`
|
||||
- variables extracted from `config` :
|
||||
e.g. the variables like "<dataset.kwargs.segments.train.0>" will be replaced with the values from `config`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -122,8 +173,13 @@ def fill_placeholder(config: dict, config_extend: dict):
|
||||
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
|
||||
item_queue.append(now_item[key])
|
||||
tail += 1
|
||||
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
|
||||
now_item[key] = config_extend[now_item[key]]
|
||||
elif isinstance(now_item[key], str):
|
||||
if now_item[key] in config_extend.keys():
|
||||
now_item[key] = config_extend[now_item[key]]
|
||||
else:
|
||||
m = re.match(r"<(?P<name_path>[^<>]+)>", now_item[key])
|
||||
if m is not None:
|
||||
now_item[key] = get_item_from_obj(config, m.groupdict()["name_path"])
|
||||
return config
|
||||
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ RECORD_CONFIG = [
|
||||
def get_data_handler_config(
|
||||
start_time="2008-01-01",
|
||||
end_time="2020-08-01",
|
||||
fit_start_time="2008-01-01",
|
||||
fit_end_time="2014-12-31",
|
||||
fit_start_time="<dataset.kwargs.segments.train.0>",
|
||||
fit_end_time="<dataset.kwargs.segments.train.1>",
|
||||
instruments=CSI300_MARKET,
|
||||
):
|
||||
return {
|
||||
|
||||
@@ -8,7 +8,7 @@ from . import lazy_sort_index
|
||||
from .time import Freq, cal_sam_minute
|
||||
|
||||
|
||||
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
|
||||
def resam_calendar(calendar_raw: np.ndarray, freq_raw: Union[str, Freq], freq_sam: Union[str, Freq]) -> np.ndarray:
|
||||
"""
|
||||
Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam
|
||||
Assumption:
|
||||
@@ -28,36 +28,36 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
np.ndarray
|
||||
The calendar with frequency freq_sam
|
||||
"""
|
||||
raw_count, freq_raw = Freq.parse(freq_raw)
|
||||
sam_count, freq_sam = Freq.parse(freq_sam)
|
||||
freq_raw = Freq(freq_raw)
|
||||
freq_sam = Freq(freq_sam)
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
|
||||
# if freq_sam is xminute, divide each trading day into several bars evenly
|
||||
if freq_sam == Freq.NORM_FREQ_MINUTE:
|
||||
if freq_raw != Freq.NORM_FREQ_MINUTE:
|
||||
if freq_sam.base == Freq.NORM_FREQ_MINUTE:
|
||||
if freq_raw.base != Freq.NORM_FREQ_MINUTE:
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
if raw_count > sam_count:
|
||||
if freq_raw.count > freq_sam.count:
|
||||
raise ValueError("raw freq must be higher than sampling freq")
|
||||
_calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, sam_count), calendar_raw)))
|
||||
_calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, freq_sam.count), calendar_raw)))
|
||||
return _calendar_minute
|
||||
|
||||
# else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam == Freq.NORM_FREQ_DAY:
|
||||
return _calendar_day[::sam_count]
|
||||
if freq_sam.base == Freq.NORM_FREQ_DAY:
|
||||
return _calendar_day[:: freq_sam.count]
|
||||
|
||||
elif freq_sam == Freq.NORM_FREQ_WEEK:
|
||||
elif freq_sam.base == Freq.NORM_FREQ_WEEK:
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
|
||||
return _calendar_week[::sam_count]
|
||||
return _calendar_week[:: freq_sam.count]
|
||||
|
||||
elif freq_sam == Freq.NORM_FREQ_MONTH:
|
||||
elif freq_sam.base == Freq.NORM_FREQ_MONTH:
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
|
||||
return _calendar_month[::sam_count]
|
||||
return _calendar_month[:: freq_sam.count]
|
||||
else:
|
||||
raise ValueError("sampling freq must be xmin, xd, xw, xm")
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Time related utils are compiled in this script
|
||||
"""
|
||||
import bisect
|
||||
from datetime import datetime, time, date
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import functools
|
||||
import re
|
||||
|
||||
@@ -69,13 +69,29 @@ class Freq:
|
||||
NORM_FREQ_MONTH = "month"
|
||||
NORM_FREQ_WEEK = "week"
|
||||
NORM_FREQ_DAY = "day"
|
||||
NORM_FREQ_MINUTE = "minute"
|
||||
NORM_FREQ_MINUTE = "min" # using min instead of minute for align with Qlib's data filename
|
||||
SUPPORT_CAL_LIST = [NORM_FREQ_MINUTE, NORM_FREQ_DAY] # FIXME: this list should from data
|
||||
|
||||
MIN_CAL = get_min_cal()
|
||||
|
||||
def __init__(self, freq: str) -> None:
|
||||
self.count, self.base = self.parse(freq)
|
||||
def __init__(self, freq: Union[str, "Freq"]) -> None:
|
||||
if isinstance(freq, str):
|
||||
self.count, self.base = self.parse(freq)
|
||||
elif isinstance(freq, Freq):
|
||||
self.count, self.base = freq.count, freq.base
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def __eq__(self, freq):
|
||||
freq = Freq(freq)
|
||||
return freq.count == self.count and freq.base == self.base
|
||||
|
||||
def __str__(self):
|
||||
# trying to align to the filename of Qlib: day, 30min, 5min, 1min...
|
||||
return f"{self.count if self.count != 1 or self.base != 'day' else ''}{self.base}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({str(self)})"
|
||||
|
||||
@staticmethod
|
||||
def parse(freq: str) -> Tuple[int, str]:
|
||||
@@ -159,14 +175,14 @@ class Freq:
|
||||
Freq.NORM_FREQ_WEEK: 7 * 60 * 24,
|
||||
Freq.NORM_FREQ_MONTH: 30 * 7 * 60 * 24,
|
||||
}
|
||||
left_freq = Freq.parse(left_frq)
|
||||
left_minutes = left_freq[0] * minutes_map[left_freq[1]]
|
||||
right_freq = Freq.parse(right_freq)
|
||||
right_minutes = right_freq[0] * minutes_map[right_freq[1]]
|
||||
left_freq = Freq(left_frq)
|
||||
left_minutes = left_freq.count * minutes_map[left_freq.base]
|
||||
right_freq = Freq(right_freq)
|
||||
right_minutes = right_freq.count * minutes_map[right_freq.base]
|
||||
return left_minutes - right_minutes
|
||||
|
||||
@staticmethod
|
||||
def get_recent_freq(base_freq: str, freq_list: List[str]) -> str:
|
||||
def get_recent_freq(base_freq: Union[str, "Freq"], freq_list: List[Union[str, "Freq"]]) -> Optional["Freq"]:
|
||||
"""Get the closest freq to base_freq from freq_list
|
||||
|
||||
Parameters
|
||||
@@ -176,17 +192,22 @@ class Freq:
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
if the recent frequency is found
|
||||
Freq
|
||||
else:
|
||||
None
|
||||
"""
|
||||
base_freq = Freq(base_freq)
|
||||
# use the nearest freq greater than 0
|
||||
_freq_minutes = []
|
||||
min_freq = None
|
||||
for _freq in freq_list:
|
||||
freq = Freq(_freq)
|
||||
_min_delta = Freq.get_min_delta(base_freq, _freq)
|
||||
if _min_delta < 0:
|
||||
continue
|
||||
if min_freq is None:
|
||||
min_freq = (_min_delta, _freq)
|
||||
min_freq = (_min_delta, str(_freq))
|
||||
continue
|
||||
min_freq = min_freq if min_freq[0] <= _min_delta else (_min_delta, _freq)
|
||||
return min_freq[1] if min_freq else None
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# 1min data (Optional for running non-high-frequency strategies)
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
```
|
||||
|
||||
### Download US Data
|
||||
|
||||
60
scripts/data_collector/README.md
Normal file
60
scripts/data_collector/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Data Collector
|
||||
|
||||
## Introduction
|
||||
|
||||
Scripts for data collection
|
||||
|
||||
- yahoo: get *US/CN* stock data from *Yahoo Finance*
|
||||
- fund: get fund data from *http://fund.eastmoney.com*
|
||||
- cn_index: get *CN index* from *http://www.csindex.com.cn*, *CSI300*/*CSI100*
|
||||
- us_index: get *US index* from *https://en.wikipedia.org/wiki*, *SP500*/*NASDAQ100*/*DJIA*/*SP400*
|
||||
- contrib: scripts for some auxiliary functions
|
||||
|
||||
|
||||
## Custom Data Collection
|
||||
|
||||
> Specific implementation reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo
|
||||
|
||||
1. Create a dataset code directory in the current directory
|
||||
2. Add `collector.py`
|
||||
- add collector class:
|
||||
```python
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
class UserCollector(BaseCollector):
|
||||
...
|
||||
```
|
||||
- add normalize class:
|
||||
```python
|
||||
class UserNormalzie(BaseNormalize):
|
||||
...
|
||||
```
|
||||
- add `CLI` class:
|
||||
```python
|
||||
class Run(BaseRun):
|
||||
...
|
||||
```
|
||||
3. add `README.md`
|
||||
4. add `requirements.txt`
|
||||
|
||||
|
||||
## Description of dataset
|
||||
|
||||
| | Basic data |
|
||||
|------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
|
||||
| Features | **Price/Volume**: <br> - $close/$open/$low/$high/$volume/$change/$factor |
|
||||
| Calendar | **\<freq>.txt**: <br> - day.txt<br> - 1min.txt |
|
||||
| Instruments | **\<market>.txt**: <br> - required: **all.txt**; <br> - csi300.txt/csi500.txt/sp500.txt |
|
||||
|
||||
- `Features`: data, **digital**
|
||||
- if not **adjusted**, **factor=1**
|
||||
|
||||
### Data-dependent component
|
||||
|
||||
> To make the component running correctly, the dependent data are required
|
||||
|
||||
| Component | required data |
|
||||
|---------------------------------------------------|--------------------------------|
|
||||
| Data retrieval | Features, Calendar, Instrument |
|
||||
| Backtest | **Features[Price/Volume]**, Calendar, Instruments |
|
||||
@@ -6,13 +6,12 @@ import abc
|
||||
import sys
|
||||
import importlib
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from typing import List, Iterable
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
@@ -22,12 +21,10 @@ from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
NEW_COMPANIES_URL = "https://csi-web-dev.oss-cn-shanghai-finance-1-pub.aliyuncs.com/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
|
||||
|
||||
# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
# 2020-11-27 Announcement title change
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89"
|
||||
INDEX_CHANGES_URL = "https://www.csindex.com.cn/csindex-home/search/search-content?lang=cn&searchInput=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC&pageNum={page_num}&pageSize={page_size}&sortField=date&dateRange=all&contentType=announcement"
|
||||
|
||||
REQ_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48"
|
||||
@@ -55,7 +52,11 @@ class CSIIndex(IndexBase):
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
return get_calendar_list(bench_code=self.index_name.upper())
|
||||
_calendar = getattr(self, "_calendar_list", None)
|
||||
if not _calendar:
|
||||
_calendar = get_calendar_list(bench_code=self.index_name.upper())
|
||||
setattr(self, "_calendar_list", _calendar)
|
||||
return _calendar
|
||||
|
||||
@property
|
||||
def new_companies_url(self) -> str:
|
||||
@@ -135,7 +136,8 @@ class CSIIndex(IndexBase):
|
||||
res = []
|
||||
for _url in self._get_change_notices_url():
|
||||
_df = self._read_change_from_url(_url)
|
||||
res.append(_df)
|
||||
if not _df.empty:
|
||||
res.append(_df)
|
||||
logger.info("get companies changes finish")
|
||||
return pd.concat(res, sort=False)
|
||||
|
||||
@@ -155,6 +157,56 @@ class CSIIndex(IndexBase):
|
||||
symbol = f"{int(symbol):06}"
|
||||
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
|
||||
|
||||
def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.Timestamp) -> pd.DataFrame:
|
||||
content = retry_request(excel_url, exclude_status=[404]).content
|
||||
_io = BytesIO(content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
fp.write(content)
|
||||
tmp = []
|
||||
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df["type"] = _type
|
||||
_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_df)
|
||||
df = pd.concat(tmp)
|
||||
return df
|
||||
|
||||
def _parse_table(self, content: str, add_date: pd.DataFrame, remove_date: pd.DataFrame) -> pd.DataFrame:
|
||||
df = pd.DataFrame()
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(content):
|
||||
if _df.shape[-1] != 4:
|
||||
continue
|
||||
_tmp_count += 1
|
||||
if self.html_table_index + 1 > _tmp_count:
|
||||
continue
|
||||
tmp = []
|
||||
for _s, _type, _date in [
|
||||
(_df.iloc[2:, 0], self.REMOVE, remove_date),
|
||||
(_df.iloc[2:, 2], self.ADD, add_date),
|
||||
]:
|
||||
_tmp_df = pd.DataFrame()
|
||||
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
|
||||
_tmp_df["type"] = _type
|
||||
_tmp_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
df.to_csv(
|
||||
str(
|
||||
self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
|
||||
).resolve()
|
||||
)
|
||||
)
|
||||
break
|
||||
return df
|
||||
|
||||
def _read_change_from_url(self, url: str) -> pd.DataFrame:
|
||||
"""read change from url
|
||||
|
||||
@@ -174,75 +226,60 @@ class CSIIndex(IndexBase):
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
resp = retry_request(url)
|
||||
_text = resp.text
|
||||
resp = retry_request(url).json()["data"]
|
||||
title = resp["title"]
|
||||
if not title.startswith("关于"):
|
||||
return pd.DataFrame()
|
||||
if "沪深300" not in title:
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f"load index data from https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}")
|
||||
_text = resp["content"]
|
||||
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
|
||||
if len(date_list) >= 2:
|
||||
add_date = pd.Timestamp("-".join(date_list[0]))
|
||||
else:
|
||||
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
|
||||
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
|
||||
if "盘后" in _text or "市后" in _text:
|
||||
add_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=1)
|
||||
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
|
||||
logger.info(f"get {add_date} changes")
|
||||
try:
|
||||
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
|
||||
content = retry_request(f"http://www.csindex.com.cn{excel_url}", exclude_status=[404]).content
|
||||
_io = BytesIO(content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
fp.write(content)
|
||||
tmp = []
|
||||
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df["type"] = _type
|
||||
_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_df)
|
||||
df = pd.concat(tmp)
|
||||
except Exception as e:
|
||||
df = None
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(resp.content):
|
||||
if _df.shape[-1] != 4:
|
||||
continue
|
||||
_tmp_count += 1
|
||||
if self.html_table_index + 1 > _tmp_count:
|
||||
continue
|
||||
tmp = []
|
||||
for _s, _type, _date in [
|
||||
(_df.iloc[2:, 0], self.REMOVE, remove_date),
|
||||
(_df.iloc[2:, 2], self.ADD, add_date),
|
||||
]:
|
||||
_tmp_df = pd.DataFrame()
|
||||
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
|
||||
_tmp_df["type"] = _type
|
||||
_tmp_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
df.to_csv(
|
||||
str(
|
||||
self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
|
||||
).resolve()
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
excel_url = None
|
||||
if resp.get("enclosureList", []):
|
||||
excel_url = resp["enclosureList"][0]["fileUrl"]
|
||||
else:
|
||||
excel_url_list = re.findall('.*href="(.*?xls.*?)".*', _text)
|
||||
if excel_url_list:
|
||||
excel_url = excel_url_list[0]
|
||||
if not excel_url.startswith("http"):
|
||||
excel_url = excel_url if excel_url.startswith("/") else "/" + excel_url
|
||||
excel_url = f"http://www.csindex.com.cn{excel_url}"
|
||||
if excel_url:
|
||||
logger.info(f"get {add_date} changes from excel, title={title}, excel_url={excel_url}")
|
||||
try:
|
||||
df = self._parse_excel(excel_url, add_date, remove_date)
|
||||
except ValueError:
|
||||
logger.warning(f"error downloading file: {excel_url}, will parse the table from the content")
|
||||
df = self._parse_table(_text, add_date, remove_date)
|
||||
else:
|
||||
logger.info(f"get {add_date} changes from url content, title={title}")
|
||||
df = self._parse_table(_text, add_date, remove_date)
|
||||
return df
|
||||
|
||||
def _get_change_notices_url(self) -> List[str]:
|
||||
def _get_change_notices_url(self) -> Iterable[str]:
|
||||
"""get change notices url
|
||||
|
||||
Returns
|
||||
-------
|
||||
[url1, url2]
|
||||
"""
|
||||
resp = retry_request(self.changes_url)
|
||||
html = etree.HTML(resp.text)
|
||||
return html.xpath("//*[@id='itemContainer']//li/a/@href")
|
||||
page_num = 1
|
||||
page_size = 5
|
||||
data = retry_request(self.changes_url.format(page_size=page_size, page_num=page_num)).json()
|
||||
data = retry_request(self.changes_url.format(page_size=data["total"], page_num=page_num)).json()
|
||||
for item in data["data"]:
|
||||
yield f"https://www.csindex.com.cn/csindex-home/announcement/queryAnnouncementById?id={item['id']}"
|
||||
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
@@ -270,7 +307,7 @@ class CSIIndex(IndexBase):
|
||||
df = df.iloc[:, [0, 4]]
|
||||
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
||||
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD])
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
|
||||
df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
logger.info("end of get new companies.")
|
||||
return df
|
||||
@@ -287,7 +324,7 @@ class CSI300(CSIIndex):
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 0
|
||||
return 1
|
||||
|
||||
|
||||
class CSI100(CSIIndex):
|
||||
@@ -301,7 +338,7 @@ class CSI100(CSIIndex):
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 1
|
||||
return 2
|
||||
|
||||
|
||||
def get_instruments(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
|
||||
23
setup.py
23
setup.py
@@ -6,6 +6,21 @@ import numpy
|
||||
|
||||
from setuptools import find_packages, setup, Extension
|
||||
|
||||
|
||||
def read(rel_path: str) -> str:
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
with open(os.path.join(here, rel_path), encoding="utf-8") as fp:
|
||||
return fp.read()
|
||||
|
||||
|
||||
def get_version(rel_path: str) -> str:
|
||||
for line in read(rel_path).splitlines():
|
||||
if line.startswith("__version__"):
|
||||
delim = '"' if '"' in line else "'"
|
||||
return line.split(delim)[1]
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
|
||||
# Package meta-data.
|
||||
NAME = "pyqlib"
|
||||
DESCRIPTION = "A Quantitative-research Platform"
|
||||
@@ -14,11 +29,7 @@ REQUIRES_PYTHON = ">=3.5.0"
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
CURRENT_DIR = Path(__file__).absolute().parent
|
||||
_version_src = CURRENT_DIR / "VERSION.txt"
|
||||
_version_dst = CURRENT_DIR / "qlib" / "VERSION.txt"
|
||||
copyfile(_version_src, _version_dst)
|
||||
VERSION = _version_dst.read_text(encoding="utf-8").strip()
|
||||
VERSION = get_version("qlib/__init__.py")
|
||||
|
||||
# Detect Cython
|
||||
try:
|
||||
@@ -47,7 +58,7 @@ REQUIRED = [
|
||||
"python-redis-lock>=3.3.1",
|
||||
"schedule>=0.6.0",
|
||||
"cvxpy>=1.0.21",
|
||||
"hyperopt==0.1.1",
|
||||
"hyperopt==0.1.2",
|
||||
"fire>=0.3.1",
|
||||
"statsmodels",
|
||||
"xlrd>=1.0.0",
|
||||
|
||||
Reference in New Issue
Block a user