1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

support US index

This commit is contained in:
zhupr
2020-11-16 16:29:53 +08:00
committed by you-n-g
parent ae300592a0
commit c6d557c33e
8 changed files with 639 additions and 135 deletions

View File

@@ -0,0 +1,22 @@
# CSI300/CSI100 History Companies Collection
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
# parse instruments, using in qlib/instruments.
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
# index_name support: CSI300, CSI100
# help
python collector.py --help
```

View File

@@ -4,8 +4,9 @@
import re
import abc
import sys
import bisect
import importlib
from io import BytesIO
from typing import List
from pathlib import Path
import fire
@@ -16,7 +17,9 @@ from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.utils import get_hs_calendar_list as get_calendar_list
from data_collector.index import IndexBase
from data_collector.utils import get_calendar_list, get_trading_date_by_shift
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
@@ -24,64 +27,48 @@ NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
class CSIIndex:
REMOVE = "remove"
ADD = "add"
def __init__(self, qlib_dir=None):
"""
Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
"""
if qlib_dir is None:
qlib_dir = CUR_DIR.joinpath("qlib_data")
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
self.instruments_dir.mkdir(exist_ok=True, parents=True)
self._calendar_list = None
self.cache_dir = Path("~/.cache/csi").expanduser().resolve()
self.cache_dir.mkdir(exist_ok=True, parents=True)
class CSIIndex(IndexBase):
@property
def calendar_list(self) -> list:
def calendar_list(self) -> List[pd.Timestamp]:
"""get history trading date
Returns
-------
calendar list
"""
return get_calendar_list(bench_code=self.index_name.upper())
@property
def new_companies_url(self):
def new_companies_url(self) -> str:
return NEW_COMPANIES_URL.format(index_code=self.index_code)
@property
def changes_url(self):
def changes_url(self) -> str:
return INDEX_CHANGES_URL
@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
raise NotImplementedError()
"""
Returns
-------
index start date
"""
raise NotImplementedError("rewrite bench_start_date")
@property
@abc.abstractmethod
def index_code(self):
raise NotImplementedError()
def index_code(self) -> str:
"""
Returns
-------
index code
"""
raise NotImplementedError("rewrite index_code")
@property
@abc.abstractmethod
def index_name(self):
raise NotImplementedError()
@property
@abc.abstractmethod
def html_table_index(self):
def html_table_index(self) -> int:
"""Which table of changes in html
CSI300: 0
@@ -90,33 +77,19 @@ class CSIIndex:
"""
raise NotImplementedError()
def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1):
"""get trading date by shift
Parameters
----------
shift : int
shift, default is 1
trading_date : pd.Timestamp
trading date
Returns
-------
"""
left_index = bisect.bisect_left(self.calendar_list, trading_date)
try:
res = self.calendar_list[left_index + shift]
except IndexError:
res = trading_date
return res
def _get_changes(self) -> pd.DataFrame:
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
logger.info("get companies changes......")
res = []
@@ -124,10 +97,21 @@ class CSIIndex:
_df = self._read_change_from_url(_url)
res.append(_df)
logger.info("get companies changes finish")
return pd.concat(res)
return pd.concat(res, sort=False)
@staticmethod
def normalize_symbol(symbol):
def normalize_symbol(symbol: str) -> str:
"""
Parameters
----------
symbol: str
symbol
Returns
-------
symbol
"""
symbol = f"{int(symbol):06}"
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
@@ -141,7 +125,14 @@ class CSIIndex:
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
resp = requests.get(url)
_text = resp.text
@@ -151,8 +142,8 @@ class CSIIndex:
add_date = pd.Timestamp("-".join(date_list[0]))
else:
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
add_date = self._get_trading_date_by_shift(_date, shift=0)
remove_date = self._get_trading_date_by_shift(add_date, shift=-1)
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
logger.info(f"get {add_date} changes")
try:
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
@@ -168,12 +159,12 @@ class CSIIndex:
_df = df_map[_s_name]
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
_df = _df.applymap(self.normalize_symbol)
_df.columns = ["symbol"]
_df.columns = [self.SYMBOL_FIELD_NAME]
_df["type"] = _type
_df["date"] = _date
_df[self.DATE_FIELD_NAME] = _date
tmp.append(_df)
df = pd.concat(tmp)
except Exception:
except Exception as e:
df = None
_tmp_count = 0
for _df in pd.read_html(resp.content):
@@ -188,9 +179,9 @@ class CSIIndex:
(_df.iloc[2:, 2], self.ADD, add_date),
]:
_tmp_df = pd.DataFrame()
_tmp_df["symbol"] = _s.map(self.normalize_symbol)
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
_tmp_df["type"] = _type
_tmp_df["date"] = _date
_tmp_df[self.DATE_FIELD_NAME] = _date
tmp.append(_tmp_df)
df = pd.concat(tmp)
df.to_csv(
@@ -203,20 +194,33 @@ class CSIIndex:
break
return df
def _get_change_notices_url(self) -> list:
def _get_change_notices_url(self) -> List[str]:
"""get change notices url
Returns
-------
[url1, url2]
"""
resp = requests.get(self.changes_url)
html = etree.HTML(resp.text)
return html.xpath("//*[@id='itemContainer']//li/a/@href")
def _get_new_companies(self):
def get_new_companies(self) -> pd.DataFrame:
"""
logger.info("get new companies")
Returns
-------
pd.DataFrame:
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
dtypes:
symbol: str
start_date: pd.Timestamp
end_date: pd.Timestamp
"""
logger.info("get new companies......")
context = requests.get(self.new_companies_url).content
with self.cache_dir.joinpath(
f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}"
@@ -225,51 +229,19 @@ class CSIIndex:
_io = BytesIO(context)
df = pd.read_excel(_io)
df = df.iloc[:, [0, 4]]
df.columns = ["end_date", "symbol"]
df["symbol"] = df["symbol"].map(self.normalize_symbol)
df["end_date"] = pd.to_datetime(df["end_date"])
df["start_date"] = self.bench_start_date
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD])
df[self.START_DATE_FIELD] = self.bench_start_date
logger.info("end of get new companies.")
return df
def parse_instruments(self):
"""parse csi300.txt
Examples
-------
$ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data
"""
logger.info(f"start parse {self.index_name.lower()} companies.....")
instruments_columns = ["symbol", "start_date", "end_date"]
changers_df = self._get_changes()
new_df = self._get_new_companies()
logger.info("parse history companies by changes......")
for _row in changers_df.sort_values("date", ascending=False).itertuples(index=False):
if _row.type == self.ADD:
min_end_date = new_df.loc[new_df["symbol"] == _row.symbol, "end_date"].min()
new_df.loc[
(new_df["end_date"] == min_end_date) & (new_df["symbol"] == _row.symbol), "start_date"
] = _row.date
else:
_tmp_df = pd.DataFrame(
[[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"]
)
new_df = new_df.append(_tmp_df, sort=False)
new_df.loc[:, instruments_columns].to_csv(
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
)
logger.info(f"parse {self.index_name.lower()} companies finished.")
class CSI300(CSIIndex):
@property
def index_code(self):
return "000300"
@property
def index_name(self):
return "csi300"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2005-01-01")
@@ -284,10 +256,6 @@ class CSI100(CSIIndex):
def index_code(self):
return "000903"
@property
def index_name(self):
return "csi100"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2006-05-29")
@@ -297,19 +265,39 @@ class CSI100(CSIIndex):
return 1
def parse_instruments(qlib_dir: str):
def get_instruments(
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
):
"""
Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
index_name: str
index name, value from ["csi100", "csi300"]
method: str
method, value from ["parse_instruments", "save_new_companies"]
request_retry: int
request retry, by default 5
retry_sleep: int
request sleep, by default 3
Examples
-------
# parse instruments
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
"""
qlib_dir = Path(qlib_dir).expanduser().resolve()
qlib_dir.mkdir(exist_ok=True, parents=True)
CSI300(qlib_dir).parse_instruments()
CSI100(qlib_dir).parse_instruments()
_cur_module = importlib.import_module("collector")
obj = getattr(_cur_module, f"{index_name.upper()}")(
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
)
getattr(obj, method)()
if __name__ == "__main__":
fire.Fire(parse_instruments)
fire.Fire(get_instruments)

View File

@@ -1,14 +0,0 @@
# CSI300 History Companies Collection
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
python collector.py parse_instruments --qlib_dir ~/.qlib/stock_data/qlib_data
```

View File

@@ -0,0 +1,202 @@
import sys
import abc
from pathlib import Path
from typing import List
import pandas as pd
from tqdm import tqdm
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent))
from data_collector.utils import get_trading_date_by_shift
class IndexBase:
DEFAULT_END_DATE = pd.Timestamp("2099-12-31")
SYMBOL_FIELD_NAME = "symbol"
DATE_FIELD_NAME = "date"
START_DATE_FIELD = "start_date"
END_DATE_FIELD = "end_ate"
CHANGE_TYPE_FIELD = "type"
INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD]
REMOVE = "remove"
ADD = "add"
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
"""
Parameters
----------
index_name: str
index name
qlib_dir: str
qlib directory, by default Path(__file__).resolve().parent.joinpath("qlib_data")
request_retry: int
request retry, by default 5
retry_sleep: int
request sleep, by default 3
"""
self.index_name = index_name
if qlib_dir is None:
qlib_dir = Path(__file__).resolve().parent.joinpath("qlib_data")
self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments")
self.instruments_dir.mkdir(exist_ok=True, parents=True)
self.cache_dir = Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve()
self.cache_dir.mkdir(exist_ok=True, parents=True)
self._request_retry = request_retry
self._retry_sleep = retry_sleep
@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
"""
Returns
-------
index start date
"""
raise NotImplementedError("rewrite bench_start_date")
@property
@abc.abstractmethod
def calendar_list(self) -> List[pd.Timestamp]:
"""get history trading date
Returns
-------
calendar list
"""
raise NotImplementedError("rewrite calendar_list")
@abc.abstractmethod
def get_new_companies(self) -> pd.DataFrame:
"""
Returns
-------
pd.DataFrame:
symbol start_date end_date
SH600000 2000-01-01 2099-12-31
dtypes:
symbol: str
start_date: pd.Timestamp
end_date: pd.Timestamp
"""
raise NotImplementedError("rewrite get_new_companies")
@abc.abstractmethod
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
raise NotImplementedError("rewrite get_changes")
def save_new_companies(self):
"""save new companies
Examples
-------
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
"""
df = self.get_new_companies()
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
)
def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame:
"""get changes with history companies
Parameters
----------
history_companies : pd.DataFrame
symbol date
SH600000 2020-11-11
dtypes:
symbol: str
date: pd.Timestamp
Return
--------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
logger.info("parse changes from history companies......")
last_code = []
result_df_list = []
_columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD]
for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)):
_currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][
self.SYMBOL_FIELD_NAME
].tolist()
if last_code:
add_code = list(set(last_code) - set(_currenet_code))
remote_code = list(set(_currenet_code) - set(last_code))
for _code in add_code:
result_df_list.append(
pd.DataFrame(
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]],
columns=_columns,
)
)
for _code in remote_code:
result_df_list.append(
pd.DataFrame(
[[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]],
columns=_columns,
)
)
last_code = _currenet_code
df = pd.concat(result_df_list)
logger.info("end of parse changes from history companies.")
return df
def parse_instruments(self):
"""parse instruments, eg: csi300.txt
Examples
-------
$ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
"""
logger.info(f"start parse {self.index_name.lower()} companies.....")
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
changers_df = self.get_changes()
new_df = self.get_new_companies().copy()
logger.info("parse history companies by changes......")
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
if _row.type == self.ADD:
min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min()
new_df.loc[
(new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol),
self.START_DATE_FIELD,
] = _row.date
else:
_tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)
new_df = new_df.append(_tmp_df, sort=False)
new_df.loc[:, instruments_columns].to_csv(
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
)
logger.info(f"parse {self.index_name.lower()} companies finished.")

View File

@@ -0,0 +1,22 @@
# NASDAQ100/SP500/SP400/DJIA History Companies Collection
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
# parse instruments, using in qlib/instruments.
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
# index_name support: SP500, NASDAQ100, DJIA, SP400
# help
python collector.py --help
```

View File

@@ -0,0 +1,278 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import sys
import importlib
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from typing import List
import fire
import requests
import pandas as pd
from tqdm import tqdm
from loguru import logger
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.index import IndexBase
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
WIKI_URL = "https://en.wikipedia.org/wiki"
WIKI_INDEX_NAME_MAP = {
"NASDAQ100": "NASDAQ-100",
"SP500": "List_of_S%26P_500_companies",
"SP400": "List_of_S%26P_400_companies",
"DJIA": "Dow_Jones_Industrial_Average",
}
class WIKIIndex(IndexBase):
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
super(WIKIIndex, self).__init__(
index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep
)
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
@property
@abc.abstractmethod
def bench_start_date(self) -> pd.Timestamp:
"""
Returns
-------
index start date
"""
raise NotImplementedError("rewrite bench_start_date")
@abc.abstractmethod
def get_changes(self) -> pd.DataFrame:
"""get companies changes
Returns
-------
pd.DataFrame:
symbol date type
SH600000 2019-11-11 add
SH600000 2020-11-10 remove
dtypes:
symbol: str
date: pd.Timestamp
type: str, value from ["add", "remove"]
"""
raise NotImplementedError("rewrite get_changes")
@property
def calendar_list(self) -> List[pd.Timestamp]:
"""get history trading date
Returns
-------
calendar list
"""
_calendar_list = getattr(self, "_calendar_list", None)
if _calendar_list is None:
_calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL")))
setattr(self, "_calendar_list", _calendar_list)
return _calendar_list
def _request_new_companies(self) -> requests.Response:
resp = requests.get(self._target_url)
if resp.status_code != 200:
raise ValueError(f"request error: {self._target_url}")
return resp
def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame:
_df = df.copy()
_df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip()
_df[self.START_DATE_FIELD] = self.bench_start_date
_df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE
return _df.loc[:, self.INSTRUMENTS_COLUMNS]
def get_new_companies(self):
logger.info(f"get new companies {self.index_name} ......")
_data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()
df_list = pd.read_html(_data.text)
for _df in df_list:
_df = self.filter_df(_df)
if (_df is not None) and (not _df.empty):
_df.columns = [self.SYMBOL_FIELD_NAME]
_df = self.set_default_date_range(_df)
logger.info(f"end of get new companies {self.index_name} ......")
return _df
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
raise NotImplementedError("rewrite filter_df")
class NASDAQ100Index(WIKIIndex):
HISTORY_COMPANIES_URL = (
"https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD"
)
MAX_WORKERS = 16
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if not (set(df.columns) - {"Company", "Ticker"}):
return df.loc[:, ["Ticker"]].copy()
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2003-01-02")
@deco_retry
def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame:
trade_date = trade_date.strftime("%Y-%m-%d")
cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl")
if cache_path.exists() and use_cache:
df = pd.read_pickle(cache_path)
else:
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
resp = requests.post(url)
if resp.status_code != 200:
raise ValueError(f"request error: {url}")
df = pd.DataFrame(resp.json()["aaData"])
df[self.DATE_FIELD_NAME] = trade_date
df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True)
if not df.empty:
df.to_pickle(cache_path)
return df
def get_history_companies(self):
logger.info(f"start get history companies......")
all_history = []
error_list = []
with tqdm(total=len(self.calendar_list)) as p_bar:
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
for _trading_date, _df in zip(
self.calendar_list, executor.map(self._request_history_companies, self.calendar_list)
):
if _df.empty:
error_list.append(_trading_date)
else:
all_history.append(_df)
p_bar.update()
if error_list:
logger.warning(f"get error: {error_list}")
logger.info(f"total {len(self.calendar_list)}, error {len(error_list)}")
logger.info(f"end of get history companies.")
return pd.concat(all_history, sort=False)
def get_changes(self):
return self.get_changes_with_history_companies(self.get_history_companies())
class DJIAIndex(WIKIIndex):
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2000-01-01")
def get_changes(self) -> pd.DataFrame:
pass
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if "Symbol" in df.columns:
_df = df.loc[:, ["Symbol"]].copy()
_df["Symbol"] = _df["Symbol"].apply(lambda x: x.split(":")[-1])
return _df
def parse_instruments(self):
logger.warning(f"No suitable data source has been found!")
class SP500Index(WIKIIndex):
WIKISP500_CHANGES_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("1999-01-01")
def get_changes(self) -> pd.DataFrame:
logger.info(f"get sp500 history changes......")
# NOTE: may update the index of the table
changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1]
changes_df = changes_df.iloc[:, [0, 1, 3]]
changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]
changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])
_result = []
for _type in [self.ADD, self.REMOVE]:
_df = changes_df.copy()
_df[self.CHANGE_TYPE_FIELD] = _type
_df[self.SYMBOL_FIELD_NAME] = _df[_type]
_df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True)
if _type == self.ADD:
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
lambda x: get_trading_date_by_shift(self.calendar_list, x, 0)
)
else:
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
lambda x: get_trading_date_by_shift(self.calendar_list, x, -1)
)
_result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]])
logger.info(f"end of get sp500 history changes.")
return pd.concat(_result, sort=False)
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if "Symbol" in df.columns:
return df.loc[:, ["Symbol"]].copy()
class SP400Index(WIKIIndex):
@property
def bench_start_date(self) -> pd.Timestamp:
return pd.Timestamp("2000-01-01")
def get_changes(self) -> pd.DataFrame:
pass
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if "Ticker symbol" in df.columns:
return df.loc[:, ["Ticker symbol"]].copy()
def parse_instruments(self):
logger.warning(f"No suitable data source has been found!")
def get_instruments(
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
):
"""
Parameters
----------
qlib_dir: str
qlib data dir, default "Path(__file__).parent/qlib_data"
index_name: str
index name, value from ["SP500", "NASDAQ100", "DJIA", "SP400"]
method: str
method, value from ["parse_instruments", "save_new_companies"]
request_retry: int
request retry, by default 5
retry_sleep: int
request sleep, by default 3
Examples
-------
# parse instruments
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
# parse new companies
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
"""
_cur_module = importlib.import_module("collector")
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
)
getattr(obj, method)()
if __name__ == "__main__":
fire.Fire(get_instruments)

View File

@@ -0,0 +1,6 @@
logure
fire
requests
pandas
lxml
loguru