mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 10:01:19 +08:00
Merge remote-tracking branch 'origin/main' into main
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
|
||||
- name: Test data downloads and examples
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
# cd examples
|
||||
# estimator -c estimator/estimator_config.yaml
|
||||
# jupyter nbconvert --execute estimator/analyze_from_estimator.ipynb --to html
|
||||
@@ -91,7 +91,7 @@ Also, users can install ``Qlib`` by the source code according to the following s
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
```bash
|
||||
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
```
|
||||
|
||||
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in
|
||||
|
||||
@@ -34,7 +34,7 @@ Qlib Format Dataset
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
After running the above command, users can find china-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory.
|
||||
|
||||
@@ -59,7 +59,7 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
|
||||
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
|
||||
|
||||
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ Load and prepare data by running the following code:
|
||||
|
||||
.. code-block::
|
||||
|
||||
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
This dataset is created by public data collected by crawler scripts in ``scripts/data_collector/``, which have been released in the same repository. Users could create the same dataset with it.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ Please follow the steps below to initialize ``Qlib``.
|
||||
- Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>` for more information about customized dataset.
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
@@ -219,4 +219,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,7 @@ if __name__ == "__main__":
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(target_dir=provider_uri)
|
||||
GetData().qlib_data(target_dir=provider_uri)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
|
||||
@@ -31,13 +31,13 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# use default data\n",
|
||||
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
|
||||
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n",
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"if not exists_qlib_data(provider_uri):\n",
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
@@ -335,4 +335,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
61
scripts/README.md
Normal file
61
scripts/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
|
||||
- [Download Qlib Data](#Download-Qlib-Data)
|
||||
- [Download CN Data](#Download-CN-Data)
|
||||
- [Downlaod US Data](#Downlaod-US-Data)
|
||||
- [Download CN Simple Data](#Download-CN-Simple-Data)
|
||||
- [Help](#Help)
|
||||
- [Using in Qlib](#Using-in-Qlib)
|
||||
- [US data](#US-data)
|
||||
- [CN data](#CN-data)
|
||||
|
||||
|
||||
## Download Qlib Data
|
||||
|
||||
|
||||
### Download CN Data
|
||||
|
||||
```bash
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
```
|
||||
|
||||
### Downlaod US Data
|
||||
|
||||
> The US stock code contains 'PRN', and the directory cannot be created on Windows system: https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
||||
|
||||
```bash
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
```
|
||||
|
||||
### Download CN Simple Data
|
||||
|
||||
```bash
|
||||
python get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
```
|
||||
|
||||
### Help
|
||||
|
||||
```bash
|
||||
python get_data.py qlib_data --help
|
||||
```
|
||||
|
||||
## Using in Qlib
|
||||
> For more information: https://qlib.readthedocs.io/en/latest/start/initialization.html
|
||||
|
||||
|
||||
### US data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.config import REG_US
|
||||
provider_uri = "~/.qlib/qlib_data/us_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_US)
|
||||
```
|
||||
|
||||
### CN data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
```
|
||||
144
scripts/check_dump_bin.py
Normal file
144
scripts/check_dump_bin.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
import fire
|
||||
import datacompy
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CheckBin:
|
||||
|
||||
NOT_IN_FEATURES = "not in features"
|
||||
COMPARE_FALSE = "compare False"
|
||||
COMPARE_TRUE = "compare True"
|
||||
COMPARE_ERROR = "compare error"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qlib_dir: str,
|
||||
csv_path: str,
|
||||
check_fields: str = None,
|
||||
freq: str = "day",
|
||||
symbol_field_name: str = "symbol",
|
||||
date_field_name: str = "date",
|
||||
file_suffix: str = ".csv",
|
||||
max_workers: int = 16,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir : str
|
||||
qlib dir
|
||||
csv_path : str
|
||||
origin csv path
|
||||
check_fields : str, optional
|
||||
check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
|
||||
freq : str, optional
|
||||
freq, value from ["day", "1m"]
|
||||
symbol_field_name: str, optional
|
||||
symbol field name, by default "symbol"
|
||||
date_field_name: str, optional
|
||||
date field name, by default "date"
|
||||
file_suffix: str, optional
|
||||
csv file suffix, by default ".csv"
|
||||
max_workers: int, optional
|
||||
max workers, by default 16
|
||||
"""
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
|
||||
self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))
|
||||
qlib.init(
|
||||
provider_uri=str(self.qlib_dir.resolve()),
|
||||
mount_path=str(self.qlib_dir.resolve()),
|
||||
auto_mount=False,
|
||||
redis_port=-1,
|
||||
)
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
|
||||
if check_fields is None:
|
||||
check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin")))
|
||||
else:
|
||||
check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
|
||||
self.check_fields = list(map(lambda x: x.strip(), check_fields))
|
||||
self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
|
||||
self.max_workers = max_workers
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.date_field_name = date_field_name
|
||||
self.freq = freq
|
||||
self.file_suffix = file_suffix
|
||||
|
||||
def _compare(self, file_path: Path):
|
||||
symbol = file_path.name.strip(self.file_suffix)
|
||||
if symbol.lower() not in self.qlib_symbols:
|
||||
return self.NOT_IN_FEATURES
|
||||
# qlib data
|
||||
qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq)
|
||||
qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True)
|
||||
# csv data
|
||||
origin_df = pd.read_csv(file_path)
|
||||
origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name])
|
||||
if self.symbol_field_name not in origin_df.columns:
|
||||
origin_df[self.symbol_field_name] = symbol
|
||||
origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)
|
||||
origin_df.index.names = qlib_df.index.names
|
||||
try:
|
||||
compare = datacompy.Compare(
|
||||
origin_df,
|
||||
qlib_df,
|
||||
on_index=True,
|
||||
abs_tol=1e-08, # Optional, defaults to 0
|
||||
rel_tol=1e-05, # Optional, defaults to 0
|
||||
df1_name="Original", # Optional, defaults to 'df1'
|
||||
df2_name="New", # Optional, defaults to 'df2'
|
||||
)
|
||||
_r = compare.matches(ignore_extra_columns=True)
|
||||
return self.COMPARE_TRUE if _r else self.COMPARE_FALSE
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol} compare error: {e}")
|
||||
return self.COMPARE_ERROR
|
||||
|
||||
def check(self):
|
||||
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data
|
||||
|
||||
"""
|
||||
logger.info("start check......")
|
||||
|
||||
error_list = []
|
||||
not_in_features = []
|
||||
compare_false = []
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)):
|
||||
symbol = file_path.name.strip(self.file_suffix)
|
||||
if _check_res == self.NOT_IN_FEATURES:
|
||||
not_in_features.append(symbol)
|
||||
elif _check_res == self.COMPARE_ERROR:
|
||||
error_list.append(symbol)
|
||||
elif _check_res == self.COMPARE_FALSE:
|
||||
compare_false.append(symbol)
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of check......")
|
||||
if error_list:
|
||||
logger.warning(f"compare error: {error_list}")
|
||||
if not_in_features:
|
||||
logger.warning(f"not in features: {not_in_features}")
|
||||
if compare_false:
|
||||
logger.warning(f"compare False: {compare_false}")
|
||||
logger.info(
|
||||
f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(CheckBin)
|
||||
22
scripts/data_collector/cn_index/README.md
Normal file
22
scripts/data_collector/cn_index/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# CSI300/CSI100 History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
# index_name support: CSI300, CSI100
|
||||
# help
|
||||
python collector.py --help
|
||||
```
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
import re
|
||||
import abc
|
||||
import sys
|
||||
import bisect
|
||||
import importlib
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
@@ -16,7 +17,9 @@ from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
@@ -24,64 +27,48 @@ NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
|
||||
class CSIIndex:
|
||||
|
||||
REMOVE = "remove"
|
||||
ADD = "add"
|
||||
|
||||
def __init__(self, qlib_dir=None):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
"""
|
||||
|
||||
if qlib_dir is None:
|
||||
qlib_dir = CUR_DIR.joinpath("qlib_data")
|
||||
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
|
||||
self.instruments_dir.mkdir(exist_ok=True, parents=True)
|
||||
self._calendar_list = None
|
||||
|
||||
self.cache_dir = Path("~/.cache/csi").expanduser().resolve()
|
||||
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
class CSIIndex(IndexBase):
|
||||
@property
|
||||
def calendar_list(self) -> list:
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
return get_calendar_list(bench_code=self.index_name.upper())
|
||||
|
||||
@property
|
||||
def new_companies_url(self):
|
||||
def new_companies_url(self) -> str:
|
||||
return NEW_COMPANIES_URL.format(index_code=self.index_code)
|
||||
|
||||
@property
|
||||
def changes_url(self):
|
||||
def changes_url(self) -> str:
|
||||
return INDEX_CHANGES_URL
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
raise NotImplementedError()
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_code(self):
|
||||
raise NotImplementedError()
|
||||
def index_code(self) -> str:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index code
|
||||
"""
|
||||
raise NotImplementedError("rewrite index_code")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def html_table_index(self):
|
||||
def html_table_index(self) -> int:
|
||||
"""Which table of changes in html
|
||||
|
||||
CSI300: 0
|
||||
@@ -90,33 +77,19 @@ class CSIIndex:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
|
||||
"""get trading date by shift
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shift : int
|
||||
shift, default is 1
|
||||
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
left_index = bisect.bisect_left(self.calendar_list, trading_date)
|
||||
try:
|
||||
res = self.calendar_list[left_index + shift]
|
||||
except IndexError:
|
||||
res = trading_date
|
||||
return res
|
||||
|
||||
def _get_changes(self) -> pd.DataFrame:
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
logger.info("get companies changes......")
|
||||
res = []
|
||||
@@ -124,10 +97,21 @@ class CSIIndex:
|
||||
_df = self._read_change_from_url(_url)
|
||||
res.append(_df)
|
||||
logger.info("get companies changes finish")
|
||||
return pd.concat(res)
|
||||
return pd.concat(res, sort=False)
|
||||
|
||||
@staticmethod
|
||||
def normalize_symbol(symbol):
|
||||
def normalize_symbol(symbol: str) -> str:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
|
||||
Returns
|
||||
-------
|
||||
symbol
|
||||
"""
|
||||
symbol = f"{int(symbol):06}"
|
||||
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
|
||||
|
||||
@@ -141,7 +125,14 @@ class CSIIndex:
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
resp = requests.get(url)
|
||||
_text = resp.text
|
||||
@@ -151,8 +142,8 @@ class CSIIndex:
|
||||
add_date = pd.Timestamp("-".join(date_list[0]))
|
||||
else:
|
||||
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
|
||||
add_date = self._get_trading_date_by_shift(_date, shift=0)
|
||||
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
|
||||
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
|
||||
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
|
||||
logger.info(f"get {add_date} changes")
|
||||
try:
|
||||
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
|
||||
@@ -168,12 +159,12 @@ class CSIIndex:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = ["symbol"]
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df["type"] = _type
|
||||
_df["date"] = _date
|
||||
_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_df)
|
||||
df = pd.concat(tmp)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
df = None
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(resp.content):
|
||||
@@ -188,9 +179,9 @@ class CSIIndex:
|
||||
(_df.iloc[2:, 2], self.ADD, add_date),
|
||||
]:
|
||||
_tmp_df = pd.DataFrame()
|
||||
_tmp_df["symbol"] = _s.map(self.normalize_symbol)
|
||||
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
|
||||
_tmp_df["type"] = _type
|
||||
_tmp_df["date"] = _date
|
||||
_tmp_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
df.to_csv(
|
||||
@@ -203,20 +194,33 @@ class CSIIndex:
|
||||
break
|
||||
return df
|
||||
|
||||
def _get_change_notices_url(self) -> list:
|
||||
def _get_change_notices_url(self) -> List[str]:
|
||||
"""get change notices url
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
[url1, url2]
|
||||
"""
|
||||
resp = requests.get(self.changes_url)
|
||||
html = etree.HTML(resp.text)
|
||||
return html.xpath("//*[@id='itemContainer']//li/a/@href")
|
||||
|
||||
def _get_new_companies(self):
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
logger.info("get new companies")
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
logger.info("get new companies......")
|
||||
context = requests.get(self.new_companies_url).content
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
|
||||
@@ -225,51 +229,19 @@ class CSIIndex:
|
||||
_io = BytesIO(context)
|
||||
df = pd.read_excel(_io)
|
||||
df = df.iloc[:, [0, 4]]
|
||||
df.columns = ["end_date", "symbol"]
|
||||
df["symbol"] = df["symbol"].map(self.normalize_symbol)
|
||||
df["end_date"] = pd.to_datetime(df["end_date"])
|
||||
df["start_date"] = self.bench_start_date
|
||||
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
||||
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD])
|
||||
df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
logger.info("end of get new companies.")
|
||||
return df
|
||||
|
||||
def parse_instruments(self):
|
||||
"""parse csi300.txt
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = ["symbol", "start_date", "end_date"]
|
||||
changers_df = self._get_changes()
|
||||
new_df = self._get_new_companies()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in changers_df.sort_values("date", ascending=False).itertuples(index=False):
|
||||
if _row.type == self.ADD:
|
||||
min_end_date = new_df.loc[new_df["symbol"] == _row.symbol, "end_date"].min()
|
||||
new_df.loc[
|
||||
(new_df["end_date"] == min_end_date) & (new_df["symbol"] == _row.symbol), "start_date"
|
||||
] = _row.date
|
||||
else:
|
||||
_tmp_df = pd.DataFrame(
|
||||
[[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"]
|
||||
)
|
||||
new_df = new_df.append(_tmp_df, sort=False)
|
||||
|
||||
new_df.loc[:, instruments_columns].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
logger.info(f"parse {self.index_name.lower()} companies finished.")
|
||||
|
||||
|
||||
class CSI300(CSIIndex):
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000300"
|
||||
|
||||
@property
|
||||
def index_name(self):
|
||||
return "csi300"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2005-01-01")
|
||||
@@ -284,10 +256,6 @@ class CSI100(CSIIndex):
|
||||
def index_code(self):
|
||||
return "000903"
|
||||
|
||||
@property
|
||||
def index_name(self):
|
||||
return "csi100"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2006-05-29")
|
||||
@@ -297,19 +265,39 @@ class CSI100(CSIIndex):
|
||||
return 1
|
||||
|
||||
|
||||
def parse_instruments(qlib_dir: str):
|
||||
def get_instruments(
|
||||
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
index_name: str
|
||||
index name, value from ["csi100", "csi300"]
|
||||
method: str
|
||||
method, value from ["parse_instruments", "save_new_companies"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
qlib_dir.mkdir(exist_ok=True, parents=True)
|
||||
CSI300(qlib_dir).parse_instruments()
|
||||
CSI100(qlib_dir).parse_instruments()
|
||||
_cur_module = importlib.import_module("collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(parse_instruments)
|
||||
fire.Fire(get_instruments)
|
||||
@@ -1,14 +0,0 @@
|
||||
# CSI300 History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
python collector.py parse_instruments --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
```
|
||||
|
||||
202
scripts/data_collector/index.py
Normal file
202
scripts/data_collector/index.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import sys
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent))
|
||||
|
||||
|
||||
from data_collector.utils import get_trading_date_by_shift
|
||||
|
||||
|
||||
class IndexBase:
|
||||
DEFAULT_END_DATE = pd.Timestamp("2099-12-31")
|
||||
SYMBOL_FIELD_NAME = "symbol"
|
||||
DATE_FIELD_NAME = "date"
|
||||
START_DATE_FIELD = "start_date"
|
||||
END_DATE_FIELD = "end_ate"
|
||||
CHANGE_TYPE_FIELD = "type"
|
||||
INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD]
|
||||
REMOVE = "remove"
|
||||
ADD = "add"
|
||||
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_name: str
|
||||
index name
|
||||
qlib_dir: str
|
||||
qlib directory, by default Path(__file__).resolve().parent.joinpath("qlib_data")
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
"""
|
||||
self.index_name = index_name
|
||||
if qlib_dir is None:
|
||||
qlib_dir = Path(__file__).resolve().parent.joinpath("qlib_data")
|
||||
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
|
||||
self.instruments_dir.mkdir(exist_ok=True, parents=True)
|
||||
self.cache_dir = Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve()
|
||||
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
self._request_retry = request_retry
|
||||
self._retry_sleep = retry_sleep
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
raise NotImplementedError("rewrite calendar_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_new_companies")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_changes")
|
||||
|
||||
def save_new_companies(self):
|
||||
"""save new companies
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
df = self.get_new_companies()
|
||||
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
|
||||
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
|
||||
def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame:
|
||||
"""get changes with history companies
|
||||
|
||||
Parameters
|
||||
----------
|
||||
history_companies : pd.DataFrame
|
||||
symbol date
|
||||
SH600000 2020-11-11
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
|
||||
Return
|
||||
--------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
|
||||
"""
|
||||
logger.info("parse changes from history companies......")
|
||||
last_code = []
|
||||
result_df_list = []
|
||||
_columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD]
|
||||
for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)):
|
||||
_currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][
|
||||
self.SYMBOL_FIELD_NAME
|
||||
].tolist()
|
||||
if last_code:
|
||||
add_code = list(set(last_code) - set(_currenet_code))
|
||||
remote_code = list(set(_currenet_code) - set(last_code))
|
||||
for _code in add_code:
|
||||
result_df_list.append(
|
||||
pd.DataFrame(
|
||||
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]],
|
||||
columns=_columns,
|
||||
)
|
||||
)
|
||||
for _code in remote_code:
|
||||
result_df_list.append(
|
||||
pd.DataFrame(
|
||||
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]],
|
||||
columns=_columns,
|
||||
)
|
||||
)
|
||||
last_code = _currenet_code
|
||||
df = pd.concat(result_df_list)
|
||||
logger.info("end of parse changes from history companies.")
|
||||
return df
|
||||
|
||||
def parse_instruments(self):
|
||||
"""parse instruments, eg: csi300.txt
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
changers_df = self.get_changes()
|
||||
new_df = self.get_new_companies().copy()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
|
||||
if _row.type == self.ADD:
|
||||
min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min()
|
||||
new_df.loc[
|
||||
(new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol),
|
||||
self.START_DATE_FIELD,
|
||||
] = _row.date
|
||||
else:
|
||||
_tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)
|
||||
new_df = new_df.append(_tmp_df, sort=False)
|
||||
|
||||
new_df.loc[:, instruments_columns].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
logger.info(f"parse {self.index_name.lower()} companies finished.")
|
||||
22
scripts/data_collector/us_index/README.md
Normal file
22
scripts/data_collector/us_index/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# NASDAQ100/SP500/SP400/DJIA History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
# index_name support: SP500, NASDAQ100, DJIA, SP400
|
||||
# help
|
||||
python collector.py --help
|
||||
```
|
||||
|
||||
278
scripts/data_collector/us_index/collector.py
Normal file
278
scripts/data_collector/us_index/collector.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
|
||||
|
||||
|
||||
WIKI_URL = "https://en.wikipedia.org/wiki"
|
||||
|
||||
WIKI_INDEX_NAME_MAP = {
|
||||
"NASDAQ100": "NASDAQ-100",
|
||||
"SP500": "List_of_S%26P_500_companies",
|
||||
"SP400": "List_of_S%26P_400_companies",
|
||||
"DJIA": "Dow_Jones_Industrial_Average",
|
||||
}
|
||||
|
||||
|
||||
class WIKIIndex(IndexBase):
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
super(WIKIIndex, self).__init__(
|
||||
index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_changes")
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
_calendar_list = getattr(self, "_calendar_list", None)
|
||||
if _calendar_list is None:
|
||||
_calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL")))
|
||||
setattr(self, "_calendar_list", _calendar_list)
|
||||
return _calendar_list
|
||||
|
||||
def _request_new_companies(self) -> requests.Response:
|
||||
resp = requests.get(self._target_url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {self._target_url}")
|
||||
|
||||
return resp
|
||||
|
||||
def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
_df = df.copy()
|
||||
_df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip()
|
||||
_df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
_df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE
|
||||
return _df.loc[:, self.INSTRUMENTS_COLUMNS]
|
||||
|
||||
def get_new_companies(self):
|
||||
logger.info(f"get new companies {self.index_name} ......")
|
||||
_data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()
|
||||
df_list = pd.read_html(_data.text)
|
||||
for _df in df_list:
|
||||
_df = self.filter_df(_df)
|
||||
if (_df is not None) and (not _df.empty):
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df = self.set_default_date_range(_df)
|
||||
logger.info(f"end of get new companies {self.index_name} ......")
|
||||
return _df
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
raise NotImplementedError("rewrite filter_df")
|
||||
|
||||
|
||||
class NASDAQ100Index(WIKIIndex):
|
||||
|
||||
HISTORY_COMPANIES_URL = (
|
||||
"https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD"
|
||||
)
|
||||
MAX_WORKERS = 16
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if not (set(df.columns) - {"Company", "Ticker"}):
|
||||
return df.loc[:, ["Ticker"]].copy()
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2003-01-02")
|
||||
|
||||
@deco_retry
|
||||
def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame:
|
||||
trade_date = trade_date.strftime("%Y-%m-%d")
|
||||
cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl")
|
||||
if cache_path.exists() and use_cache:
|
||||
df = pd.read_pickle(cache_path)
|
||||
else:
|
||||
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
|
||||
resp = requests.post(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {url}")
|
||||
df = pd.DataFrame(resp.json()["aaData"])
|
||||
df[self.DATE_FIELD_NAME] = trade_date
|
||||
df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True)
|
||||
if not df.empty:
|
||||
df.to_pickle(cache_path)
|
||||
return df
|
||||
|
||||
def get_history_companies(self):
|
||||
logger.info(f"start get history companies......")
|
||||
all_history = []
|
||||
error_list = []
|
||||
with tqdm(total=len(self.calendar_list)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
||||
for _trading_date, _df in zip(
|
||||
self.calendar_list, executor.map(self._request_history_companies, self.calendar_list)
|
||||
):
|
||||
if _df.empty:
|
||||
error_list.append(_trading_date)
|
||||
else:
|
||||
all_history.append(_df)
|
||||
p_bar.update()
|
||||
|
||||
if error_list:
|
||||
logger.warning(f"get error: {error_list}")
|
||||
logger.info(f"total {len(self.calendar_list)}, error {len(error_list)}")
|
||||
logger.info(f"end of get history companies.")
|
||||
return pd.concat(all_history, sort=False)
|
||||
|
||||
def get_changes(self):
|
||||
return self.get_changes_with_history_companies(self.get_history_companies())
|
||||
|
||||
|
||||
class DJIAIndex(WIKIIndex):
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2000-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
pass
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Symbol" in df.columns:
|
||||
_df = df.loc[:, ["Symbol"]].copy()
|
||||
_df["Symbol"] = _df["Symbol"].apply(lambda x: x.split(":")[-1])
|
||||
return _df
|
||||
|
||||
def parse_instruments(self):
|
||||
logger.warning(f"No suitable data source has been found!")
|
||||
|
||||
|
||||
class SP500Index(WIKIIndex):
|
||||
WIKISP500_CHANGES_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("1999-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
logger.info(f"get sp500 history changes......")
|
||||
# NOTE: may update the index of the table
|
||||
changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1]
|
||||
changes_df = changes_df.iloc[:, [0, 1, 3]]
|
||||
changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]
|
||||
changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])
|
||||
_result = []
|
||||
for _type in [self.ADD, self.REMOVE]:
|
||||
_df = changes_df.copy()
|
||||
_df[self.CHANGE_TYPE_FIELD] = _type
|
||||
_df[self.SYMBOL_FIELD_NAME] = _df[_type]
|
||||
_df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True)
|
||||
if _type == self.ADD:
|
||||
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
|
||||
lambda x: get_trading_date_by_shift(self.calendar_list, x, 0)
|
||||
)
|
||||
else:
|
||||
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
|
||||
lambda x: get_trading_date_by_shift(self.calendar_list, x, -1)
|
||||
)
|
||||
_result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]])
|
||||
logger.info(f"end of get sp500 history changes.")
|
||||
return pd.concat(_result, sort=False)
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Symbol" in df.columns:
|
||||
return df.loc[:, ["Symbol"]].copy()
|
||||
|
||||
|
||||
class SP400Index(WIKIIndex):
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2000-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
pass
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Ticker symbol" in df.columns:
|
||||
return df.loc[:, ["Ticker symbol"]].copy()
|
||||
|
||||
def parse_instruments(self):
|
||||
logger.warning(f"No suitable data source has been found!")
|
||||
|
||||
|
||||
def get_instruments(
|
||||
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
index_name: str
|
||||
index name, value from ["SP500", "NASDAQ100", "DJIA", "SP400"]
|
||||
method: str
|
||||
method, value from ["parse_instruments", "save_new_companies"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(get_instruments)
|
||||
6
scripts/data_collector/us_index/requirements.txt
Normal file
6
scripts/data_collector/us_index/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
@@ -3,56 +3,69 @@
|
||||
|
||||
import re
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
|
||||
SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231"
|
||||
|
||||
CALENDAR_BENCH_URL_MAP = {
|
||||
"CSI300": CALENDAR_URL_BASE.format(bench_code="000300"),
|
||||
"CSI100": CALENDAR_URL_BASE.format(bench_code="000903"),
|
||||
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
|
||||
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
|
||||
# NOTE: Use the time series of SH600000 as the sequence of all stocks
|
||||
"ALL": CALENDAR_URL_BASE.format(bench_code="600000"),
|
||||
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
|
||||
"US_ALL": "^GSPC",
|
||||
}
|
||||
|
||||
|
||||
_BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
_US_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
# NOTE: Until 2020-10-20 20:00:00
|
||||
MINIMUM_SYMBOLS_NUM = 3900
|
||||
|
||||
|
||||
def get_hs_calendar_list(bench_code="CSI300") -> list:
|
||||
def get_calendar_list(bench_code="CSI300") -> list:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bench_code: str
|
||||
value from ["CSI300", "CSI500", "ALL"]
|
||||
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
if calendar is None:
|
||||
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
|
||||
if bench_code.startswith("US_"):
|
||||
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
|
||||
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
|
||||
else:
|
||||
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
|
||||
_CALENDAR_MAP[bench_code] = calendar
|
||||
logger.info(f"end of get calendar list: {bench_code}.")
|
||||
return calendar
|
||||
|
||||
|
||||
@@ -68,13 +81,14 @@ def get_hs_stock_symbols() -> list:
|
||||
def _get_symbol():
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
|
||||
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k))
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
)
|
||||
)
|
||||
time.sleep(3)
|
||||
return _res
|
||||
|
||||
if _HS_SYMBOLS is None:
|
||||
@@ -99,6 +113,84 @@ def get_hs_stock_symbols() -> list:
|
||||
return _HS_SYMBOLS
|
||||
|
||||
|
||||
def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get US stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _US_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
return _symbols
|
||||
|
||||
@deco_retry
|
||||
def _get_nasdaq():
|
||||
_res_symbols = []
|
||||
for _name in ["otherlisted", "nasdaqtraded"]:
|
||||
url = f"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt"
|
||||
df = pd.read_csv(url, sep="|")
|
||||
df = df.rename(columns={"ACT Symbol": "Symbol"})
|
||||
_symbols = df["Symbol"].dropna()
|
||||
_symbols = _symbols.str.replace("$", "-P", regex=False)
|
||||
_symbols = _symbols.str.replace(".W", "-WT", regex=False)
|
||||
_symbols = _symbols.str.replace(".U", "-UN", regex=False)
|
||||
_symbols = _symbols.str.replace(".R", "-RI", regex=False)
|
||||
_symbols = _symbols.str.replace(".", "-", regex=False)
|
||||
_res_symbols += _symbols.unique().tolist()
|
||||
return _res_symbols
|
||||
|
||||
@deco_retry
|
||||
def _get_nyse():
|
||||
url = "https://www.nyse.com/api/quotes/filter"
|
||||
_parms = {
|
||||
"instrumentType": "EQUITY",
|
||||
"pageNumber": 1,
|
||||
"sortColumn": "NORMALIZED_TICKER",
|
||||
"sortOrder": "ASC",
|
||||
"maxResultsPerPage": 10000,
|
||||
"filterToken": "",
|
||||
}
|
||||
resp = requests.post(url, json=_parms)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
_symbols = []
|
||||
return _symbols
|
||||
|
||||
if _US_SYMBOLS is None:
|
||||
_all_symbols = _get_eastmoney() + _get_nasdaq() + _get_nyse()
|
||||
if qlib_data_path is not None:
|
||||
for _index in ["nasdaq100", "sp500"]:
|
||||
ins_df = pd.read_csv(
|
||||
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
|
||||
sep="\t",
|
||||
names=["symbol", "start_date", "end_date"],
|
||||
)
|
||||
_all_symbols += ins_df["symbol"].unique().tolist()
|
||||
_US_SYMBOLS = sorted(
|
||||
set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))
|
||||
)
|
||||
|
||||
return _US_SYMBOLS
|
||||
|
||||
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
@@ -137,5 +229,52 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
def deco_func(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
_retry = 5 if callable(retry) else retry
|
||||
_result = None
|
||||
for _i in range(1, _retry + 1):
|
||||
try:
|
||||
_result = func(*args, **kwargs)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{func.__name__}: {_i} :{e}")
|
||||
if _i == _retry:
|
||||
raise
|
||||
time.sleep(retry_sleep)
|
||||
return _result
|
||||
|
||||
return wrapper
|
||||
|
||||
return deco_func(retry) if callable(retry) else deco_func
|
||||
|
||||
|
||||
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
|
||||
"""get trading date by shift
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trading_list: list
|
||||
trading calendar list
|
||||
shift : int
|
||||
shift, default is 1
|
||||
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
trading_date = pd.Timestamp(trading_date)
|
||||
left_index = bisect.bisect_left(trading_list, trading_date)
|
||||
try:
|
||||
res = trading_list[left_index + shift]
|
||||
except IndexError:
|
||||
res = trading_date
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -18,31 +18,29 @@ pip install -r requirements.txt
|
||||
|
||||
## Collector Data
|
||||
|
||||
### Download data -> Normalize data -> Dump data
|
||||
### Download data and Normalize data
|
||||
```bash
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
```
|
||||
|
||||
### Download Data From Yahoo Finance
|
||||
### Download Data
|
||||
|
||||
```bash
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
```
|
||||
|
||||
### Normalize Yahoo Finance Data
|
||||
### Normalize Data
|
||||
|
||||
```bash
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN
|
||||
```
|
||||
|
||||
### Manual Ajust Yahoo Finance Data
|
||||
|
||||
### Help
|
||||
```bash
|
||||
python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
|
||||
pythono collector.py collector_data --help
|
||||
```
|
||||
|
||||
### Dump Yahoo Finance Data
|
||||
## Parameters
|
||||
|
||||
```bash
|
||||
python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
```
|
||||
- interval: 1m or 1d
|
||||
- region: CN or US
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
@@ -13,33 +17,103 @@ import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from dump_bin import DumpData
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
MIN_NUMBERS_TRADING = 252 / 4
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
REGION_CN = "CN"
|
||||
REGION_US = "US"
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=False, max_collector_count=5, delay=0):
|
||||
START_DATETIME = pd.Timestamp("2000-01-01")
|
||||
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 5))
|
||||
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=5,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 5
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1m, 1d], default 1m
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._delay = delay
|
||||
self._stock_list = None
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
self.max_workers = max_workers
|
||||
self._asynchronous = asynchronous
|
||||
self._max_collector_count = max_collector_count
|
||||
self._mini_symbol_map = {}
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
|
||||
self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME
|
||||
if self._interval == "1m":
|
||||
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == "1d":
|
||||
self._start_datetime = max(self._start_datetime, self.START_DATETIME)
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
self._start_datetime = self.convert_datetime(self._start_datetime)
|
||||
self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME))
|
||||
|
||||
@property
|
||||
def stock_list(self):
|
||||
if self._stock_list is None:
|
||||
self._stock_list = get_hs_stock_symbols()
|
||||
return self._stock_list
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewirte min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewirte get_stock_list")
|
||||
|
||||
@property
|
||||
@abc.abstractclassmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def convert_datetime(self, dt: pd.Timestamp):
|
||||
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
|
||||
return pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
@@ -57,63 +131,95 @@ class YahooCollector:
|
||||
if df.empty:
|
||||
raise ValueError("df is empty")
|
||||
|
||||
symbol_s = symbol.split(".")
|
||||
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
df.to_csv(stock_path, index=False)
|
||||
if stock_path.exists():
|
||||
with stock_path.open("a") as fp:
|
||||
df.to_csv(fp, index=False, header=None)
|
||||
else:
|
||||
with stock_path.open("w") as fp:
|
||||
df.to_csv(fp, index=False)
|
||||
|
||||
def _temp_save_small_data(self, symbol, df):
|
||||
if len(df) <= MIN_NUMBERS_TRADING:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {MIN_NUMBERS_TRADING}!")
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self._mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return None
|
||||
else:
|
||||
if symbol in self._mini_symbol_map:
|
||||
self._mini_symbol_map.pop(symbol)
|
||||
return symbol
|
||||
|
||||
def _get_from_remote(self, symbol):
|
||||
def _get_simple(start_, end_):
|
||||
self._sleep()
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
else:
|
||||
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}")
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}")
|
||||
|
||||
_result = None
|
||||
if self._interval == "1d":
|
||||
_result = _get_simple(self._start_datetime, self._end_datetime)
|
||||
elif self._interval == "1m":
|
||||
_start_date = self._start_datetime.date() + pd.Timedelta(days=1)
|
||||
_end_date = self._end_datetime.date()
|
||||
if _start_date >= _end_date:
|
||||
_result = _get_simple(self._start_datetime, self._end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(_start_date, _end_date, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
self._sleep()
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self._interval}")
|
||||
return _result
|
||||
|
||||
def _get_data(self, symbol):
|
||||
_result = None
|
||||
df = self._get_from_remote(symbol)
|
||||
if isinstance(df, pd.DataFrame):
|
||||
if not df.empty:
|
||||
if self._check_small_data:
|
||||
if self._save_small_data(symbol, df) is not None:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
else:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
return _result
|
||||
|
||||
def _collector(self, stock_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
futures = {}
|
||||
p_bar = tqdm(total=len(stock_list))
|
||||
for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]:
|
||||
self._sleep()
|
||||
resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history(
|
||||
period="max"
|
||||
)
|
||||
if isinstance(resp, dict):
|
||||
for symbol, df in resp.items():
|
||||
if isinstance(df, pd.DataFrame):
|
||||
self._temp_save_small_data(self, df)
|
||||
futures[
|
||||
worker.submit(
|
||||
self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"})
|
||||
)
|
||||
] = symbol
|
||||
else:
|
||||
error_symbol.append(symbol)
|
||||
else:
|
||||
for symbol, df in resp.reset_index().groupby("symbol"):
|
||||
self._temp_save_small_data(self, df)
|
||||
futures[worker.submit(self.save_stock, symbol, df)] = symbol
|
||||
p_bar.update(self.max_workers)
|
||||
p_bar.close()
|
||||
|
||||
with tqdm(total=len(futures.values())) as p_bar:
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
error_symbol.append(futures[future])
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)):
|
||||
if _result is None:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self._mini_symbol_map.keys())
|
||||
return error_symbol
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
@@ -126,81 +232,140 @@ class YahooCollector:
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self._mini_symbol_map.items():
|
||||
self.save_stock(_symbol, max(_df_list, key=len))
|
||||
|
||||
logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self._mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
self.download_index_data()
|
||||
|
||||
@abc.abstractmethod
|
||||
def download_index_data(self):
|
||||
"""download index data"""
|
||||
raise NotImplementedError("rewrite download_index_data")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
|
||||
class YahooCollectorCN(YahooCollector):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
if self._interval == "1m":
|
||||
return 60 * 4 * 5
|
||||
elif self._interval == "1d":
|
||||
return 252 / 4
|
||||
|
||||
def get_stock_list(self):
|
||||
logger.info("get HS stock symbos......")
|
||||
symbols = get_hs_stock_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
df = pd.DataFrame(
|
||||
map(
|
||||
lambda x: x.split(","),
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"],
|
||||
)
|
||||
)
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.astype(float, errors="ignore")
|
||||
df["adjclose"] = df["close"]
|
||||
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
|
||||
# FIXME: 1m
|
||||
if self._interval == "1d":
|
||||
_format = "%Y%m%d"
|
||||
_begin = self._start_datetime.strftime(_format)
|
||||
_end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
try:
|
||||
df = pd.DataFrame(
|
||||
map(
|
||||
lambda x: x.split(","),
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[
|
||||
"data"
|
||||
]["klines"],
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"get {_index_name} error: {e}")
|
||||
continue
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.astype(float, errors="ignore")
|
||||
df["adjclose"] = df["close"]
|
||||
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
|
||||
else:
|
||||
logger.warning(f"{self.__class__.__name__} {self._interval} does not support: downlaod_index_data")
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
symbol_s = symbol.split(".")
|
||||
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
|
||||
return symbol
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "Asia/Shanghai"
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4):
|
||||
class YahooCollectorUS(YahooCollector):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
if self._interval == "1m":
|
||||
return 60 * 6.5 * 5
|
||||
elif self._interval == "1d":
|
||||
return 252 / 4
|
||||
|
||||
def get_stock_list(self):
|
||||
logger.info("get US stock symbols......")
|
||||
symbols = get_us_stock_symbols() + [
|
||||
"^GSPC",
|
||||
"^NDX",
|
||||
"^DJI",
|
||||
]
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def download_index_data(self):
|
||||
pass
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol.upper()
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "America/New_York"
|
||||
|
||||
|
||||
class YahooNormalize:
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
|
||||
def __init__(self, source_dir: [str, Path], target_dir: [str, Path], max_workers: int = 16):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
qlib_dir: str
|
||||
qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data"
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
Concurrent number, default is 16
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if qlib_dir is None:
|
||||
qlib_dir = CUR_DIR.joinpath("qlib_data")
|
||||
self.qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
self.qlib_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_workers = max_workers
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._max_workers = max_workers
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
def normalize_data(self):
|
||||
"""normalize data
|
||||
logger.info("normalize data......")
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
|
||||
|
||||
"""
|
||||
|
||||
def _normalize(file_path: Path):
|
||||
columns = ["open", "close", "high", "low", "volume"]
|
||||
df = pd.read_csv(file_path)
|
||||
def _normalize(source_path: Path):
|
||||
columns = copy.deepcopy(self.COLUMNS)
|
||||
df = pd.read_csv(source_path)
|
||||
df.set_index("date", inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
|
||||
# using China stock market data calendar
|
||||
df = df.reindex(pd.Index(get_calendar_list("ALL")))
|
||||
if self._calendar_list is not None:
|
||||
df = df.reindex(pd.DataFrame(index=self._calendar_list).loc[df.index.min() : df.index.max()].index)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan
|
||||
df["factor"] = df["adjclose"] / df["close"]
|
||||
for _col in columns:
|
||||
@@ -213,22 +378,17 @@ class Run:
|
||||
columns += ["change", "factor"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
df.index.names = ["date"]
|
||||
df.loc[:, columns].to_csv(self.normalize_dir.joinpath(file_path.name))
|
||||
df.loc[:, columns].to_csv(self._target_dir.joinpath(source_path.name))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
file_list = list(self.source_dir.glob("*.csv"))
|
||||
with ThreadPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(_normalize, file_list):
|
||||
p_bar.update()
|
||||
|
||||
def manual_adj_data(self):
|
||||
"""manual adjust data
|
||||
|
||||
Examples
|
||||
--------
|
||||
$ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
|
||||
|
||||
"""
|
||||
"""adjust data"""
|
||||
logger.info("manual adjust data......")
|
||||
|
||||
def _adj(file_path: Path):
|
||||
df = pd.read_csv(file_path)
|
||||
@@ -244,59 +404,166 @@ class Run:
|
||||
df[_col] = df[_col] / _close
|
||||
else:
|
||||
pass
|
||||
df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
|
||||
df.reset_index().to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
file_list = list(self.normalize_dir.glob("*.csv"))
|
||||
with ThreadPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._target_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(_adj, file_list):
|
||||
p_bar.update()
|
||||
|
||||
def dump_data(self):
|
||||
"""dump yahoo data
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
|
||||
"""
|
||||
DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump(
|
||||
include_fields="close,open,high,low,volume,change,factor"
|
||||
)
|
||||
|
||||
def download_data(self, asynchronous=False, max_collector_count=5, delay=0):
|
||||
"""download data from Internet
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source
|
||||
|
||||
"""
|
||||
YahooCollector(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
asynchronous=asynchronous,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
).collector_data()
|
||||
|
||||
def download_index_data(self):
|
||||
YahooCollector(self.source_dir).download_index_data()
|
||||
|
||||
def download_bench_data(self):
|
||||
"""download bench stock data(SH000300)"""
|
||||
|
||||
def collector_data(self):
|
||||
"""download -> normalize -> dump data
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
"""
|
||||
self.download_data()
|
||||
def normalize(self):
|
||||
self.normalize_data()
|
||||
self.manual_adj_data()
|
||||
self.dump_data()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class YahooNormalizeUS(YahooNormalize):
|
||||
def _get_calendar_list(self):
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("US_ALL")
|
||||
|
||||
|
||||
class YahooNormalizeCN(YahooNormalize):
|
||||
def _get_calendar_list(self):
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
region: str
|
||||
region, value from ["CN", "US"], default "CN"
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
self.region = region
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=5,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 5
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1m, 1d], default 1m
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(self._cur_module, f"YahooCollector{self.region.upper()}")
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self):
|
||||
"""normalize data
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}")
|
||||
_class(self.source_dir, self.normalize_dir, self.max_workers).normalize()
|
||||
|
||||
def collector_data(
|
||||
self,
|
||||
max_collector_count=5,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download -> normalize
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 5
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1m, 1d], default 1m
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
Examples
|
||||
-------
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
"""
|
||||
self.download_data(
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
self.normalize_data()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import shutil
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Union
|
||||
from functools import partial
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
@@ -13,8 +16,20 @@ from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class DumpData(object):
|
||||
FILE_SUFFIX = ".csv"
|
||||
class DumpDataBase:
|
||||
INSTRUMENTS_START_FIELD = "start_datetime"
|
||||
INSTRUMENTS_END_FIELD = "end_datetime"
|
||||
CALENDARS_DIR_NAME = "calendars"
|
||||
FEATURES_DIR_NAME = "features"
|
||||
INSTRUMENTS_DIR_NAME = "instruments"
|
||||
DUMP_FILE_SUFFIX = ".bin"
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
INSTRUMENTS_SEP = "\t"
|
||||
INSTRUMENTS_FILE_NAME = "all.txt"
|
||||
|
||||
UPDATE_MODE = "update"
|
||||
ALL_MODE = "all"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -22,8 +37,13 @@ class DumpData(object):
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
works: int = None,
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
file_suffix: str = ".csv",
|
||||
symbol_field_name: str = "symbol",
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -37,80 +57,101 @@ class DumpData(object):
|
||||
if backup_dir is not None, backup qlib_dir to backup_dir
|
||||
freq: str, default "day"
|
||||
transaction frequency
|
||||
works: int, default None
|
||||
max_workers: int, default None
|
||||
number of threads
|
||||
date_field_name: str, default "date"
|
||||
the name of the date field in the csv
|
||||
file_suffix: str, default ".csv"
|
||||
file suffix
|
||||
symbol_field_name: str, default "symbol"
|
||||
symbol field name
|
||||
include_fields: tuple
|
||||
dump fields
|
||||
exclude_fields: tuple
|
||||
fields not dumped
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
"""
|
||||
csv_path = Path(csv_path).expanduser()
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path])
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
include_fields = include_fields.split(",")
|
||||
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
|
||||
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
|
||||
self.file_suffix = file_suffix
|
||||
self.symbol_field_name = symbol_field_name
|
||||
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
if limit_nums is not None:
|
||||
self.csv_files = self.csv_files[: int(limit_nums)]
|
||||
self.qlib_dir = Path(qlib_dir).expanduser()
|
||||
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
|
||||
if backup_dir is not None:
|
||||
self._backup_qlib_dir(Path(backup_dir).expanduser())
|
||||
|
||||
self.freq = freq
|
||||
self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S"
|
||||
self.calendar_format = self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT
|
||||
|
||||
self.works = works
|
||||
self.works = max_workers
|
||||
self.date_field_name = date_field_name
|
||||
|
||||
self._calendars_dir = self.qlib_dir.joinpath("calendars")
|
||||
self._features_dir = self.qlib_dir.joinpath("features")
|
||||
self._instruments_dir = self.qlib_dir.joinpath("instruments")
|
||||
self._calendars_dir = self.qlib_dir.joinpath(self.CALENDARS_DIR_NAME)
|
||||
self._features_dir = self.qlib_dir.joinpath(self.FEATURES_DIR_NAME)
|
||||
self._instruments_dir = self.qlib_dir.joinpath(self.INSTRUMENTS_DIR_NAME)
|
||||
|
||||
self._calendars_list = []
|
||||
self._include_fields = ()
|
||||
self._exclude_fields = ()
|
||||
|
||||
self._mode = self.ALL_MODE
|
||||
self._kwargs = {}
|
||||
|
||||
def _backup_qlib_dir(self, target_dir: Path):
|
||||
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
|
||||
|
||||
def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False):
|
||||
df = pd.read_csv(str(file_path.resolve()))
|
||||
if df.empty or self.date_field_name not in df.columns.tolist():
|
||||
return []
|
||||
if is_begin_end:
|
||||
return [df[self.date_field_name].min(), df[self.date_field_name].max()]
|
||||
return df[self.date_field_name].tolist()
|
||||
def _format_datetime(self, datetime_d: [str, pd.Timestamp]):
|
||||
datetime_d = pd.Timestamp(datetime_d)
|
||||
return datetime_d.strftime(self.calendar_format)
|
||||
|
||||
def _get_source_data(self, file_path: Path):
|
||||
df = pd.read_csv(str(file_path.resolve()))
|
||||
def _get_date(
|
||||
self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False
|
||||
) -> Iterable[pd.Timestamp]:
|
||||
if not isinstance(file_or_df, pd.DataFrame):
|
||||
df = self._get_source_data(file_or_df)
|
||||
else:
|
||||
df = file_or_df
|
||||
if df.empty or self.date_field_name not in df.columns.tolist():
|
||||
_calendars = pd.Series()
|
||||
else:
|
||||
_calendars = df[self.date_field_name]
|
||||
|
||||
if is_begin_end and as_set:
|
||||
return (_calendars.min(), _calendars.max()), set(_calendars)
|
||||
elif is_begin_end:
|
||||
return _calendars.min(), _calendars.max()
|
||||
elif as_set:
|
||||
return set(_calendars)
|
||||
else:
|
||||
return _calendars.tolist()
|
||||
|
||||
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
|
||||
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
|
||||
df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64)
|
||||
# df.drop_duplicates([self.date_field_name], inplace=True)
|
||||
return df
|
||||
|
||||
def _file_to_bin(self, file_path: Path = None):
|
||||
code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower()
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name])
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
|
||||
# read csv file
|
||||
df = self._get_source_data(file_path)
|
||||
cal_df = calendars_df[
|
||||
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
|
||||
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
|
||||
]
|
||||
cal_df.set_index(self.date_field_name, inplace=True)
|
||||
df.set_index(self.date_field_name, inplace=True)
|
||||
r_df = df.reindex(cal_df.index)
|
||||
date_index = self._calendars_list.index(r_df.index.min())
|
||||
for field in (
|
||||
def get_symbol_from_file(self, file_path: Path) -> str:
|
||||
return file_path.name[: -len(self.file_suffix)].strip().lower()
|
||||
|
||||
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
|
||||
return (
|
||||
self._include_fields
|
||||
if self._include_fields
|
||||
else set(r_df.columns) - set(self._exclude_fields)
|
||||
else set(df_columns) - set(self._exclude_fields)
|
||||
if self._exclude_fields
|
||||
else r_df.columns
|
||||
):
|
||||
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin")
|
||||
if field not in r_df.columns:
|
||||
continue
|
||||
r = np.hstack([date_index, r_df[field]]).astype("<f")
|
||||
r.tofile(str(bin_path.resolve()))
|
||||
else df_columns
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _read_calendar(calendar_path: Path):
|
||||
def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]:
|
||||
return sorted(
|
||||
map(
|
||||
pd.Timestamp,
|
||||
@@ -118,133 +159,305 @@ class DumpData(object):
|
||||
)
|
||||
)
|
||||
|
||||
def dump_features(
|
||||
self,
|
||||
calendar_path: str = None,
|
||||
include_fields: tuple = None,
|
||||
exclude_fields: tuple = None,
|
||||
):
|
||||
"""dump features
|
||||
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
|
||||
return pd.read_csv(
|
||||
instrument_path,
|
||||
sep=self.INSTRUMENTS_SEP,
|
||||
names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD],
|
||||
)
|
||||
|
||||
Parameters
|
||||
---------
|
||||
calendar_path: str
|
||||
calendar path
|
||||
def save_calendars(self, calendars_data: list):
|
||||
self._calendars_dir.mkdir(parents=True, exist_ok=True)
|
||||
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
|
||||
result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data))
|
||||
np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8")
|
||||
|
||||
include_fields: str
|
||||
dump fields
|
||||
def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
|
||||
self._instruments_dir.mkdir(parents=True, exist_ok=True)
|
||||
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
|
||||
if isinstance(instruments_data, pd.DataFrame):
|
||||
instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]]
|
||||
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
|
||||
else:
|
||||
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
|
||||
|
||||
exclude_fields: str
|
||||
fields not dumped
|
||||
def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:
|
||||
# calendars
|
||||
calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])
|
||||
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
|
||||
cal_df = calendars_df[
|
||||
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
|
||||
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
|
||||
]
|
||||
# align index
|
||||
cal_df.set_index(self.date_field_name, inplace=True)
|
||||
df.set_index(self.date_field_name, inplace=True)
|
||||
r_df = df.reindex(cal_df.index)
|
||||
return r_df
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_features --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
@staticmethod
|
||||
def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int:
|
||||
return calendar_list.index(df.index.min())
|
||||
|
||||
Examples
|
||||
---------
|
||||
def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path):
|
||||
if df.empty:
|
||||
logger.warning(f"{features_dir.name} data is None or empty")
|
||||
return
|
||||
# align index
|
||||
_df = self.data_merge_calendar(df, self._calendars_list)
|
||||
date_index = self.get_datetime_index(_df, calendar_list)
|
||||
for field in self.get_dump_fields(_df.columns):
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
if field not in _df.columns:
|
||||
continue
|
||||
if self._mode == self.UPDATE_MODE:
|
||||
# update
|
||||
with bin_path.open("ab") as fp:
|
||||
np.array(_df[field]).astype("<f").tofile(fp)
|
||||
elif self._mode == self.ALL_MODE:
|
||||
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
|
||||
else:
|
||||
raise ValueError(f"{self._mode} cannot support!")
|
||||
|
||||
# dump all stock
|
||||
python dump_bin.py dump_features --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
|
||||
# dump one stock
|
||||
python dump_bin.py dump_features --csv_path ~/tmp/stock_data/sh600000.csv --qlib_dir ~/tmp/qlib_data --calendar_path ~/tmp/qlib_data/calendar/all.txt --exclude_fields date,code,timestamp,code_name
|
||||
"""
|
||||
logger.info("start dump features......")
|
||||
if calendar_path is not None:
|
||||
# read calendar from calendar file
|
||||
self._calendars_list = self._read_calendar(Path(calendar_path))
|
||||
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
|
||||
if isinstance(file_or_data, pd.DataFrame):
|
||||
if file_or_data.empty:
|
||||
return
|
||||
code = file_or_data.iloc[0][self.symbol_field_name].lower()
|
||||
df = file_or_data
|
||||
elif isinstance(file_or_data, Path):
|
||||
code = self.get_symbol_from_file(file_or_data)
|
||||
df = self._get_source_data(file_or_data)
|
||||
else:
|
||||
raise ValueError(f"not support {type(file_or_data)}")
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"{code} data is None or empty")
|
||||
return
|
||||
# features save dir
|
||||
features_dir = self._features_dir.joinpath(code)
|
||||
features_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._data_to_bin(df, calendar_list, features_dir)
|
||||
|
||||
if not self._calendars_list:
|
||||
self.dump_calendars()
|
||||
@abc.abstractmethod
|
||||
def dump(self):
|
||||
raise NotImplementedError("dump not implemented!")
|
||||
|
||||
self._include_fields = tuple(map(str.strip, include_fields)) if include_fields else self._include_fields
|
||||
self._exclude_fields = tuple(map(str.strip, exclude_fields)) if exclude_fields else self._exclude_fields
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.dump()
|
||||
|
||||
|
||||
class DumpDataAll(DumpDataBase):
|
||||
def _get_all_date(self):
|
||||
logger.info("start get all date......")
|
||||
all_datetime = set()
|
||||
date_range_list = []
|
||||
_fun = partial(self._get_date, as_set=True, is_begin_end=True)
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(self._file_to_bin, self.csv_files):
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
|
||||
self.csv_files, executor.map(_fun, self.csv_files)
|
||||
):
|
||||
all_datetime = all_datetime | _set_calendars
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
_begin_time = self._format_datetime(_begin_time)
|
||||
_end_time = self._format_datetime(_end_time)
|
||||
symbol = self.get_symbol_from_file(file_path)
|
||||
date_range_list.append(f"{self.INSTRUMENTS_SEP.join((symbol.upper(), _begin_time, _end_time))}")
|
||||
p_bar.update()
|
||||
self._kwargs["all_datetime_set"] = all_datetime
|
||||
self._kwargs["date_range_list"] = date_range_list
|
||||
logger.info("end of get all date.\n")
|
||||
|
||||
def _dump_calendars(self):
|
||||
logger.info("start dump calendars......")
|
||||
self._calendars_list = sorted(map(pd.Timestamp, self._kwargs["all_datetime_set"]))
|
||||
self.save_calendars(self._calendars_list)
|
||||
logger.info("end of calendars dump.\n")
|
||||
|
||||
def _dump_instruments(self):
|
||||
logger.info("start dump instruments......")
|
||||
self.save_instruments(self._kwargs["date_range_list"])
|
||||
logger.info("end of instruments dump.\n")
|
||||
|
||||
def _dump_features(self):
|
||||
logger.info("start dump features......")
|
||||
_dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
for _ in executor.map(_dump_func, self.csv_files):
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
|
||||
def dump_calendars(self):
|
||||
"""dump calendars
|
||||
def dump(self):
|
||||
self._get_all_date()
|
||||
self._dump_calendars()
|
||||
self._dump_instruments()
|
||||
self._dump_features()
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_calendars --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
|
||||
"""
|
||||
logger.info("start dump calendars......")
|
||||
calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
|
||||
all_datetime = set()
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for temp_datetime in executor.map(self._get_date_for_df, self.csv_files):
|
||||
all_datetime = all_datetime | set(temp_datetime)
|
||||
p_bar.update()
|
||||
|
||||
self._calendars_list = sorted(map(pd.Timestamp, all_datetime))
|
||||
self._calendars_dir.mkdir(parents=True, exist_ok=True)
|
||||
result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list))
|
||||
np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8")
|
||||
logger.info("end of calendars dump.\n")
|
||||
|
||||
def dump_instruments(self):
|
||||
"""dump instruments
|
||||
|
||||
Notes
|
||||
---------
|
||||
python dump_bin.py dump_instruments --csv_path <stock data directory or path> --qlib_dir <dump data directory>
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
|
||||
"""
|
||||
class DumpDataFix(DumpDataAll):
|
||||
def _dump_instruments(self):
|
||||
logger.info("start dump instruments......")
|
||||
symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files))
|
||||
_result_list = []
|
||||
_fun = partial(self._get_date_for_df, is_begin_end=True)
|
||||
with tqdm(total=len(symbol_list)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as execute:
|
||||
for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)):
|
||||
if res:
|
||||
begin_time = res[0]
|
||||
end_time = res[-1]
|
||||
_result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}")
|
||||
_fun = partial(self._get_date, is_begin_end=True)
|
||||
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
|
||||
with tqdm(total=len(new_stock_files)) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=self.works) as execute:
|
||||
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
|
||||
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
|
||||
symbol = self.get_symbol_from_file(file_path).upper()
|
||||
_dt_map = self._old_instruments.setdefault(symbol, dict())
|
||||
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
|
||||
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
|
||||
p_bar.update()
|
||||
|
||||
self._instruments_dir.mkdir(parents=True, exist_ok=True)
|
||||
to_path = str(self._instruments_dir.joinpath("all.txt").resolve())
|
||||
np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8")
|
||||
self.save_instruments(pd.DataFrame.from_dict(self._old_instruments, orient="index"))
|
||||
logger.info("end of instruments dump.\n")
|
||||
|
||||
def dump(self, include_fields: str = None, exclude_fields: tuple = None):
|
||||
"""dump data
|
||||
def dump(self):
|
||||
self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._old_instruments = self._read_instruments(
|
||||
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
|
||||
).to_dict(
|
||||
orient="index"
|
||||
) # type: dict
|
||||
self._dump_instruments()
|
||||
self._dump_features()
|
||||
|
||||
|
||||
class DumpDataUpdate(DumpDataBase):
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
qlib_dir: str,
|
||||
backup_dir: str = None,
|
||||
freq: str = "day",
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
file_suffix: str = ".csv",
|
||||
symbol_field_name: str = "symbol",
|
||||
exclude_fields: str = "",
|
||||
include_fields: str = "",
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
include_fields: str
|
||||
csv_path: str
|
||||
stock data path or directory
|
||||
qlib_dir: str
|
||||
qlib(dump) data director
|
||||
backup_dir: str, default None
|
||||
if backup_dir is not None, backup qlib_dir to backup_dir
|
||||
freq: str, default "day"
|
||||
transaction frequency
|
||||
max_workers: int, default None
|
||||
number of threads
|
||||
date_field_name: str, default "date"
|
||||
the name of the date field in the csv
|
||||
file_suffix: str, default ".csv"
|
||||
file suffix
|
||||
symbol_field_name: str, default "symbol"
|
||||
symbol field name
|
||||
include_fields: tuple
|
||||
dump fields
|
||||
|
||||
exclude_fields: str
|
||||
exclude_fields: tuple
|
||||
fields not dumped
|
||||
|
||||
Examples
|
||||
---------
|
||||
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor
|
||||
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
|
||||
limit_nums: int
|
||||
Use when debugging, default None
|
||||
"""
|
||||
if isinstance(exclude_fields, str):
|
||||
exclude_fields = exclude_fields.split(",")
|
||||
if isinstance(include_fields, str):
|
||||
include_fields = include_fields.split(",")
|
||||
self.dump_calendars()
|
||||
self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields)
|
||||
self.dump_instruments()
|
||||
super().__init__(
|
||||
csv_path,
|
||||
qlib_dir,
|
||||
backup_dir,
|
||||
freq,
|
||||
max_workers,
|
||||
date_field_name,
|
||||
file_suffix,
|
||||
symbol_field_name,
|
||||
exclude_fields,
|
||||
include_fields,
|
||||
)
|
||||
self._mode = self.UPDATE_MODE
|
||||
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
|
||||
self._update_instruments = self._read_instruments(
|
||||
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
|
||||
).to_dict(
|
||||
orient="index"
|
||||
) # type: dict
|
||||
|
||||
# load all csv files
|
||||
self._all_data = self._load_all_source_data() # type: pd.DataFrame
|
||||
self._update_calendars = sorted(
|
||||
filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())
|
||||
)
|
||||
self._new_calendar_list = self._old_calendar_list + self._update_calendars
|
||||
|
||||
def _load_all_source_data(self):
|
||||
# NOTE: Need more memory
|
||||
logger.info("start load all source data....")
|
||||
all_df = []
|
||||
|
||||
def _read_csv(file_path: Path):
|
||||
if self._include_fields:
|
||||
_df = pd.read_csv(file_path, usecols=self._include_fields)
|
||||
else:
|
||||
_df = pd.read_csv(file_path)
|
||||
if self.symbol_field_name not in _df.columns:
|
||||
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
|
||||
return _df
|
||||
|
||||
with tqdm(total=len(self.csv_files)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.works) as executor:
|
||||
for df in executor.map(_read_csv, self.csv_files):
|
||||
if df:
|
||||
all_df.append(df)
|
||||
p_bar.update()
|
||||
|
||||
logger.info("end of load all data.\n")
|
||||
return pd.concat(all_df, sort=False)
|
||||
|
||||
def _dump_calendars(self):
|
||||
pass
|
||||
|
||||
def _dump_instruments(self):
|
||||
pass
|
||||
|
||||
def _dump_features(self):
|
||||
logger.info("start dump features......")
|
||||
error_code = {}
|
||||
with ProcessPoolExecutor(max_workers=self.works) as executor:
|
||||
futures = {}
|
||||
for _code, _df in self._all_data.groupby(self.symbol_field_name):
|
||||
_code = str(_code).upper()
|
||||
_start, _end = self._get_date(_df, is_begin_end=True)
|
||||
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
|
||||
continue
|
||||
if _code in self._update_instruments:
|
||||
self._update_instruments[_code]["end_time"] = _end
|
||||
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
|
||||
else:
|
||||
# new stock
|
||||
_dt_range = self._update_instruments.setdefault(_code, dict())
|
||||
_dt_range["start_time"] = _start
|
||||
_dt_range["end_time"] = _end
|
||||
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
|
||||
|
||||
for _future in tqdm(as_completed(futures)):
|
||||
try:
|
||||
_future.result()
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
logger.info(f"dump bin errors: {error_code}")
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
|
||||
def dump(self):
|
||||
self.save_calendars(self._new_calendar_list)
|
||||
self._dump_features()
|
||||
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(DumpData)
|
||||
fire.Fire({"dump_all": DumpDataAll, "dump_fix": DumpDataFix, "dump_update": DumpDataUpdate})
|
||||
|
||||
@@ -55,7 +55,7 @@ class GetData:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data_cn(self, name="qlib_data_cn", target_dir="~/.qlib/qlib_data/cn_data", version="latest"):
|
||||
def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
@@ -63,18 +63,25 @@ class GetData:
|
||||
target_dir: str
|
||||
data save directory
|
||||
name: str
|
||||
dataset name, value from [qlib_data_cn, qlib_data_cn_simple], by default qlib_data_cn
|
||||
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
|
||||
version: str
|
||||
data version, value from [v0, v1, ..., latest], by default latest
|
||||
interval: str
|
||||
data freq, value from [1d], by default 1d
|
||||
region: str
|
||||
data region, value from [cn, us], by default cn
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version latest
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = f"{name}_{version}.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
# TODO: The US stock code contains "PRN", and the directory cannot be created on Windows system
|
||||
if region.lower() == "us":
|
||||
logger.warning(f"The US stock code contains 'PRN', and the directory cannot be created on Windows system")
|
||||
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
|
||||
self._download_data(file_name.lower(), target_dir)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestDataset(unittest.TestCase):
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri)
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def testCSI300(self):
|
||||
|
||||
@@ -149,7 +149,7 @@ class TestAllFlow(unittest.TestCase):
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri)
|
||||
GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def test_0_train(self):
|
||||
|
||||
@@ -14,7 +14,7 @@ from qlib.data import D
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
from dump_bin import DumpData
|
||||
from dump_bin import DumpDataAll, DumpDataFix
|
||||
|
||||
|
||||
DATA_DIR = Path(__file__).parent.joinpath("test_dump_data")
|
||||
@@ -36,7 +36,7 @@ class TestDumpData(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
GetData().csv_data_cn(SOURCE_DIR)
|
||||
TestDumpData.DUMP_DATA = DumpData(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR)
|
||||
TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
|
||||
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
provider_uri = str(QLIB_DIR.resolve())
|
||||
qlib.init(
|
||||
@@ -49,8 +49,10 @@ class TestDumpData(unittest.TestCase):
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(DATA_DIR.resolve()))
|
||||
|
||||
def test_0_dump_calendars(self):
|
||||
self.DUMP_DATA.dump_calendars()
|
||||
def test_0_dump_bin(self):
|
||||
self.DUMP_DATA.dump()
|
||||
|
||||
def test_1_dump_calendars(self):
|
||||
ori_calendars = set(
|
||||
map(
|
||||
pd.Timestamp,
|
||||
@@ -60,23 +62,21 @@ class TestDumpData(unittest.TestCase):
|
||||
res_calendars = set(D.calendar())
|
||||
assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed"
|
||||
|
||||
def test_1_dump_instruments(self):
|
||||
self.DUMP_DATA.dump_instruments()
|
||||
def test_2_dump_instruments(self):
|
||||
ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
|
||||
res_ins = set(D.list_instruments(D.instruments("all"), as_list=True))
|
||||
assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, "dump instruments failed"
|
||||
|
||||
def test_2_dump_features(self):
|
||||
self.DUMP_DATA.dump_features(include_fields=self.FIELDS)
|
||||
def test_3_dump_features(self):
|
||||
df = D.features(self.STOCK_NAMES, self.QLIB_FIELDS)
|
||||
TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.STOCK_NAMES[0], :]
|
||||
self.assertFalse(df.dropna().empty, "features data failed")
|
||||
self.assertListEqual(list(df.columns), self.QLIB_FIELDS, "features columns failed")
|
||||
|
||||
def test_3_dump_features_simple(self):
|
||||
def test_4_dump_features_simple(self):
|
||||
stock = self.STOCK_NAMES[0]
|
||||
dump_data = DumpData(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR)
|
||||
dump_data.dump_features(include_fields=self.FIELDS, calendar_path=QLIB_DIR.joinpath("calendars", "day.txt"))
|
||||
dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS)
|
||||
dump_data.dump()
|
||||
|
||||
df = D.features([stock], self.QLIB_FIELDS)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestGetData(unittest.TestCase):
|
||||
|
||||
def test_0_qlib_data(self):
|
||||
|
||||
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=QLIB_DIR)
|
||||
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", version="latest")
|
||||
df = D.features(D.instruments("csi300"), self.FIELDS)
|
||||
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
|
||||
self.assertFalse(df.dropna().empty, "get qlib data failed")
|
||||
|
||||
Reference in New Issue
Block a user