mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
support US index
This commit is contained in:
22
scripts/data_collector/cn_index/README.md
Normal file
22
scripts/data_collector/cn_index/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# CSI300/CSI100 History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
# index_name support: CSI300, CSI100
|
||||
# help
|
||||
python collector.py --help
|
||||
```
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
import re
|
||||
import abc
|
||||
import sys
|
||||
import bisect
|
||||
import importlib
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
@@ -16,7 +17,9 @@ from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
@@ -24,64 +27,48 @@ NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
|
||||
class CSIIndex:
|
||||
|
||||
REMOVE = "remove"
|
||||
ADD = "add"
|
||||
|
||||
def __init__(self, qlib_dir=None):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
"""
|
||||
|
||||
if qlib_dir is None:
|
||||
qlib_dir = CUR_DIR.joinpath("qlib_data")
|
||||
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
|
||||
self.instruments_dir.mkdir(exist_ok=True, parents=True)
|
||||
self._calendar_list = None
|
||||
|
||||
self.cache_dir = Path("~/.cache/csi").expanduser().resolve()
|
||||
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
class CSIIndex(IndexBase):
|
||||
@property
|
||||
def calendar_list(self) -> list:
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
return get_calendar_list(bench_code=self.index_name.upper())
|
||||
|
||||
@property
|
||||
def new_companies_url(self):
|
||||
def new_companies_url(self) -> str:
|
||||
return NEW_COMPANIES_URL.format(index_code=self.index_code)
|
||||
|
||||
@property
|
||||
def changes_url(self):
|
||||
def changes_url(self) -> str:
|
||||
return INDEX_CHANGES_URL
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
raise NotImplementedError()
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_code(self):
|
||||
raise NotImplementedError()
|
||||
def index_code(self) -> str:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index code
|
||||
"""
|
||||
raise NotImplementedError("rewrite index_code")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def html_table_index(self):
|
||||
def html_table_index(self) -> int:
|
||||
"""Which table of changes in html
|
||||
|
||||
CSI300: 0
|
||||
@@ -90,33 +77,19 @@ class CSIIndex:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
|
||||
"""get trading date by shift
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shift : int
|
||||
shift, default is 1
|
||||
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
left_index = bisect.bisect_left(self.calendar_list, trading_date)
|
||||
try:
|
||||
res = self.calendar_list[left_index + shift]
|
||||
except IndexError:
|
||||
res = trading_date
|
||||
return res
|
||||
|
||||
def _get_changes(self) -> pd.DataFrame:
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
logger.info("get companies changes......")
|
||||
res = []
|
||||
@@ -124,10 +97,21 @@ class CSIIndex:
|
||||
_df = self._read_change_from_url(_url)
|
||||
res.append(_df)
|
||||
logger.info("get companies changes finish")
|
||||
return pd.concat(res)
|
||||
return pd.concat(res, sort=False)
|
||||
|
||||
@staticmethod
|
||||
def normalize_symbol(symbol):
|
||||
def normalize_symbol(symbol: str) -> str:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
|
||||
Returns
|
||||
-------
|
||||
symbol
|
||||
"""
|
||||
symbol = f"{int(symbol):06}"
|
||||
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
|
||||
|
||||
@@ -141,7 +125,14 @@ class CSIIndex:
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
resp = requests.get(url)
|
||||
_text = resp.text
|
||||
@@ -151,8 +142,8 @@ class CSIIndex:
|
||||
add_date = pd.Timestamp("-".join(date_list[0]))
|
||||
else:
|
||||
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
|
||||
add_date = self._get_trading_date_by_shift(_date, shift=0)
|
||||
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
|
||||
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
|
||||
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
|
||||
logger.info(f"get {add_date} changes")
|
||||
try:
|
||||
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
|
||||
@@ -168,12 +159,12 @@ class CSIIndex:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = ["symbol"]
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df["type"] = _type
|
||||
_df["date"] = _date
|
||||
_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_df)
|
||||
df = pd.concat(tmp)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
df = None
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(resp.content):
|
||||
@@ -188,9 +179,9 @@ class CSIIndex:
|
||||
(_df.iloc[2:, 2], self.ADD, add_date),
|
||||
]:
|
||||
_tmp_df = pd.DataFrame()
|
||||
_tmp_df["symbol"] = _s.map(self.normalize_symbol)
|
||||
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
|
||||
_tmp_df["type"] = _type
|
||||
_tmp_df["date"] = _date
|
||||
_tmp_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
df.to_csv(
|
||||
@@ -203,20 +194,33 @@ class CSIIndex:
|
||||
break
|
||||
return df
|
||||
|
||||
def _get_change_notices_url(self) -> list:
|
||||
def _get_change_notices_url(self) -> List[str]:
|
||||
"""get change notices url
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
[url1, url2]
|
||||
"""
|
||||
resp = requests.get(self.changes_url)
|
||||
html = etree.HTML(resp.text)
|
||||
return html.xpath("//*[@id='itemContainer']//li/a/@href")
|
||||
|
||||
def _get_new_companies(self):
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
logger.info("get new companies")
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
logger.info("get new companies......")
|
||||
context = requests.get(self.new_companies_url).content
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
|
||||
@@ -225,51 +229,19 @@ class CSIIndex:
|
||||
_io = BytesIO(context)
|
||||
df = pd.read_excel(_io)
|
||||
df = df.iloc[:, [0, 4]]
|
||||
df.columns = ["end_date", "symbol"]
|
||||
df["symbol"] = df["symbol"].map(self.normalize_symbol)
|
||||
df["end_date"] = pd.to_datetime(df["end_date"])
|
||||
df["start_date"] = self.bench_start_date
|
||||
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
||||
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD])
|
||||
df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
logger.info("end of get new companies.")
|
||||
return df
|
||||
|
||||
def parse_instruments(self):
|
||||
"""parse csi300.txt
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = ["symbol", "start_date", "end_date"]
|
||||
changers_df = self._get_changes()
|
||||
new_df = self._get_new_companies()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in changers_df.sort_values("date", ascending=False).itertuples(index=False):
|
||||
if _row.type == self.ADD:
|
||||
min_end_date = new_df.loc[new_df["symbol"] == _row.symbol, "end_date"].min()
|
||||
new_df.loc[
|
||||
(new_df["end_date"] == min_end_date) & (new_df["symbol"] == _row.symbol), "start_date"
|
||||
] = _row.date
|
||||
else:
|
||||
_tmp_df = pd.DataFrame(
|
||||
[[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"]
|
||||
)
|
||||
new_df = new_df.append(_tmp_df, sort=False)
|
||||
|
||||
new_df.loc[:, instruments_columns].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
logger.info(f"parse {self.index_name.lower()} companies finished.")
|
||||
|
||||
|
||||
class CSI300(CSIIndex):
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000300"
|
||||
|
||||
@property
|
||||
def index_name(self):
|
||||
return "csi300"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2005-01-01")
|
||||
@@ -284,10 +256,6 @@ class CSI100(CSIIndex):
|
||||
def index_code(self):
|
||||
return "000903"
|
||||
|
||||
@property
|
||||
def index_name(self):
|
||||
return "csi100"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2006-05-29")
|
||||
@@ -297,19 +265,39 @@ class CSI100(CSIIndex):
|
||||
return 1
|
||||
|
||||
|
||||
def parse_instruments(qlib_dir: str):
|
||||
def get_instruments(
|
||||
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
index_name: str
|
||||
index name, value from ["csi100", "csi300"]
|
||||
method: str
|
||||
method, value from ["parse_instruments", "save_new_companies"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
qlib_dir.mkdir(exist_ok=True, parents=True)
|
||||
CSI300(qlib_dir).parse_instruments()
|
||||
CSI100(qlib_dir).parse_instruments()
|
||||
_cur_module = importlib.import_module("collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(parse_instruments)
|
||||
fire.Fire(get_instruments)
|
||||
@@ -1,14 +0,0 @@
|
||||
# CSI300 History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
python collector.py parse_instruments --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
```
|
||||
|
||||
202
scripts/data_collector/index.py
Normal file
202
scripts/data_collector/index.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import sys
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent))
|
||||
|
||||
|
||||
from data_collector.utils import get_trading_date_by_shift
|
||||
|
||||
|
||||
class IndexBase:
|
||||
DEFAULT_END_DATE = pd.Timestamp("2099-12-31")
|
||||
SYMBOL_FIELD_NAME = "symbol"
|
||||
DATE_FIELD_NAME = "date"
|
||||
START_DATE_FIELD = "start_date"
|
||||
END_DATE_FIELD = "end_ate"
|
||||
CHANGE_TYPE_FIELD = "type"
|
||||
INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD]
|
||||
REMOVE = "remove"
|
||||
ADD = "add"
|
||||
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index_name: str
|
||||
index name
|
||||
qlib_dir: str
|
||||
qlib directory, by default Path(__file__).resolve().parent.joinpath("qlib_data")
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
"""
|
||||
self.index_name = index_name
|
||||
if qlib_dir is None:
|
||||
qlib_dir = Path(__file__).resolve().parent.joinpath("qlib_data")
|
||||
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
|
||||
self.instruments_dir.mkdir(exist_ok=True, parents=True)
|
||||
self.cache_dir = Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve()
|
||||
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
self._request_retry = request_retry
|
||||
self._retry_sleep = retry_sleep
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
raise NotImplementedError("rewrite calendar_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
start_date: pd.Timestamp
|
||||
end_date: pd.Timestamp
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_new_companies")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_changes")
|
||||
|
||||
def save_new_companies(self):
|
||||
"""save new companies
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
df = self.get_new_companies()
|
||||
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
|
||||
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
|
||||
def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame:
|
||||
"""get changes with history companies
|
||||
|
||||
Parameters
|
||||
----------
|
||||
history_companies : pd.DataFrame
|
||||
symbol date
|
||||
SH600000 2020-11-11
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
|
||||
Return
|
||||
--------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
|
||||
"""
|
||||
logger.info("parse changes from history companies......")
|
||||
last_code = []
|
||||
result_df_list = []
|
||||
_columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD]
|
||||
for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)):
|
||||
_currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][
|
||||
self.SYMBOL_FIELD_NAME
|
||||
].tolist()
|
||||
if last_code:
|
||||
add_code = list(set(last_code) - set(_currenet_code))
|
||||
remote_code = list(set(_currenet_code) - set(last_code))
|
||||
for _code in add_code:
|
||||
result_df_list.append(
|
||||
pd.DataFrame(
|
||||
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]],
|
||||
columns=_columns,
|
||||
)
|
||||
)
|
||||
for _code in remote_code:
|
||||
result_df_list.append(
|
||||
pd.DataFrame(
|
||||
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]],
|
||||
columns=_columns,
|
||||
)
|
||||
)
|
||||
last_code = _currenet_code
|
||||
df = pd.concat(result_df_list)
|
||||
logger.info("end of parse changes from history companies.")
|
||||
return df
|
||||
|
||||
def parse_instruments(self):
|
||||
"""parse instruments, eg: csi300.txt
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
changers_df = self.get_changes()
|
||||
new_df = self.get_new_companies().copy()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
|
||||
if _row.type == self.ADD:
|
||||
min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min()
|
||||
new_df.loc[
|
||||
(new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol),
|
||||
self.START_DATE_FIELD,
|
||||
] = _row.date
|
||||
else:
|
||||
_tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)
|
||||
new_df = new_df.append(_tmp_df, sort=False)
|
||||
|
||||
new_df.loc[:, instruments_columns].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
logger.info(f"parse {self.index_name.lower()} companies finished.")
|
||||
22
scripts/data_collector/us_index/README.md
Normal file
22
scripts/data_collector/us_index/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# NASDAQ100/SP500/SP400/DJIA History Companies Collection
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
# index_name support: SP500, NASDAQ100, DJIA, SP400
|
||||
# help
|
||||
python collector.py --help
|
||||
```
|
||||
|
||||
278
scripts/data_collector/us_index/collector.py
Normal file
278
scripts/data_collector/us_index/collector.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
from data_collector.index import IndexBase
|
||||
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
|
||||
|
||||
|
||||
WIKI_URL = "https://en.wikipedia.org/wiki"
|
||||
|
||||
WIKI_INDEX_NAME_MAP = {
|
||||
"NASDAQ100": "NASDAQ-100",
|
||||
"SP500": "List_of_S%26P_500_companies",
|
||||
"SP400": "List_of_S%26P_400_companies",
|
||||
"DJIA": "Dow_Jones_Industrial_Average",
|
||||
}
|
||||
|
||||
|
||||
class WIKIIndex(IndexBase):
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
super(WIKIIndex, self).__init__(
|
||||
index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
symbol date type
|
||||
SH600000 2019-11-11 add
|
||||
SH600000 2020-11-10 remove
|
||||
dtypes:
|
||||
symbol: str
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_changes")
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
|
||||
Returns
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
_calendar_list = getattr(self, "_calendar_list", None)
|
||||
if _calendar_list is None:
|
||||
_calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL")))
|
||||
setattr(self, "_calendar_list", _calendar_list)
|
||||
return _calendar_list
|
||||
|
||||
def _request_new_companies(self) -> requests.Response:
|
||||
resp = requests.get(self._target_url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {self._target_url}")
|
||||
|
||||
return resp
|
||||
|
||||
def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
_df = df.copy()
|
||||
_df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip()
|
||||
_df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
_df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE
|
||||
return _df.loc[:, self.INSTRUMENTS_COLUMNS]
|
||||
|
||||
def get_new_companies(self):
|
||||
logger.info(f"get new companies {self.index_name} ......")
|
||||
_data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()
|
||||
df_list = pd.read_html(_data.text)
|
||||
for _df in df_list:
|
||||
_df = self.filter_df(_df)
|
||||
if (_df is not None) and (not _df.empty):
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df = self.set_default_date_range(_df)
|
||||
logger.info(f"end of get new companies {self.index_name} ......")
|
||||
return _df
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
raise NotImplementedError("rewrite filter_df")
|
||||
|
||||
|
||||
class NASDAQ100Index(WIKIIndex):
|
||||
|
||||
HISTORY_COMPANIES_URL = (
|
||||
"https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD"
|
||||
)
|
||||
MAX_WORKERS = 16
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if not (set(df.columns) - {"Company", "Ticker"}):
|
||||
return df.loc[:, ["Ticker"]].copy()
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2003-01-02")
|
||||
|
||||
@deco_retry
|
||||
def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame:
|
||||
trade_date = trade_date.strftime("%Y-%m-%d")
|
||||
cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl")
|
||||
if cache_path.exists() and use_cache:
|
||||
df = pd.read_pickle(cache_path)
|
||||
else:
|
||||
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
|
||||
resp = requests.post(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(f"request error: {url}")
|
||||
df = pd.DataFrame(resp.json()["aaData"])
|
||||
df[self.DATE_FIELD_NAME] = trade_date
|
||||
df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True)
|
||||
if not df.empty:
|
||||
df.to_pickle(cache_path)
|
||||
return df
|
||||
|
||||
def get_history_companies(self):
|
||||
logger.info(f"start get history companies......")
|
||||
all_history = []
|
||||
error_list = []
|
||||
with tqdm(total=len(self.calendar_list)) as p_bar:
|
||||
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
||||
for _trading_date, _df in zip(
|
||||
self.calendar_list, executor.map(self._request_history_companies, self.calendar_list)
|
||||
):
|
||||
if _df.empty:
|
||||
error_list.append(_trading_date)
|
||||
else:
|
||||
all_history.append(_df)
|
||||
p_bar.update()
|
||||
|
||||
if error_list:
|
||||
logger.warning(f"get error: {error_list}")
|
||||
logger.info(f"total {len(self.calendar_list)}, error {len(error_list)}")
|
||||
logger.info(f"end of get history companies.")
|
||||
return pd.concat(all_history, sort=False)
|
||||
|
||||
def get_changes(self):
|
||||
return self.get_changes_with_history_companies(self.get_history_companies())
|
||||
|
||||
|
||||
class DJIAIndex(WIKIIndex):
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2000-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
pass
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Symbol" in df.columns:
|
||||
_df = df.loc[:, ["Symbol"]].copy()
|
||||
_df["Symbol"] = _df["Symbol"].apply(lambda x: x.split(":")[-1])
|
||||
return _df
|
||||
|
||||
def parse_instruments(self):
|
||||
logger.warning(f"No suitable data source has been found!")
|
||||
|
||||
|
||||
class SP500Index(WIKIIndex):
|
||||
WIKISP500_CHANGES_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
|
||||
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("1999-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
logger.info(f"get sp500 history changes......")
|
||||
# NOTE: may update the index of the table
|
||||
changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1]
|
||||
changes_df = changes_df.iloc[:, [0, 1, 3]]
|
||||
changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]
|
||||
changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])
|
||||
_result = []
|
||||
for _type in [self.ADD, self.REMOVE]:
|
||||
_df = changes_df.copy()
|
||||
_df[self.CHANGE_TYPE_FIELD] = _type
|
||||
_df[self.SYMBOL_FIELD_NAME] = _df[_type]
|
||||
_df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True)
|
||||
if _type == self.ADD:
|
||||
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
|
||||
lambda x: get_trading_date_by_shift(self.calendar_list, x, 0)
|
||||
)
|
||||
else:
|
||||
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
|
||||
lambda x: get_trading_date_by_shift(self.calendar_list, x, -1)
|
||||
)
|
||||
_result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]])
|
||||
logger.info(f"end of get sp500 history changes.")
|
||||
return pd.concat(_result, sort=False)
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Symbol" in df.columns:
|
||||
return df.loc[:, ["Symbol"]].copy()
|
||||
|
||||
|
||||
class SP400Index(WIKIIndex):
|
||||
@property
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
return pd.Timestamp("2000-01-01")
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
pass
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if "Ticker symbol" in df.columns:
|
||||
return df.loc[:, ["Ticker symbol"]].copy()
|
||||
|
||||
def parse_instruments(self):
|
||||
logger.warning(f"No suitable data source has been found!")
|
||||
|
||||
|
||||
def get_instruments(
|
||||
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str
|
||||
qlib data dir, default "Path(__file__).parent/qlib_data"
|
||||
index_name: str
|
||||
index name, value from ["SP500", "NASDAQ100", "DJIA", "SP400"]
|
||||
method: str
|
||||
method, value from ["parse_instruments", "save_new_companies"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(get_instruments)
|
||||
6
scripts/data_collector/us_index/requirements.txt
Normal file
6
scripts/data_collector/us_index/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
lxml
|
||||
loguru
|
||||
Reference in New Issue
Block a user