1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 10:01:19 +08:00

Merge remote-tracking branch 'origin/main' into main

This commit is contained in:
Young
2020-11-19 08:27:42 +00:00
27 changed files with 1836 additions and 503 deletions

View File

@@ -52,7 +52,7 @@ jobs:
- name: Test data downloads and examples
run: |
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
# cd examples
# estimator -c estimator/estimator_config.yaml
# jupyter nbconvert --execute estimator/analyze_from_estimator.ipynb --to html

View File

@@ -91,7 +91,7 @@ Also, users can install ``Qlib`` by the source code according to the following s
## Data Preparation
Load and prepare data by running the following code:
```bash
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
```
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in

View File

@@ -34,7 +34,7 @@ Qlib Format Dataset
.. code-block:: bash
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
After running the above command, users can find china-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory.
@@ -59,7 +59,7 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv
.. code-block:: bash
python scripts/dump_bin.py dump --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.

View File

@@ -40,7 +40,7 @@ Load and prepare data by running the following code:
.. code-block::
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
This dataset is created by public data collected by crawler scripts in ``scripts/data_collector/``, which have been released in the same repository. Users could create the same dataset with it.

View File

@@ -14,7 +14,7 @@ Please follow the steps below to initialize ``Qlib``.
- Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>` for more information about customized dataset.
.. code-block:: bash
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,

View File

@@ -41,7 +41,7 @@
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
" GetData().qlib_data(target_dir=provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
@@ -219,4 +219,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@@ -26,7 +26,7 @@ if __name__ == "__main__":
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data_cn(target_dir=provider_uri)
GetData().qlib_data(target_dir=provider_uri)
qlib.init(provider_uri=provider_uri, region=REG_CN)

View File

@@ -31,13 +31,13 @@
"outputs": [],
"source": [
"# use default data\n",
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n",
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
" GetData().qlib_data(target_dir=provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
@@ -335,4 +335,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

61
scripts/README.md Normal file
View File

@@ -0,0 +1,61 @@
- [Download Qlib Data](#Download-Qlib-Data)
- [Download CN Data](#Download-CN-Data)
- [Downlaod US Data](#Downlaod-US-Data)
- [Download CN Simple Data](#Download-CN-Simple-Data)
- [Help](#Help)
- [Using in Qlib](#Using-in-Qlib)
- [US data](#US-data)
- [CN data](#CN-data)
## Download Qlib Data
### Download CN Data
```bash
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
```
### Downlaod US Data
> The US stock code contains 'PRN', and the directory cannot be created on Windows system: https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
```bash
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
```
### Download CN Simple Data
```bash
python get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --region cn
```
### Help
```bash
python get_data.py qlib_data --help
```
## Using in Qlib
> For more information: https://qlib.readthedocs.io/en/latest/start/initialization.html
### US data
```python
import qlib
from qlib.config import REG_US
provider_uri = "~/.qlib/qlib_data/us_data" # target_dir
qlib.init(provider_uri=provider_uri, region=REG_US)
```
### CN data
```python
import qlib
from qlib.config import REG_CN
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
qlib.init(provider_uri=provider_uri, region=REG_CN)
```

144
scripts/check_dump_bin.py Normal file
View File

@@ -0,0 +1,144 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
import qlib
from qlib.data import D
import fire
import datacompy
import pandas as pd
from tqdm import tqdm
from loguru import logger
class CheckBin:
NOT_IN_FEATURES = "not in features"
COMPARE_FALSE = "compare False"
COMPARE_TRUE = "compare True"
COMPARE_ERROR = "compare error"
def __init__(
self,
qlib_dir: str,
csv_path: str,
check_fields: str = None,
freq: str = "day",
symbol_field_name: str = "symbol",
date_field_name: str = "date",
file_suffix: str = ".csv",
max_workers: int = 16,
):
"""
Parameters
----------
qlib_dir : str
qlib dir
csv_path : str
origin csv path
check_fields : str, optional
check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
freq : str, optional
freq, value from ["day", "1m"]
symbol_field_name: str, optional
symbol field name, by default "symbol"
date_field_name: str, optional
date field name, by default "date"
file_suffix: str, optional
csv file suffix, by default ".csv"
max_workers: int, optional
max workers, by default 16
"""
self.qlib_dir = Path(qlib_dir).expanduser()
bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))
qlib.init(
provider_uri=str(self.qlib_dir.resolve()),
mount_path=str(self.qlib_dir.resolve()),
auto_mount=False,
redis_port=-1,
)
csv_path = Path(csv_path).expanduser()
self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])
if check_fields is None:
check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin")))
else:
check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
self.check_fields = list(map(lambda x: x.strip(), check_fields))
self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
self.max_workers = max_workers
self.symbol_field_name = symbol_field_name
self.date_field_name = date_field_name
self.freq = freq
self.file_suffix = file_suffix
def _compare(self, file_path: Path):
symbol = file_path.name.strip(self.file_suffix)
if symbol.lower() not in self.qlib_symbols:
return self.NOT_IN_FEATURES
# qlib data
qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq)
qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True)
# csv data
origin_df = pd.read_csv(file_path)
origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name])
if self.symbol_field_name not in origin_df.columns:
origin_df[self.symbol_field_name] = symbol
origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)
origin_df.index.names = qlib_df.index.names
try:
compare = datacompy.Compare(
origin_df,
qlib_df,
on_index=True,
abs_tol=1e-08, # Optional, defaults to 0
rel_tol=1e-05, # Optional, defaults to 0
df1_name="Original", # Optional, defaults to 'df1'
df2_name="New", # Optional, defaults to 'df2'
)
_r = compare.matches(ignore_extra_columns=True)
return self.COMPARE_TRUE if _r else self.COMPARE_FALSE
except Exception as e:
logger.warning(f"{symbol} compare error: {e}")
return self.COMPARE_ERROR
def check(self):
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data
"""
logger.info("start check......")
error_list = []
not_in_features = []
compare_false = []
with tqdm(total=len(self.csv_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)):
symbol = file_path.name.strip(self.file_suffix)
if _check_res == self.NOT_IN_FEATURES:
not_in_features.append(symbol)
elif _check_res == self.COMPARE_ERROR:
error_list.append(symbol)
elif _check_res == self.COMPARE_FALSE:
compare_false.append(symbol)
p_bar.update()
logger.info("end of check......")
if error_list:
logger.warning(f"compare error: {error_list}")
if not_in_features:
logger.warning(f"not in features: {not_in_features}")
if compare_false:
logger.warning(f"compare False: {compare_false}")
logger.info(
f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false"
)
if __name__ == "__main__":
fire.Fire(CheckBin)

View File

@@ -0,0 +1,22 @@
# CSI300/CSI100 History Companies Collection
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
# parse instruments, using in qlib/instruments.
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
# index_name support: CSI300, CSI100
# help
python collector.py --help
```

View File

@@ -4,8 +4,9 @@
import re
import abc
import sys
import bisect
import importlib
from io import BytesIO
from typing import List
from pathlib import Path
import fire
@@ -16,7 +17,9 @@ from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.utils import get_hs_calendar_list as get_calendar_list
from data_collector.index import IndexBase
from data_collector.utils import get_calendar_list, get_trading_date_by_shift
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
@@ -24,64 +27,48 @@ NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
class CSIIndex:
REMOVE = "remove"
ADD = "add"
def __init__(self, qlib_dir=None):
"""
Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
"""
if qlib_dir is None:
qlib_dir = CUR_DIR.joinpath("qlib_data")
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
self.instruments_dir.mkdir(exist_ok=True, parents=True)
self._calendar_list = None
self.cache_dir = Path("~/.cache/csi").expanduser().resolve()
self.cache_dir.mkdir(exist_ok=True, parents=True)
class CSIIndex(IndexBase):
@property
def calendar_list(self) -> list:
def calendar_list(self) -> List[pd.Timestamp]:
"""get history trading date
Returns
-------
calendar list
"""
return get_calendar_list(bench_code=self.index_name.upper())
@property
def new_companies_url(self):
def new_companies_url(self) -> str:
return NEW_COMPANIES_URL.format(index_code=self.index_code)
@property
def changes_url(self):
def changes_url(self) -> str:
return INDEX_CHANGES_URL
@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
raise NotImplementedError()
"""
Returns
-------
index start date
"""
raise NotImplementedError("rewrite bench_start_date")
@property
@abc.abstractmethod
def index_code(self):
raise NotImplementedError()
def index_code(self) -> str:
"""
Returns
-------
index code
"""
raise NotImplementedError("rewrite index_code")
@property
@abc.abstractmethod
def index_name(self):
raise NotImplementedError()
@property
@abc.abstractmethod
def html_table_index(self):
def html_table_index(self) -> int:
"""Which table of changes in html
CSI300: 0
@@ -90,33 +77,19 @@ class CSIIndex:
"""
raise NotImplementedError()
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
"""get trading date by shift
Parameters
----------
shift : int
shift, default is 1
trading_date : pd.Timestamp
trading date
Returns
-------
"""
left_index = bisect.bisect_left(self.calendar_list, trading_date)
try:
res = self.calendar_list[left_index + shift]
except IndexError:
res = trading_date
return res
def _get_changes(self) -> pd.DataFrame:
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
logger.info("get companies changes......")
res = []
@@ -124,10 +97,21 @@ class CSIIndex:
_df = self._read_change_from_url(_url)
res.append(_df)
logger.info("get companies changes finish")
return pd.concat(res)
return pd.concat(res, sort=False)
@staticmethod
def normalize_symbol(symbol):
def normalize_symbol(symbol: str) -> str:
"""
Parameters
----------
symbol: str
symbol
Returns
-------
symbol
"""
symbol = f"{int(symbol):06}"
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
@@ -141,7 +125,14 @@ class CSIIndex:
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
resp = requests.get(url)
_text = resp.text
@@ -151,8 +142,8 @@ class CSIIndex:
add_date = pd.Timestamp("-".join(date_list[0]))
else:
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
add_date = self._get_trading_date_by_shift(_date, shift=0)
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
logger.info(f"get {add_date} changes")
try:
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
@@ -168,12 +159,12 @@ class CSIIndex:
_df = df_map[_s_name]
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
_df = _df.applymap(self.normalize_symbol)
_df.columns = ["symbol"]
_df.columns = [self.SYMBOL_FIELD_NAME]
_df["type"] = _type
_df["date"] = _date
_df[self.DATE_FIELD_NAME] = _date
tmp.append(_df)
df = pd.concat(tmp)
except Exception:
except Exception as e:
df = None
_tmp_count = 0
for _df in pd.read_html(resp.content):
@@ -188,9 +179,9 @@ class CSIIndex:
(_df.iloc[2:, 2], self.ADD, add_date),
]:
_tmp_df = pd.DataFrame()
_tmp_df["symbol"] = _s.map(self.normalize_symbol)
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
_tmp_df["type"] = _type
_tmp_df["date"] = _date
_tmp_df[self.DATE_FIELD_NAME] = _date
tmp.append(_tmp_df)
df = pd.concat(tmp)
df.to_csv(
@@ -203,20 +194,33 @@ class CSIIndex:
break
return df
def _get_change_notices_url(self) -> list:
def _get_change_notices_url(self) -> List[str]:
"""get change notices url
Returns
-------
[url1, url2]
"""
resp = requests.get(self.changes_url)
html = etree.HTML(resp.text)
return html.xpath("//*[@id='itemContainer']//li/a/@href")
def _get_new_companies(self):
def get_new_companies(self) -> pd.DataFrame:
"""
logger.info("get new companies")
Returns
-------
pd.DataFrame:
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
dtypes:
symbol: str
start_date: pd.Timestamp
end_date: pd.Timestamp
"""
logger.info("get new companies......")
context = requests.get(self.new_companies_url).content
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
@@ -225,51 +229,19 @@ class CSIIndex:
_io = BytesIO(context)
df = pd.read_excel(_io)
df = df.iloc[:, [0, 4]]
df.columns = ["end_date", "symbol"]
df["symbol"] = df["symbol"].map(self.normalize_symbol)
df["end_date"] = pd.to_datetime(df["end_date"])
df["start_date"] = self.bench_start_date
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD])
df[self.START_DATE_FIELD] = self.bench_start_date
logger.info("end of get new companies.")
return df
def parse_instruments(self):
"""parse csi300.txt
Examples
-------
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
"""
logger.info(f"start parse {self.index_name.lower()} companies.....")
instruments_columns = ["symbol", "start_date", "end_date"]
changers_df = self._get_changes()
new_df = self._get_new_companies()
logger.info("parse history companies by changes......")
for _row in changers_df.sort_values("date", ascending=False).itertuples(index=False):
if _row.type == self.ADD:
min_end_date = new_df.loc[new_df["symbol"] == _row.symbol, "end_date"].min()
new_df.loc[
(new_df["end_date"] == min_end_date) & (new_df["symbol"] == _row.symbol), "start_date"
] = _row.date
else:
_tmp_df = pd.DataFrame(
[[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"]
)
new_df = new_df.append(_tmp_df, sort=False)
new_df.loc[:, instruments_columns].to_csv(
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
)
logger.info(f"parse {self.index_name.lower()} companies finished.")
class CSI300(CSIIndex):
@property
def index_code(self):
return "000300"
@property
def index_name(self):
return "csi300"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2005-01-01")
@@ -284,10 +256,6 @@ class CSI100(CSIIndex):
def index_code(self):
return "000903"
@property
def index_name(self):
return "csi100"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2006-05-29")
@@ -297,19 +265,39 @@ class CSI100(CSIIndex):
return 1
def parse_instruments(qlib_dir: str):
def get_instruments(
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
):
"""
Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
index_name: str
index name, value from ["csi100", "csi300"]
method: str
method, value from ["parse_instruments", "save_new_companies"]
request_retry: int
request retry, by default 5
retry_sleep: int
request sleep, by default 3
Examples
-------
# parse instruments
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
"""
qlib_dir = Path(qlib_dir).expanduser().resolve()
qlib_dir.mkdir(exist_ok=True, parents=True)
CSI300(qlib_dir).parse_instruments()
CSI100(qlib_dir).parse_instruments()
_cur_module = importlib.import_module("collector")
obj = getattr(_cur_module, f"{index_name.upper()}")(
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
)
getattr(obj, method)()
if __name__ == "__main__":
fire.Fire(parse_instruments)
fire.Fire(get_instruments)

View File

@@ -1,14 +0,0 @@
# CSI300 History Companies Collection
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
python collector.py parse_instruments --qlib_dir ~/.qlib/stock_data/qlib_data
```

View File

@@ -0,0 +1,202 @@
import sys
import abc
from pathlib import Path
from typing import List
import pandas as pd
from tqdm import tqdm
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent))
from data_collector.utils import get_trading_date_by_shift
class IndexBase:
DEFAULT_END_DATE = pd.Timestamp("2099-12-31")
SYMBOL_FIELD_NAME = "symbol"
DATE_FIELD_NAME = "date"
START_DATE_FIELD = "start_date"
END_DATE_FIELD = "end_ate"
CHANGE_TYPE_FIELD = "type"
INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD]
REMOVE = "remove"
ADD = "add"
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
"""
Parameters
----------
index_name: str
index name
qlib_dir: str
qlib directory, by default Path(__file__).resolve().parent.joinpath("qlib_data")
request_retry: int
request retry, by default 5
retry_sleep: int
request sleep, by default 3
"""
self.index_name = index_name
if qlib_dir is None:
qlib_dir = Path(__file__).resolve().parent.joinpath("qlib_data")
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
self.instruments_dir.mkdir(exist_ok=True, parents=True)
self.cache_dir = Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve()
self.cache_dir.mkdir(exist_ok=True, parents=True)
self._request_retry = request_retry
self._retry_sleep = retry_sleep
@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
"""
Returns
-------
index start date
"""
raise NotImplementedError("rewrite bench_start_date")
@property
@abc.abstractmethod
def calendar_list(self) -> List[pd.Timestamp]:
"""get history trading date
Returns
-------
calendar list
"""
raise NotImplementedError("rewrite calendar_list")
@abc.abstractmethod
def get_new_companies(self) -> pd.DataFrame:
"""
Returns
-------
pd.DataFrame:
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
dtypes:
symbol: str
start_date: pd.Timestamp
end_date: pd.Timestamp
"""
raise NotImplementedError("rewrite get_new_companies")
@abc.abstractmethod
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
raise NotImplementedError("rewrite get_changes")
def save_new_companies(self):
"""save new companies
Examples
-------
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
"""
df = self.get_new_companies()
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
)
def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame:
"""get changes with history companies
Parameters
----------
history_companies : pd.DataFrame
symbol date
SH600000 2020-11-11
dtypes:
symbol: str
date: pd.Timestamp
Return
--------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
logger.info("parse changes from history companies......")
last_code = []
result_df_list = []
_columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD]
for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)):
_currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][
self.SYMBOL_FIELD_NAME
].tolist()
if last_code:
add_code = list(set(last_code) - set(_currenet_code))
remote_code = list(set(_currenet_code) - set(last_code))
for _code in add_code:
result_df_list.append(
pd.DataFrame(
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]],
columns=_columns,
)
)
for _code in remote_code:
result_df_list.append(
pd.DataFrame(
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]],
columns=_columns,
)
)
last_code = _currenet_code
df = pd.concat(result_df_list)
logger.info("end of parse changes from history companies.")
return df
def parse_instruments(self):
"""parse instruments, eg: csi300.txt
Examples
-------
$ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
"""
logger.info(f"start parse {self.index_name.lower()} companies.....")
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
changers_df = self.get_changes()
new_df = self.get_new_companies().copy()
logger.info("parse history companies by changes......")
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
if _row.type == self.ADD:
min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min()
new_df.loc[
(new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol),
self.START_DATE_FIELD,
] = _row.date
else:
_tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)
new_df = new_df.append(_tmp_df, sort=False)
new_df.loc[:, instruments_columns].to_csv(
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
)
logger.info(f"parse {self.index_name.lower()} companies finished.")

View File

@@ -0,0 +1,22 @@
# NASDAQ100/SP500/SP400/DJIA History Companies Collection
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
# parse instruments, using in qlib/instruments.
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
# index_name support: SP500, NASDAQ100, DJIA, SP400
# help
python collector.py --help
```

View File

@@ -0,0 +1,278 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import sys
import importlib
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from typing import List
import fire
import requests
import pandas as pd
from tqdm import tqdm
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.index import IndexBase
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
WIKI_URL = "https://en.wikipedia.org/wiki"
WIKI_INDEX_NAME_MAP = {
"NASDAQ100": "NASDAQ-100",
"SP500": "List_of_S%26P_500_companies",
"SP400": "List_of_S%26P_400_companies",
"DJIA": "Dow_Jones_Industrial_Average",
}
class WIKIIndex(IndexBase):
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
super(WIKIIndex, self).__init__(
index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep
)
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
"""
Returns
-------
index start date
"""
raise NotImplementedError("rewrite bench_start_date")
@abc.abstractmethod
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
raise NotImplementedError("rewrite get_changes")
@property
def calendar_list(self) -> List[pd.Timestamp]:
"""get history trading date
Returns
-------
calendar list
"""
_calendar_list = getattr(self, "_calendar_list", None)
if _calendar_list is None:
_calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL")))
setattr(self, "_calendar_list", _calendar_list)
return _calendar_list
def _request_new_companies(self) -> requests.Response:
resp = requests.get(self._target_url)
if resp.status_code != 200:
raise ValueError(f"request error: {self._target_url}")
return resp
def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame:
_df = df.copy()
_df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip()
_df[self.START_DATE_FIELD] = self.bench_start_date
_df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE
return _df.loc[:, self.INSTRUMENTS_COLUMNS]
def get_new_companies(self):
logger.info(f"get new companies {self.index_name} ......")
_data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()
df_list = pd.read_html(_data.text)
for _df in df_list:
_df = self.filter_df(_df)
if (_df is not None) and (not _df.empty):
_df.columns = [self.SYMBOL_FIELD_NAME]
_df = self.set_default_date_range(_df)
logger.info(f"end of get new companies {self.index_name} ......")
return _df
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
raise NotImplementedError("rewrite filter_df")
class NASDAQ100Index(WIKIIndex):
HISTORY_COMPANIES_URL = (
"https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD"
)
MAX_WORKERS = 16
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if not (set(df.columns) - {"Company", "Ticker"}):
return df.loc[:, ["Ticker"]].copy()
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2003-01-02")
@deco_retry
def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame:
trade_date = trade_date.strftime("%Y-%m-%d")
cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl")
if cache_path.exists() and use_cache:
df = pd.read_pickle(cache_path)
else:
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
resp = requests.post(url)
if resp.status_code != 200:
raise ValueError(f"request error: {url}")
df = pd.DataFrame(resp.json()["aaData"])
df[self.DATE_FIELD_NAME] = trade_date
df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True)
if not df.empty:
df.to_pickle(cache_path)
return df
def get_history_companies(self):
logger.info(f"start get history companies......")
all_history = []
error_list = []
with tqdm(total=len(self.calendar_list)) as p_bar:
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
for _trading_date, _df in zip(
self.calendar_list, executor.map(self._request_history_companies, self.calendar_list)
):
if _df.empty:
error_list.append(_trading_date)
else:
all_history.append(_df)
p_bar.update()
if error_list:
logger.warning(f"get error: {error_list}")
logger.info(f"total {len(self.calendar_list)}, error {len(error_list)}")
logger.info(f"end of get history companies.")
return pd.concat(all_history, sort=False)
def get_changes(self):
return self.get_changes_with_history_companies(self.get_history_companies())
class DJIAIndex(WIKIIndex):
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2000-01-01")
def get_changes(self) -> pd.DataFrame:
pass
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if "Symbol" in df.columns:
_df = df.loc[:, ["Symbol"]].copy()
_df["Symbol"] = _df["Symbol"].apply(lambda x: x.split(":")[-1])
return _df
def parse_instruments(self):
logger.warning(f"No suitable data source has been found!")
class SP500Index(WIKIIndex):
WIKISP500_CHANGES_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("1999-01-01")
def get_changes(self) -> pd.DataFrame:
logger.info(f"get sp500 history changes......")
# NOTE: may update the index of the table
changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1]
changes_df = changes_df.iloc[:, [0, 1, 3]]
changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]
changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])
_result = []
for _type in [self.ADD, self.REMOVE]:
_df = changes_df.copy()
_df[self.CHANGE_TYPE_FIELD] = _type
_df[self.SYMBOL_FIELD_NAME] = _df[_type]
_df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True)
if _type == self.ADD:
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
lambda x: get_trading_date_by_shift(self.calendar_list, x, 0)
)
else:
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
lambda x: get_trading_date_by_shift(self.calendar_list, x, -1)
)
_result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]])
logger.info(f"end of get sp500 history changes.")
return pd.concat(_result, sort=False)
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if "Symbol" in df.columns:
return df.loc[:, ["Symbol"]].copy()
class SP400Index(WIKIIndex):
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2000-01-01")
def get_changes(self) -> pd.DataFrame:
pass
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if "Ticker symbol" in df.columns:
return df.loc[:, ["Ticker symbol"]].copy()
def parse_instruments(self):
logger.warning(f"No suitable data source has been found!")
def get_instruments(
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
):
"""
Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
index_name: str
index name, value from ["SP500", "NASDAQ100", "DJIA", "SP400"]
method: str
method, value from ["parse_instruments", "save_new_companies"]
request_retry: int
request retry, by default 5
retry_sleep: int
request sleep, by default 3
Examples
-------
# parse instruments
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
"""
_cur_module = importlib.import_module("collector")
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
)
getattr(obj, method)()
if __name__ == "__main__":
fire.Fire(get_instruments)

View File

@@ -0,0 +1,6 @@
logure
fire
requests
pandas
lxml
loguru

View File

@@ -3,56 +3,69 @@
import re
import time
import bisect
import pickle
import requests
import functools
from pathlib import Path
import pandas as pd
from lxml import etree
from loguru import logger
from yahooquery import Ticker
SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231"
CALENDAR_BENCH_URL_MAP = {
"CSI300": CALENDAR_URL_BASE.format(bench_code="000300"),
"CSI100": CALENDAR_URL_BASE.format(bench_code="000903"),
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
# NOTE: Use the time series of SH600000 as the sequence of all stocks
"ALL": CALENDAR_URL_BASE.format(bench_code="600000"),
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
"US_ALL": "^GSPC",
}
_BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
_US_SYMBOLS = None
_CALENDAR_MAP = {}
# NOTE: Until 2020-10-20 20:00:00
MINIMUM_SYMBOLS_NUM = 3900
def get_hs_calendar_list(bench_code="CSI300") -> list:
def get_calendar_list(bench_code="CSI300") -> list:
"""get SH/SZ history calendar list
Parameters
----------
bench_code: str
value from ["CSI300", "CSI500", "ALL"]
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
Returns
-------
history calendar list
"""
logger.info(f"get calendar list: {bench_code}......")
def _get_calendar(url):
_value_list = requests.get(url).json()["data"]["klines"]
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
if bench_code.startswith("US_"):
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
else:
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
_CALENDAR_MAP[bench_code] = calendar
logger.info(f"end of get calendar list: {bench_code}.")
return calendar
@@ -68,13 +81,14 @@ def get_hs_stock_symbols() -> list:
def _get_symbol():
_res = set()
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k))
_res |= set(
map(
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
)
)
time.sleep(3)
return _res
if _HS_SYMBOLS is None:
@@ -99,6 +113,84 @@ def get_hs_stock_symbols() -> list:
return _HS_SYMBOLS
def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get US stock symbols
Returns
-------
stock symbols
"""
global _US_SYMBOLS
@deco_retry
def _get_eastmoney():
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
resp = requests.get(url)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
except Exception as e:
logger.warning(f"request error: {e}")
raise
if len(_symbols) < 8000:
raise ValueError("request error")
return _symbols
@deco_retry
def _get_nasdaq():
_res_symbols = []
for _name in ["otherlisted", "nasdaqtraded"]:
url = f"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt"
df = pd.read_csv(url, sep="|")
df = df.rename(columns={"ACT Symbol": "Symbol"})
_symbols = df["Symbol"].dropna()
_symbols = _symbols.str.replace("$", "-P", regex=False)
_symbols = _symbols.str.replace(".W", "-WT", regex=False)
_symbols = _symbols.str.replace(".U", "-UN", regex=False)
_symbols = _symbols.str.replace(".R", "-RI", regex=False)
_symbols = _symbols.str.replace(".", "-", regex=False)
_res_symbols += _symbols.unique().tolist()
return _res_symbols
@deco_retry
def _get_nyse():
url = "https://www.nyse.com/api/quotes/filter"
_parms = {
"instrumentType": "EQUITY",
"pageNumber": 1,
"sortColumn": "NORMALIZED_TICKER",
"sortOrder": "ASC",
"maxResultsPerPage": 10000,
"filterToken": "",
}
resp = requests.post(url, json=_parms)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
except Exception as e:
logger.warning(f"request error: {e}")
_symbols = []
return _symbols
if _US_SYMBOLS is None:
_all_symbols = _get_eastmoney() + _get_nasdaq() + _get_nyse()
if qlib_data_path is not None:
for _index in ["nasdaq100", "sp500"]:
ins_df = pd.read_csv(
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
sep="\t",
names=["symbol", "start_date", "end_date"],
)
_all_symbols += ins_df["symbol"].unique().tolist()
_US_SYMBOLS = sorted(
set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))
)
return _US_SYMBOLS
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
"""symbol suffix to prefix
@@ -137,5 +229,52 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
return res.upper() if capital else res.lower()
def deco_retry(retry: int = 5, retry_sleep: int = 3):
def deco_func(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
_retry = 5 if callable(retry) else retry
_result = None
for _i in range(1, _retry + 1):
try:
_result = func(*args, **kwargs)
break
except Exception as e:
logger.warning(f"{func.__name__}: {_i} :{e}")
if _i == _retry:
raise
time.sleep(retry_sleep)
return _result
return wrapper
return deco_func(retry) if callable(retry) else deco_func
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
"""get trading date by shift
Parameters
----------
trading_list: list
trading calendar list
shift : int
shift, default is 1
trading_date : pd.Timestamp
trading date
Returns
-------
"""
trading_date = pd.Timestamp(trading_date)
left_index = bisect.bisect_left(trading_list, trading_date)
try:
res = trading_list[left_index + shift]
except IndexError:
res = trading_date
return res
if __name__ == "__main__":
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM

View File

@@ -18,31 +18,29 @@ pip install -r requirements.txt
## Collector Data
### Download data -> Normalize data -> Dump data
### Download data and Normalize data
```bash
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
```
### Download Data From Yahoo Finance
### Download Data
```bash
python collector.py download_data --source_dir ~/.qlib/stock_data/source
python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
```
### Normalize Yahoo Finance Data
### Normalize Data
```bash
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN
```
### Manual Ajust Yahoo Finance Data
### Help
```bash
python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
pythono collector.py collector_data --help
```
### Dump Yahoo Finance Data
## Parameters
```bash
python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
```
- interval: 1m or 1d
- region: CN or US

View File

@@ -1,8 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import sys
import copy
import time
import datetime
import importlib
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -13,33 +17,103 @@ import pandas as pd
from tqdm import tqdm
from loguru import logger
from yahooquery import Ticker
from dateutil.tz import tzlocal
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from dump_bin import DumpData
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
MIN_NUMBERS_TRADING = 252 / 4
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
REGION_CN = "CN"
REGION_US = "US"
class YahooCollector:
def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=False, max_collector_count=5, delay=0):
START_DATETIME = pd.Timestamp("2000-01-01")
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 5))
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
def __init__(
self,
save_dir: [str, Path],
start=None,
end=None,
interval="1d",
max_workers=4,
max_collector_count=5,
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
):
"""
Parameters
----------
save_dir: str
stock save dir
max_workers: int
workers, default 4
max_collector_count: int
default 5
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1m, 1d], default 1m
start: str
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
limit_nums: int
using for debug, by default None
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self._delay = delay
self._stock_list = None
self.stock_list = sorted(set(self.get_stock_list()))
if limit_nums is not None:
try:
self.stock_list = self.stock_list[: int(limit_nums)]
except Exception as e:
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
self.max_workers = max_workers
self._asynchronous = asynchronous
self._max_collector_count = max_collector_count
self._mini_symbol_map = {}
self._interval = interval
self._check_small_data = check_data_length
self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME
if self._interval == "1m":
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
elif self._interval == "1d":
self._start_datetime = max(self._start_datetime, self.START_DATETIME)
else:
raise ValueError(f"interval error: {self._interval}")
self._start_datetime = self.convert_datetime(self._start_datetime)
self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME))
@property
def stock_list(self):
if self._stock_list is None:
self._stock_list = get_hs_stock_symbols()
return self._stock_list
@abc.abstractmethod
def min_numbers_trading(self):
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
raise NotImplementedError("rewirte min_numbers_trading")
@abc.abstractmethod
def get_stock_list(self):
raise NotImplementedError("rewirte get_stock_list")
@property
@abc.abstractclassmethod
def _timezone(self):
raise NotImplementedError("rewrite get_timezone")
def convert_datetime(self, dt: pd.Timestamp):
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
return pd.Timestamp(dt, tz=tzlocal(), unit="s")
def _sleep(self):
time.sleep(self._delay)
@@ -57,63 +131,95 @@ class YahooCollector:
if df.empty:
raise ValueError("df is empty")
symbol_s = symbol.split(".")
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
symbol = self.normalize_symbol(symbol)
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
df.to_csv(stock_path, index=False)
if stock_path.exists():
with stock_path.open("a") as fp:
df.to_csv(fp, index=False, header=None)
else:
with stock_path.open("w") as fp:
df.to_csv(fp, index=False)
def _temp_save_small_data(self, symbol, df):
if len(df) <= MIN_NUMBERS_TRADING:
logger.warning(f"the number of trading days of {symbol} is less than {MIN_NUMBERS_TRADING}!")
def _save_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
_temp = self._mini_symbol_map.setdefault(symbol, [])
_temp.append(df.copy())
return None
else:
if symbol in self._mini_symbol_map:
self._mini_symbol_map.pop(symbol)
return symbol
def _get_from_remote(self, symbol):
def _get_simple(start_, end_):
self._sleep()
try:
_resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_)
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
else:
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}")
except Exception as e:
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}")
_result = None
if self._interval == "1d":
_result = _get_simple(self._start_datetime, self._end_datetime)
elif self._interval == "1m":
_start_date = self._start_datetime.date() + pd.Timedelta(days=1)
_end_date = self._end_datetime.date()
if _start_date >= _end_date:
_result = _get_simple(self._start_datetime, self._end_datetime)
else:
_res = []
def _get_multi(start_, end_):
_resp = _get_simple(start_, end_)
if _resp is not None:
_res.append(_resp)
for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)):
_get_multi(_s, _e)
for _start in pd.date_range(_start_date, _end_date, closed="left"):
_end = _start + pd.Timedelta(days=1)
self._sleep()
_get_multi(_start, _end)
if _res:
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
else:
raise ValueError(f"cannot support {self._interval}")
return _result
def _get_data(self, symbol):
_result = None
df = self._get_from_remote(symbol)
if isinstance(df, pd.DataFrame):
if not df.empty:
if self._check_small_data:
if self._save_small_data(symbol, df) is not None:
_result = symbol
self.save_stock(symbol, df)
else:
_result = symbol
self.save_stock(symbol, df)
return _result
def _collector(self, stock_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
futures = {}
p_bar = tqdm(total=len(stock_list))
for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]:
self._sleep()
resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history(
period="max"
)
if isinstance(resp, dict):
for symbol, df in resp.items():
if isinstance(df, pd.DataFrame):
self._temp_save_small_data(self, df)
futures[
worker.submit(
self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"})
)
] = symbol
else:
error_symbol.append(symbol)
else:
for symbol, df in resp.reset_index().groupby("symbol"):
self._temp_save_small_data(self, df)
futures[worker.submit(self.save_stock, symbol, df)] = symbol
p_bar.update(self.max_workers)
p_bar.close()
with tqdm(total=len(futures.values())) as p_bar:
for future in as_completed(futures):
try:
future.result()
except Exception as e:
logger.error(e)
error_symbol.append(futures[future])
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with tqdm(total=len(stock_list)) as p_bar:
for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)):
if _result is None:
error_symbol.append(_symbol)
p_bar.update()
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(stock_list)}")
error_symbol.extend(self._mini_symbol_map.keys())
return error_symbol
return sorted(set(error_symbol))
def collector_data(self):
"""collector data"""
@@ -126,81 +232,140 @@ class YahooCollector:
stock_list = self._collector(stock_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self._mini_symbol_map.items():
self.save_stock(_symbol, max(_df_list, key=len))
logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}")
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
if self._mini_symbol_map:
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
self.download_index_data()
@abc.abstractmethod
def download_index_data(self):
"""download index data"""
raise NotImplementedError("rewrite download_index_data")
@abc.abstractmethod
def normalize_symbol(self, symbol: str):
"""normalize symbol"""
raise NotImplementedError("rewrite normalize_symbol")
class YahooCollectorCN(YahooCollector):
@property
def min_numbers_trading(self):
if self._interval == "1m":
return 60 * 4 * 5
elif self._interval == "1d":
return 252 / 4
def get_stock_list(self):
logger.info("get HS stock symbos......")
symbols = get_hs_stock_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols
def download_index_data(self):
# TODO: from MSN
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
logger.info(f"get bench data: {_index_name}({_index_code})......")
df = pd.DataFrame(
map(
lambda x: x.split(","),
requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"],
)
)
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
df["date"] = pd.to_datetime(df["date"])
df = df.astype(float, errors="ignore")
df["adjclose"] = df["close"]
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
# FIXME: 1m
if self._interval == "1d":
_format = "%Y%m%d"
_begin = self._start_datetime.strftime(_format)
_end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format)
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
logger.info(f"get bench data: {_index_name}({_index_code})......")
try:
df = pd.DataFrame(
map(
lambda x: x.split(","),
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[
"data"
]["klines"],
)
)
except Exception as e:
logger.warning(f"get {_index_name} error: {e}")
continue
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
df["date"] = pd.to_datetime(df["date"])
df = df.astype(float, errors="ignore")
df["adjclose"] = df["close"]
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
else:
logger.warning(f"{self.__class__.__name__} {self._interval} does not support: downlaod_index_data")
def normalize_symbol(self, symbol):
symbol_s = symbol.split(".")
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
return symbol
@property
def _timezone(self):
return "Asia/Shanghai"
class Run:
def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4):
class YahooCollectorUS(YahooCollector):
@property
def min_numbers_trading(self):
if self._interval == "1m":
return 60 * 6.5 * 5
elif self._interval == "1d":
return 252 / 4
def get_stock_list(self):
logger.info("get US stock symbols......")
symbols = get_us_stock_symbols() + [
"^GSPC",
"^NDX",
"^DJI",
]
logger.info(f"get {len(symbols)} symbols.")
return symbols
def download_index_data(self):
pass
def normalize_symbol(self, symbol):
return symbol.upper()
@property
def _timezone(self):
return "America/New_York"
class YahooNormalize:
COLUMNS = ["open", "close", "high", "low", "volume"]
def __init__(self, source_dir: [str, Path], target_dir: [str, Path], max_workers: int = 16):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
qlib_dir: str
qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data"
source_dir: str or Path
The directory where the raw data collected from the Internet is saved
target_dir: str or Path
Directory for normalize data
max_workers: int
Concurrent number, default is 4
Concurrent number, default is 16
"""
if source_dir is None:
source_dir = CUR_DIR.joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
if qlib_dir is None:
qlib_dir = CUR_DIR.joinpath("qlib_data")
self.qlib_dir = Path(qlib_dir).expanduser().resolve()
self.qlib_dir.mkdir(parents=True, exist_ok=True)
self.max_workers = max_workers
if not (source_dir and target_dir):
raise ValueError("source_dir and target_dir cannot be None")
self._source_dir = Path(source_dir).expanduser()
self._target_dir = Path(target_dir).expanduser()
self._max_workers = max_workers
self._calendar_list = self._get_calendar_list()
def normalize_data(self):
"""normalize data
logger.info("normalize data......")
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
"""
def _normalize(file_path: Path):
columns = ["open", "close", "high", "low", "volume"]
df = pd.read_csv(file_path)
def _normalize(source_path: Path):
columns = copy.deepcopy(self.COLUMNS)
df = pd.read_csv(source_path)
df.set_index("date", inplace=True)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
# using China stock market data calendar
df = df.reindex(pd.Index(get_calendar_list("ALL")))
if self._calendar_list is not None:
df = df.reindex(pd.DataFrame(index=self._calendar_list).loc[df.index.min() : df.index.max()].index)
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan
df["factor"] = df["adjclose"] / df["close"]
for _col in columns:
@@ -213,22 +378,17 @@ class Run:
columns += ["change", "factor"]
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
df.index.names = ["date"]
df.loc[:, columns].to_csv(self.normalize_dir.joinpath(file_path.name))
df.loc[:, columns].to_csv(self._target_dir.joinpath(source_path.name))
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
file_list = list(self.source_dir.glob("*.csv"))
with ThreadPoolExecutor(max_workers=self._max_workers) as worker:
file_list = list(self._source_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(_normalize, file_list):
p_bar.update()
def manual_adj_data(self):
"""manual adjust data
Examples
--------
$ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
"""
"""adjust data"""
logger.info("manual adjust data......")
def _adj(file_path: Path):
df = pd.read_csv(file_path)
@@ -244,59 +404,166 @@ class Run:
df[_col] = df[_col] / _close
else:
pass
df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
df.reset_index().to_csv(self._target_dir.joinpath(file_path.name), index=False)
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
file_list = list(self.normalize_dir.glob("*.csv"))
with ThreadPoolExecutor(max_workers=self._max_workers) as worker:
file_list = list(self._target_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(_adj, file_list):
p_bar.update()
def dump_data(self):
"""dump yahoo data
Examples
---------
$ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
"""
DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump(
include_fields="close,open,high,low,volume,change,factor"
)
def download_data(self, asynchronous=False, max_collector_count=5, delay=0):
"""download data from Internet
Examples
---------
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source
"""
YahooCollector(
self.source_dir,
max_workers=self.max_workers,
asynchronous=asynchronous,
max_collector_count=max_collector_count,
delay=delay,
).collector_data()
def download_index_data(self):
YahooCollector(self.source_dir).download_index_data()
def download_bench_data(self):
"""download bench stock data(SH000300)"""
def collector_data(self):
"""download -> normalize -> dump data
Examples
-------
$ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
"""
self.download_data()
def normalize(self):
self.normalize_data()
self.manual_adj_data()
self.dump_data()
@abc.abstractmethod
def _get_calendar_list(self):
"""Get benchmark calendar"""
raise NotImplementedError("")
class YahooNormalizeUS(YahooNormalize):
def _get_calendar_list(self):
# TODO: from MSN
return get_calendar_list("US_ALL")
class YahooNormalizeCN(YahooNormalize):
def _get_calendar_list(self):
# TODO: from MSN
return get_calendar_list("ALL")
class Run:
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
region: str
region, value from ["CN", "US"], default "CN"
"""
if source_dir is None:
source_dir = CUR_DIR.joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
self._cur_module = importlib.import_module("collector")
self.max_workers = max_workers
self.region = region
def download_data(
self,
max_collector_count=5,
delay=0,
start=None,
end=None,
interval="1d",
check_data_length=False,
limit_nums=None,
):
"""download data from Internet
Parameters
----------
max_collector_count: int
default 5
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1m, 1d], default 1m
start: str
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check data length, by default False
limit_nums: int
using for debug, by default None
Examples
---------
# get daily data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
_class = getattr(self._cur_module, f"YahooCollector{self.region.upper()}")
_class(
self.source_dir,
max_workers=self.max_workers,
max_collector_count=max_collector_count,
delay=delay,
start=start,
end=end,
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
).collector_data()
def normalize_data(self):
"""normalize data
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN
"""
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}")
_class(self.source_dir, self.normalize_dir, self.max_workers).normalize()
def collector_data(
self,
max_collector_count=5,
delay=0,
start=None,
end=None,
interval="1d",
check_data_length=False,
limit_nums=None,
):
"""download -> normalize
Parameters
----------
max_collector_count: int
default 5
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1m, 1d], default 1m
start: str
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check data length, by default False
limit_nums: int
using for debug, by default None
Examples
-------
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
"""
self.download_data(
max_collector_count=max_collector_count,
delay=delay,
start=start,
end=end,
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
)
self.normalize_data()
if __name__ == "__main__":

View File

@@ -1,10 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import shutil
import traceback
from pathlib import Path
from typing import Iterable, List, Union
from functools import partial
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor
import fire
import numpy as np
@@ -13,8 +16,20 @@ from tqdm import tqdm
from loguru import logger
class DumpData(object):
FILE_SUFFIX = ".csv"
class DumpDataBase:
INSTRUMENTS_START_FIELD = "start_datetime"
INSTRUMENTS_END_FIELD = "end_datetime"
CALENDARS_DIR_NAME = "calendars"
FEATURES_DIR_NAME = "features"
INSTRUMENTS_DIR_NAME = "instruments"
DUMP_FILE_SUFFIX = ".bin"
DAILY_FORMAT = "%Y-%m-%d"
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
INSTRUMENTS_SEP = "\t"
INSTRUMENTS_FILE_NAME = "all.txt"
UPDATE_MODE = "update"
ALL_MODE = "all"
def __init__(
self,
@@ -22,8 +37,13 @@ class DumpData(object):
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
works: int = None,
max_workers: int = 16,
date_field_name: str = "date",
file_suffix: str = ".csv",
symbol_field_name: str = "symbol",
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
):
"""
@@ -37,80 +57,101 @@ class DumpData(object):
if backup_dir is not None, backup qlib_dir to backup_dir
freq: str, default "day"
transaction frequency
works: int, default None
max_workers: int, default None
number of threads
date_field_name: str, default "date"
the name of the date field in the csv
file_suffix: str, default ".csv"
file suffix
symbol_field_name: str, default "symbol"
symbol field name
include_fields: tuple
dump fields
exclude_fields: tuple
fields not dumped
limit_nums: int
Use when debugging, default None
"""
csv_path = Path(csv_path).expanduser()
self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path])
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
include_fields = include_fields.split(",")
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
if limit_nums is not None:
self.csv_files = self.csv_files[: int(limit_nums)]
self.qlib_dir = Path(qlib_dir).expanduser()
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
if backup_dir is not None:
self._backup_qlib_dir(Path(backup_dir).expanduser())
self.freq = freq
self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S"
self.calendar_format = self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT
self.works = works
self.works = max_workers
self.date_field_name = date_field_name
self._calendars_dir = self.qlib_dir.joinpath("calendars")
self._features_dir = self.qlib_dir.joinpath("features")
self._instruments_dir = self.qlib_dir.joinpath("instruments")
self._calendars_dir = self.qlib_dir.joinpath(self.CALENDARS_DIR_NAME)
self._features_dir = self.qlib_dir.joinpath(self.FEATURES_DIR_NAME)
self._instruments_dir = self.qlib_dir.joinpath(self.INSTRUMENTS_DIR_NAME)
self._calendars_list = []
self._include_fields = ()
self._exclude_fields = ()
self._mode = self.ALL_MODE
self._kwargs = {}
def _backup_qlib_dir(self, target_dir: Path):
shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve()))
def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False):
df = pd.read_csv(str(file_path.resolve()))
if df.empty or self.date_field_name not in df.columns.tolist():
return []
if is_begin_end:
return [df[self.date_field_name].min(), df[self.date_field_name].max()]
return df[self.date_field_name].tolist()
def _format_datetime(self, datetime_d: [str, pd.Timestamp]):
datetime_d = pd.Timestamp(datetime_d)
return datetime_d.strftime(self.calendar_format)
def _get_source_data(self, file_path: Path):
df = pd.read_csv(str(file_path.resolve()))
def _get_date(
self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False
) -> Iterable[pd.Timestamp]:
if not isinstance(file_or_df, pd.DataFrame):
df = self._get_source_data(file_or_df)
else:
df = file_or_df
if df.empty or self.date_field_name not in df.columns.tolist():
_calendars = pd.Series()
else:
_calendars = df[self.date_field_name]
if is_begin_end and as_set:
return (_calendars.min(), _calendars.max()), set(_calendars)
elif is_begin_end:
return _calendars.min(), _calendars.max()
elif as_set:
return set(_calendars)
else:
return _calendars.tolist()
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64)
# df.drop_duplicates([self.date_field_name], inplace=True)
return df
def _file_to_bin(self, file_path: Path = None):
code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower()
features_dir = self._features_dir.joinpath(code)
features_dir.mkdir(parents=True, exist_ok=True)
calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name])
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
# read csv file
df = self._get_source_data(file_path)
cal_df = calendars_df[
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
]
cal_df.set_index(self.date_field_name, inplace=True)
df.set_index(self.date_field_name, inplace=True)
r_df = df.reindex(cal_df.index)
date_index = self._calendars_list.index(r_df.index.min())
for field in (
def get_symbol_from_file(self, file_path: Path) -> str:
return file_path.name[: -len(self.file_suffix)].strip().lower()
def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
self._include_fields
if self._include_fields
else set(r_df.columns) - set(self._exclude_fields)
else set(df_columns) - set(self._exclude_fields)
if self._exclude_fields
else r_df.columns
):
bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin")
if field not in r_df.columns:
continue
r = np.hstack([date_index, r_df[field]]).astype("<f")
r.tofile(str(bin_path.resolve()))
else df_columns
)
@staticmethod
def _read_calendar(calendar_path: Path):
def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]:
return sorted(
map(
pd.Timestamp,
@@ -118,133 +159,305 @@ class DumpData(object):
)
)
def dump_features(
self,
calendar_path: str = None,
include_fields: tuple = None,
exclude_fields: tuple = None,
):
"""dump features
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
return pd.read_csv(
instrument_path,
sep=self.INSTRUMENTS_SEP,
names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD],
)
Parameters
---------
calendar_path: str
calendar path
def save_calendars(self, calendars_data: list):
self._calendars_dir.mkdir(parents=True, exist_ok=True)
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data))
np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8")
include_fields: str
dump fields
def save_instruments(self, instruments_data: Union[list, pd.DataFrame]):
self._instruments_dir.mkdir(parents=True, exist_ok=True)
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
if isinstance(instruments_data, pd.DataFrame):
instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]]
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
else:
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
exclude_fields: str
fields not dumped
def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame:
# calendars
calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name])
calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64)
cal_df = calendars_df[
(calendars_df[self.date_field_name] >= df[self.date_field_name].min())
& (calendars_df[self.date_field_name] <= df[self.date_field_name].max())
]
# align index
cal_df.set_index(self.date_field_name, inplace=True)
df.set_index(self.date_field_name, inplace=True)
r_df = df.reindex(cal_df.index)
return r_df
Notes
---------
python dump_bin.py dump_features --csv_path <stock data directory or path> --qlib_dir <dump data directory>
@staticmethod
def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int:
return calendar_list.index(df.index.min())
Examples
---------
def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path):
if df.empty:
logger.warning(f"{features_dir.name} data is None or empty")
return
# align index
_df = self.data_merge_calendar(df, self._calendars_list)
date_index = self.get_datetime_index(_df, calendar_list)
for field in self.get_dump_fields(_df.columns):
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
if field not in _df.columns:
continue
if self._mode == self.UPDATE_MODE:
# update
with bin_path.open("ab") as fp:
np.array(_df[field]).astype("<f").tofile(fp)
elif self._mode == self.ALL_MODE:
np.hstack([date_index, _df[field]]).astype("<f").tofile(str(bin_path.resolve()))
else:
raise ValueError(f"{self._mode} cannot support!")
# dump all stock
python dump_bin.py dump_features --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
# dump one stock
python dump_bin.py dump_features --csv_path ~/tmp/stock_data/sh600000.csv --qlib_dir ~/tmp/qlib_data --calendar_path ~/tmp/qlib_data/calendar/all.txt --exclude_fields date,code,timestamp,code_name
"""
logger.info("start dump features......")
if calendar_path is not None:
# read calendar from calendar file
self._calendars_list = self._read_calendar(Path(calendar_path))
def _dump_bin(self, file_or_data: [Path, pd.DataFrame], calendar_list: List[pd.Timestamp]):
if isinstance(file_or_data, pd.DataFrame):
if file_or_data.empty:
return
code = file_or_data.iloc[0][self.symbol_field_name].lower()
df = file_or_data
elif isinstance(file_or_data, Path):
code = self.get_symbol_from_file(file_or_data)
df = self._get_source_data(file_or_data)
else:
raise ValueError(f"not support {type(file_or_data)}")
if df is None or df.empty:
logger.warning(f"{code} data is None or empty")
return
# features save dir
features_dir = self._features_dir.joinpath(code)
features_dir.mkdir(parents=True, exist_ok=True)
self._data_to_bin(df, calendar_list, features_dir)
if not self._calendars_list:
self.dump_calendars()
@abc.abstractmethod
def dump(self):
raise NotImplementedError("dump not implemented!")
self._include_fields = tuple(map(str.strip, include_fields)) if include_fields else self._include_fields
self._exclude_fields = tuple(map(str.strip, exclude_fields)) if exclude_fields else self._exclude_fields
def __call__(self, *args, **kwargs):
self.dump()
class DumpDataAll(DumpDataBase):
def _get_all_date(self):
logger.info("start get all date......")
all_datetime = set()
date_range_list = []
_fun = partial(self._get_date, as_set=True, is_begin_end=True)
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(self._file_to_bin, self.csv_files):
with ProcessPoolExecutor(max_workers=self.works) as executor:
for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
self.csv_files, executor.map(_fun, self.csv_files)
):
all_datetime = all_datetime | _set_calendars
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
_begin_time = self._format_datetime(_begin_time)
_end_time = self._format_datetime(_end_time)
symbol = self.get_symbol_from_file(file_path)
date_range_list.append(f"{self.INSTRUMENTS_SEP.join((symbol.upper(), _begin_time, _end_time))}")
p_bar.update()
self._kwargs["all_datetime_set"] = all_datetime
self._kwargs["date_range_list"] = date_range_list
logger.info("end of get all date.\n")
def _dump_calendars(self):
logger.info("start dump calendars......")
self._calendars_list = sorted(map(pd.Timestamp, self._kwargs["all_datetime_set"]))
self.save_calendars(self._calendars_list)
logger.info("end of calendars dump.\n")
def _dump_instruments(self):
logger.info("start dump instruments......")
self.save_instruments(self._kwargs["date_range_list"])
logger.info("end of instruments dump.\n")
def _dump_features(self):
logger.info("start dump features......")
_dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
with tqdm(total=len(self.csv_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(_dump_func, self.csv_files):
p_bar.update()
logger.info("end of features dump.\n")
def dump_calendars(self):
"""dump calendars
def dump(self):
self._get_all_date()
self._dump_calendars()
self._dump_instruments()
self._dump_features()
Notes
---------
python dump_bin.py dump_calendars --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
"""
logger.info("start dump calendars......")
calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
all_datetime = set()
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for temp_datetime in executor.map(self._get_date_for_df, self.csv_files):
all_datetime = all_datetime | set(temp_datetime)
p_bar.update()
self._calendars_list = sorted(map(pd.Timestamp, all_datetime))
self._calendars_dir.mkdir(parents=True, exist_ok=True)
result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list))
np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8")
logger.info("end of calendars dump.\n")
def dump_instruments(self):
"""dump instruments
Notes
---------
python dump_bin.py dump_instruments --csv_path <stock data directory or path> --qlib_dir <dump data directory>
Examples
---------
python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data
"""
class DumpDataFix(DumpDataAll):
def _dump_instruments(self):
logger.info("start dump instruments......")
symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files))
_result_list = []
_fun = partial(self._get_date_for_df, is_begin_end=True)
with tqdm(total=len(symbol_list)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as execute:
for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)):
if res:
begin_time = res[0]
end_time = res[-1]
_result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}")
_fun = partial(self._get_date, is_begin_end=True)
new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files))
with tqdm(total=len(new_stock_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as execute:
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
symbol = self.get_symbol_from_file(file_path).upper()
_dt_map = self._old_instruments.setdefault(symbol, dict())
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
p_bar.update()
self._instruments_dir.mkdir(parents=True, exist_ok=True)
to_path = str(self._instruments_dir.joinpath("all.txt").resolve())
np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8")
self.save_instruments(pd.DataFrame.from_dict(self._old_instruments, orient="index"))
logger.info("end of instruments dump.\n")
def dump(self, include_fields: str = None, exclude_fields: tuple = None):
"""dump data
def dump(self):
self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
# noinspection PyAttributeOutsideInit
self._old_instruments = self._read_instruments(
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
).to_dict(
orient="index"
) # type: dict
self._dump_instruments()
self._dump_features()
class DumpDataUpdate(DumpDataBase):
def __init__(
self,
csv_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
max_workers: int = 16,
date_field_name: str = "date",
file_suffix: str = ".csv",
symbol_field_name: str = "symbol",
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
):
"""
Parameters
----------
include_fields: str
csv_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
backup_dir: str, default None
if backup_dir is not None, backup qlib_dir to backup_dir
freq: str, default "day"
transaction frequency
max_workers: int, default None
number of threads
date_field_name: str, default "date"
the name of the date field in the csv
file_suffix: str, default ".csv"
file suffix
symbol_field_name: str, default "symbol"
symbol field name
include_fields: tuple
dump fields
exclude_fields: str
exclude_fields: tuple
fields not dumped
Examples
---------
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor
python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name
limit_nums: int
Use when debugging, default None
"""
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
include_fields = include_fields.split(",")
self.dump_calendars()
self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields)
self.dump_instruments()
super().__init__(
csv_path,
qlib_dir,
backup_dir,
freq,
max_workers,
date_field_name,
file_suffix,
symbol_field_name,
exclude_fields,
include_fields,
)
self._mode = self.UPDATE_MODE
self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
self._update_instruments = self._read_instruments(
self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
).to_dict(
orient="index"
) # type: dict
# load all csv files
self._all_data = self._load_all_source_data() # type: pd.DataFrame
self._update_calendars = sorted(
filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique())
)
self._new_calendar_list = self._old_calendar_list + self._update_calendars
def _load_all_source_data(self):
# NOTE: Need more memory
logger.info("start load all source data....")
all_df = []
def _read_csv(file_path: Path):
if self._include_fields:
_df = pd.read_csv(file_path, usecols=self._include_fields)
else:
_df = pd.read_csv(file_path)
if self.symbol_field_name not in _df.columns:
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
return _df
with tqdm(total=len(self.csv_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for df in executor.map(_read_csv, self.csv_files):
if df:
all_df.append(df)
p_bar.update()
logger.info("end of load all data.\n")
return pd.concat(all_df, sort=False)
def _dump_calendars(self):
pass
def _dump_instruments(self):
pass
def _dump_features(self):
logger.info("start dump features......")
error_code = {}
with ProcessPoolExecutor(max_workers=self.works) as executor:
futures = {}
for _code, _df in self._all_data.groupby(self.symbol_field_name):
_code = str(_code).upper()
_start, _end = self._get_date(_df, is_begin_end=True)
if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)):
continue
if _code in self._update_instruments:
self._update_instruments[_code]["end_time"] = _end
futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code
else:
# new stock
_dt_range = self._update_instruments.setdefault(_code, dict())
_dt_range["start_time"] = _start
_dt_range["end_time"] = _end
futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code
for _future in tqdm(as_completed(futures)):
try:
_future.result()
except Exception:
error_code[futures[_future]] = traceback.format_exc()
logger.info(f"dump bin errors {error_code}")
logger.info("end of features dump.\n")
def dump(self):
self.save_calendars(self._new_calendar_list)
self._dump_features()
self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index"))
if __name__ == "__main__":
fire.Fire(DumpData)
fire.Fire({"dump_all": DumpDataAll, "dump_fix": DumpDataFix, "dump_update": DumpDataUpdate})

View File

@@ -55,7 +55,7 @@ class GetData:
for _file in tqdm(zp.namelist()):
zp.extract(_file, str(target_dir.resolve()))
def qlib_data_cn(self, name="qlib_data_cn", target_dir="~/.qlib/qlib_data/cn_data", version="latest"):
def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"):
"""download cn qlib data from remote
Parameters
@@ -63,18 +63,25 @@ class GetData:
target_dir: str
data save directory
name: str
dataset name, value from [qlib_data_cn, qlib_data_cn_simple], by default qlib_data_cn
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
version: str
data version, value from [v0, v1, ..., latest], by default latest
interval: str
data freq, value from [1d], by default 1d
region: str
data region, value from [cn, us], by default cn
Examples
---------
python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version latest
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
-------
"""
file_name = f"{name}_{version}.zip"
self._download_data(file_name, target_dir)
# TODO: The US stock code contains "PRN", and the directory cannot be created on Windows system
if region.lower() == "us":
logger.warning(f"The US stock code contains 'PRN', and the directory cannot be created on Windows system")
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
self._download_data(file_name.lower(), target_dir)
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
"""download cn csv data from remote

View File

@@ -18,7 +18,7 @@ class TestDataset(unittest.TestCase):
sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri)
GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def testCSI300(self):

View File

@@ -149,7 +149,7 @@ class TestAllFlow(unittest.TestCase):
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri)
GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def test_0_train(self):

View File

@@ -14,7 +14,7 @@ from qlib.data import D
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
from dump_bin import DumpData
from dump_bin import DumpDataAll, DumpDataFix
DATA_DIR = Path(__file__).parent.joinpath("test_dump_data")
@@ -36,7 +36,7 @@ class TestDumpData(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
GetData().csv_data_cn(SOURCE_DIR)
TestDumpData.DUMP_DATA = DumpData(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR)
TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
provider_uri = str(QLIB_DIR.resolve())
qlib.init(
@@ -49,8 +49,10 @@ class TestDumpData(unittest.TestCase):
def tearDownClass(cls) -> None:
shutil.rmtree(str(DATA_DIR.resolve()))
def test_0_dump_calendars(self):
self.DUMP_DATA.dump_calendars()
def test_0_dump_bin(self):
self.DUMP_DATA.dump()
def test_1_dump_calendars(self):
ori_calendars = set(
map(
pd.Timestamp,
@@ -60,23 +62,21 @@ class TestDumpData(unittest.TestCase):
res_calendars = set(D.calendar())
assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed"
def test_1_dump_instruments(self):
self.DUMP_DATA.dump_instruments()
def test_2_dump_instruments(self):
ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
res_ins = set(D.list_instruments(D.instruments("all"), as_list=True))
assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, "dump instruments failed"
def test_2_dump_features(self):
self.DUMP_DATA.dump_features(include_fields=self.FIELDS)
def test_3_dump_features(self):
df = D.features(self.STOCK_NAMES, self.QLIB_FIELDS)
TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.STOCK_NAMES[0], :]
self.assertFalse(df.dropna().empty, "features data failed")
self.assertListEqual(list(df.columns), self.QLIB_FIELDS, "features columns failed")
def test_3_dump_features_simple(self):
def test_4_dump_features_simple(self):
stock = self.STOCK_NAMES[0]
dump_data = DumpData(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR)
dump_data.dump_features(include_fields=self.FIELDS, calendar_path=QLIB_DIR.joinpath("calendars", "day.txt"))
dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS)
dump_data.dump()
df = D.features([stock], self.QLIB_FIELDS)

View File

@@ -37,7 +37,7 @@ class TestGetData(unittest.TestCase):
def test_0_qlib_data(self):
GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=QLIB_DIR)
GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", version="latest")
df = D.features(D.instruments("csi300"), self.FIELDS)
self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed")
self.assertFalse(df.dropna().empty, "get qlib data failed")