From c6d557c33e406b299cb4f8fa5d84e9683bf60ee7 Mon Sep 17 00:00:00 2001 From: zhupr Date: Mon, 16 Nov 2020 16:29:53 +0800 Subject: [PATCH] support US index --- scripts/data_collector/cn_index/README.md | 22 ++ .../{csi => cn_index}/collector.py | 230 +++++++-------- .../{csi => cn_index}/requirements.txt | 0 scripts/data_collector/csi/README.md | 14 - scripts/data_collector/index.py | 202 +++++++++++++ scripts/data_collector/us_index/README.md | 22 ++ scripts/data_collector/us_index/collector.py | 278 ++++++++++++++++++ .../data_collector/us_index/requirements.txt | 6 + 8 files changed, 639 insertions(+), 135 deletions(-) create mode 100644 scripts/data_collector/cn_index/README.md rename scripts/data_collector/{csi => cn_index}/collector.py (58%) rename scripts/data_collector/{csi => cn_index}/requirements.txt (100%) delete mode 100644 scripts/data_collector/csi/README.md create mode 100644 scripts/data_collector/index.py create mode 100644 scripts/data_collector/us_index/README.md create mode 100644 scripts/data_collector/us_index/collector.py create mode 100644 scripts/data_collector/us_index/requirements.txt diff --git a/scripts/data_collector/cn_index/README.md b/scripts/data_collector/cn_index/README.md new file mode 100644 index 000000000..82f17eb5c --- /dev/null +++ b/scripts/data_collector/cn_index/README.md @@ -0,0 +1,22 @@ +# CSI300/CSI100 History Companies Collection + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +```bash +# parse instruments, using in qlib/instruments. +python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + +# parse new companies +python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + +# index_name support: CSI300, CSI100 +# help +python collector.py --help +``` + diff --git a/scripts/data_collector/csi/collector.py b/scripts/data_collector/cn_index/collector.py similarity index 58% rename from scripts/data_collector/csi/collector.py rename to scripts/data_collector/cn_index/collector.py index af10c12d6..5af9785ec 100644 --- a/scripts/data_collector/csi/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -4,8 +4,9 @@ import re import abc import sys -import bisect +import importlib from io import BytesIO +from typing import List from pathlib import Path import fire @@ -16,7 +17,9 @@ from loguru import logger CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.utils import get_hs_calendar_list as get_calendar_list + +from data_collector.index import IndexBase +from data_collector.utils import get_calendar_list, get_trading_date_by_shift NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls" @@ -24,64 +27,48 @@ NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A" -class CSIIndex: - - REMOVE = "remove" - ADD = "add" - - def __init__(self, qlib_dir=None): - """ - - Parameters - ---------- - qlib_dir: str - qlib data dir, default "Path(__file__).parent/qlib_data" - """ - - if qlib_dir is None: - qlib_dir = CUR_DIR.joinpath("qlib_data") - self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments") - self.instruments_dir.mkdir(exist_ok=True, parents=True) - self._calendar_list = None - - self.cache_dir = Path("~/.cache/csi").expanduser().resolve() - self.cache_dir.mkdir(exist_ok=True, parents=True) - +class CSIIndex(IndexBase): @property - def calendar_list(self) -> list: + def calendar_list(self) -> List[pd.Timestamp]: """get history trading date Returns ------- + calendar list """ return get_calendar_list(bench_code=self.index_name.upper()) @property - def new_companies_url(self): + def new_companies_url(self) -> str: return NEW_COMPANIES_URL.format(index_code=self.index_code) @property - def changes_url(self): + def changes_url(self) -> str: return INDEX_CHANGES_URL @property @abc.abstractmethod def bench_start_date(self) -> pd.Timestamp: - raise NotImplementedError() + """ + Returns + ------- + index start date + """ + raise NotImplementedError("rewrite bench_start_date") @property @abc.abstractmethod - def index_code(self): - raise NotImplementedError() + def index_code(self) -> str: + """ + Returns + ------- + index code + """ + raise NotImplementedError("rewrite index_code") @property @abc.abstractmethod - def index_name(self): - raise NotImplementedError() - - @property - @abc.abstractmethod - def html_table_index(self): + def html_table_index(self) -> int: """Which table of changes in html CSI300: 0 @@ -90,33 +77,19 @@ class CSIIndex: """ raise NotImplementedError() - def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1): - """get trading date by shift - - Parameters - ---------- - shift : int - shift, default is 1 - - trading_date : pd.Timestamp - trading date - Returns - ------- - - """ - left_index = bisect.bisect_left(self.calendar_list, trading_date) - try: - res = self.calendar_list[left_index + shift] - except IndexError: - res = trading_date - return res - - def _get_changes(self) -> pd.DataFrame: + def get_changes(self) -> pd.DataFrame: """get companies changes Returns ------- - + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] """ logger.info("get companies changes......") res = [] @@ -124,10 +97,21 @@ class CSIIndex: _df = self._read_change_from_url(_url) res.append(_df) logger.info("get companies changes finish") - return pd.concat(res) + return pd.concat(res, sort=False) @staticmethod - def normalize_symbol(symbol): + def normalize_symbol(symbol: str) -> str: + """ + + Parameters + ---------- + symbol: str + symbol + + Returns + ------- + symbol + """ symbol = f"{int(symbol):06}" return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}" @@ -141,7 +125,14 @@ class CSIIndex: Returns ------- - + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] """ resp = requests.get(url) _text = resp.text @@ -151,8 +142,8 @@ class CSIIndex: add_date = pd.Timestamp("-".join(date_list[0])) else: _date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0])) - add_date = self._get_trading_date_by_shift(_date, shift=0) - remove_date = self._get_trading_date_by_shift(add_date, shift=-1) + add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0) + remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1) logger.info(f"get {add_date} changes") try: excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0] @@ -168,12 +159,12 @@ class CSIIndex: _df = df_map[_s_name] _df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]] _df = _df.applymap(self.normalize_symbol) - _df.columns = ["symbol"] + _df.columns = [self.SYMBOL_FIELD_NAME] _df["type"] = _type - _df["date"] = _date + _df[self.DATE_FIELD_NAME] = _date tmp.append(_df) df = pd.concat(tmp) - except Exception: + except Exception as e: df = None _tmp_count = 0 for _df in pd.read_html(resp.content): @@ -188,9 +179,9 @@ class CSIIndex: (_df.iloc[2:, 2], self.ADD, add_date), ]: _tmp_df = pd.DataFrame() - _tmp_df["symbol"] = _s.map(self.normalize_symbol) + _tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol) _tmp_df["type"] = _type - _tmp_df["date"] = _date + _tmp_df[self.DATE_FIELD_NAME] = _date tmp.append(_tmp_df) df = pd.concat(tmp) df.to_csv( @@ -203,20 +194,33 @@ class CSIIndex: break return df - def _get_change_notices_url(self) -> list: + def _get_change_notices_url(self) -> List[str]: """get change notices url Returns ------- - + [url1, url2] """ resp = requests.get(self.changes_url) html = etree.HTML(resp.text) return html.xpath("//*[@id='itemContainer']//li/a/@href") - def _get_new_companies(self): + def get_new_companies(self) -> pd.DataFrame: + """ - logger.info("get new companies") + Returns + ------- + pd.DataFrame: + + symbol start_date end_date + SH600000 2000-01-01 2099-12-31 + + dtypes: + symbol: str + start_date: pd.Timestamp + end_date: pd.Timestamp + """ + logger.info("get new companies......") context = requests.get(self.new_companies_url).content with self.cache_dir.joinpath( f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}" @@ -225,51 +229,19 @@ class CSIIndex: _io = BytesIO(context) df = pd.read_excel(_io) df = df.iloc[:, [0, 4]] - df.columns = ["end_date", "symbol"] - df["symbol"] = df["symbol"].map(self.normalize_symbol) - df["end_date"] = pd.to_datetime(df["end_date"]) - df["start_date"] = self.bench_start_date + df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME] + df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol) + df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD]) + df[self.START_DATE_FIELD] = self.bench_start_date + logger.info("end of get new companies.") return df - def parse_instruments(self): - """parse csi300.txt - - Examples - ------- - $ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data - """ - logger.info(f"start parse {self.index_name.lower()} companies.....") - instruments_columns = ["symbol", "start_date", "end_date"] - changers_df = self._get_changes() - new_df = self._get_new_companies() - logger.info("parse history companies by changes......") - for _row in changers_df.sort_values("date", ascending=False).itertuples(index=False): - if _row.type == self.ADD: - min_end_date = new_df.loc[new_df["symbol"] == _row.symbol, "end_date"].min() - new_df.loc[ - (new_df["end_date"] == min_end_date) & (new_df["symbol"] == _row.symbol), "start_date" - ] = _row.date - else: - _tmp_df = pd.DataFrame( - [[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"] - ) - new_df = new_df.append(_tmp_df, sort=False) - - new_df.loc[:, instruments_columns].to_csv( - self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None - ) - logger.info(f"parse {self.index_name.lower()} companies finished.") - class CSI300(CSIIndex): @property def index_code(self): return "000300" - @property - def index_name(self): - return "csi300" - @property def bench_start_date(self) -> pd.Timestamp: return pd.Timestamp("2005-01-01") @@ -284,10 +256,6 @@ class CSI100(CSIIndex): def index_code(self): return "000903" - @property - def index_name(self): - return "csi100" - @property def bench_start_date(self) -> pd.Timestamp: return pd.Timestamp("2006-05-29") @@ -297,19 +265,39 @@ class CSI100(CSIIndex): return 1 -def parse_instruments(qlib_dir: str): +def get_instruments( + qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3 +): """ Parameters ---------- qlib_dir: str qlib data dir, default "Path(__file__).parent/qlib_data" + index_name: str + index name, value from ["csi100", "csi300"] + method: str + method, value from ["parse_instruments", "save_new_companies"] + request_retry: int + request retry, by default 5 + retry_sleep: int + request sleep, by default 3 + + Examples + ------- + # parse instruments + $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + + # parse new companies + $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + """ - qlib_dir = Path(qlib_dir).expanduser().resolve() - qlib_dir.mkdir(exist_ok=True, parents=True) - CSI300(qlib_dir).parse_instruments() - CSI100(qlib_dir).parse_instruments() + _cur_module = importlib.import_module("collector") + obj = getattr(_cur_module, f"{index_name.upper()}")( + qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep + ) + getattr(obj, method)() if __name__ == "__main__": - fire.Fire(parse_instruments) + fire.Fire(get_instruments) diff --git a/scripts/data_collector/csi/requirements.txt b/scripts/data_collector/cn_index/requirements.txt similarity index 100% rename from scripts/data_collector/csi/requirements.txt rename to scripts/data_collector/cn_index/requirements.txt diff --git a/scripts/data_collector/csi/README.md b/scripts/data_collector/csi/README.md deleted file mode 100644 index 52100df81..000000000 --- a/scripts/data_collector/csi/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# CSI300 History Companies Collection - -## Requirements - -```bash -pip install -r requirements.txt -``` - -## Collector Data - -```bash -python collector.py parse_instruments --qlib_dir ~/.qlib/stock_data/qlib_data -``` - diff --git a/scripts/data_collector/index.py b/scripts/data_collector/index.py new file mode 100644 index 000000000..c5f3854fd --- /dev/null +++ b/scripts/data_collector/index.py @@ -0,0 +1,202 @@ +import sys +import abc +from pathlib import Path +from typing import List + +import pandas as pd +from tqdm import tqdm +from loguru import logger + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent)) + + +from data_collector.utils import get_trading_date_by_shift + + +class IndexBase: + DEFAULT_END_DATE = pd.Timestamp("2099-12-31") + SYMBOL_FIELD_NAME = "symbol" + DATE_FIELD_NAME = "date" + START_DATE_FIELD = "start_date" + END_DATE_FIELD = "end_ate" + CHANGE_TYPE_FIELD = "type" + INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD] + REMOVE = "remove" + ADD = "add" + + def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3): + """ + + Parameters + ---------- + index_name: str + index name + qlib_dir: str + qlib directory, by default Path(__file__).resolve().parent.joinpath("qlib_data") + request_retry: int + request retry, by default 5 + retry_sleep: int + request sleep, by default 3 + """ + self.index_name = index_name + if qlib_dir is None: + qlib_dir = Path(__file__).resolve().parent.joinpath("qlib_data") + self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments") + self.instruments_dir.mkdir(exist_ok=True, parents=True) + self.cache_dir = Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve() + self.cache_dir.mkdir(exist_ok=True, parents=True) + self._request_retry = request_retry + self._retry_sleep = retry_sleep + + @property + @abc.abstractmethod + def bench_start_date(self) -> pd.Timestamp: + """ + Returns + ------- + index start date + """ + raise NotImplementedError("rewrite bench_start_date") + + @property + @abc.abstractmethod + def calendar_list(self) -> List[pd.Timestamp]: + """get history trading date + + Returns + ------- + calendar list + """ + raise NotImplementedError("rewrite calendar_list") + + @abc.abstractmethod + def get_new_companies(self) -> pd.DataFrame: + """ + + Returns + ------- + pd.DataFrame: + + symbol start_date end_date + SH600000 2000-01-01 2099-12-31 + + dtypes: + symbol: str + start_date: pd.Timestamp + end_date: pd.Timestamp + """ + raise NotImplementedError("rewrite get_new_companies") + + @abc.abstractmethod + def get_changes(self) -> pd.DataFrame: + """get companies changes + + Returns + ------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + """ + raise NotImplementedError("rewrite get_changes") + + def save_new_companies(self): + """save new companies + + Examples + ------- + $ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data + """ + df = self.get_new_companies() + df = df.drop_duplicates([self.SYMBOL_FIELD_NAME]) + df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv( + self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None + ) + + def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame: + """get changes with history companies + + Parameters + ---------- + history_companies : pd.DataFrame + symbol date + SH600000 2020-11-11 + + dtypes: + symbol: str + date: pd.Timestamp + + Return + -------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + + """ + logger.info("parse changes from history companies......") + last_code = [] + result_df_list = [] + _columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD] + for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)): + _currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][ + self.SYMBOL_FIELD_NAME + ].tolist() + if last_code: + add_code = list(set(last_code) - set(_currenet_code)) + remote_code = list(set(_currenet_code) - set(last_code)) + for _code in add_code: + result_df_list.append( + pd.DataFrame( + [[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]], + columns=_columns, + ) + ) + for _code in remote_code: + result_df_list.append( + pd.DataFrame( + [[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]], + columns=_columns, + ) + ) + last_code = _currenet_code + df = pd.concat(result_df_list) + logger.info("end of parse changes from history companies.") + return df + + def parse_instruments(self): + """parse instruments, eg: csi300.txt + + Examples + ------- + $ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data + """ + logger.info(f"start parse {self.index_name.lower()} companies.....") + instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD] + changers_df = self.get_changes() + new_df = self.get_new_companies().copy() + logger.info("parse history companies by changes......") + for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)): + if _row.type == self.ADD: + min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min() + new_df.loc[ + (new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol), + self.START_DATE_FIELD, + ] = _row.date + else: + _tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns) + new_df = new_df.append(_tmp_df, sort=False) + + new_df.loc[:, instruments_columns].to_csv( + self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None + ) + logger.info(f"parse {self.index_name.lower()} companies finished.") diff --git a/scripts/data_collector/us_index/README.md b/scripts/data_collector/us_index/README.md new file mode 100644 index 000000000..99a0a09c3 --- /dev/null +++ b/scripts/data_collector/us_index/README.md @@ -0,0 +1,22 @@ +# NASDAQ100/SP500/SP400/DJIA History Companies Collection + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +```bash +# parse instruments, using in qlib/instruments. +python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + +# parse new companies +python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + +# index_name support: SP500, NASDAQ100, DJIA, SP400 +# help +python collector.py --help +``` + diff --git a/scripts/data_collector/us_index/collector.py b/scripts/data_collector/us_index/collector.py new file mode 100644 index 000000000..ea1e974a0 --- /dev/null +++ b/scripts/data_collector/us_index/collector.py @@ -0,0 +1,278 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import abc +import sys +import importlib +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +from typing import List + +import fire +import requests +import pandas as pd +from tqdm import tqdm +from loguru import logger + + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) + +from data_collector.index import IndexBase +from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift + + +WIKI_URL = "https://en.wikipedia.org/wiki" + +WIKI_INDEX_NAME_MAP = { + "NASDAQ100": "NASDAQ-100", + "SP500": "List_of_S%26P_500_companies", + "SP400": "List_of_S%26P_400_companies", + "DJIA": "Dow_Jones_Industrial_Average", +} + + +class WIKIIndex(IndexBase): + def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3): + super(WIKIIndex, self).__init__( + index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep + ) + + self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}" + + @property + @abc.abstractmethod + def bench_start_date(self) -> pd.Timestamp: + """ + Returns + ------- + index start date + """ + raise NotImplementedError("rewrite bench_start_date") + + @abc.abstractmethod + def get_changes(self) -> pd.DataFrame: + """get companies changes + + Returns + ------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + """ + raise NotImplementedError("rewrite get_changes") + + @property + def calendar_list(self) -> List[pd.Timestamp]: + """get history trading date + + Returns + ------- + calendar list + """ + _calendar_list = getattr(self, "_calendar_list", None) + if _calendar_list is None: + _calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL"))) + setattr(self, "_calendar_list", _calendar_list) + return _calendar_list + + def _request_new_companies(self) -> requests.Response: + resp = requests.get(self._target_url) + if resp.status_code != 200: + raise ValueError(f"request error: {self._target_url}") + + return resp + + def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame: + _df = df.copy() + _df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip() + _df[self.START_DATE_FIELD] = self.bench_start_date + _df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE + return _df.loc[:, self.INSTRUMENTS_COLUMNS] + + def get_new_companies(self): + logger.info(f"get new companies {self.index_name} ......") + _data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)() + df_list = pd.read_html(_data.text) + for _df in df_list: + _df = self.filter_df(_df) + if (_df is not None) and (not _df.empty): + _df.columns = [self.SYMBOL_FIELD_NAME] + _df = self.set_default_date_range(_df) + logger.info(f"end of get new companies {self.index_name} ......") + return _df + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + raise NotImplementedError("rewrite filter_df") + + +class NASDAQ100Index(WIKIIndex): + + HISTORY_COMPANIES_URL = ( + "https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD" + ) + MAX_WORKERS = 16 + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + if not (set(df.columns) - {"Company", "Ticker"}): + return df.loc[:, ["Ticker"]].copy() + + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2003-01-02") + + @deco_retry + def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame: + trade_date = trade_date.strftime("%Y-%m-%d") + cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl") + if cache_path.exists() and use_cache: + df = pd.read_pickle(cache_path) + else: + url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date) + resp = requests.post(url) + if resp.status_code != 200: + raise ValueError(f"request error: {url}") + df = pd.DataFrame(resp.json()["aaData"]) + df[self.DATE_FIELD_NAME] = trade_date + df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True) + if not df.empty: + df.to_pickle(cache_path) + return df + + def get_history_companies(self): + logger.info(f"start get history companies......") + all_history = [] + error_list = [] + with tqdm(total=len(self.calendar_list)) as p_bar: + with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor: + for _trading_date, _df in zip( + self.calendar_list, executor.map(self._request_history_companies, self.calendar_list) + ): + if _df.empty: + error_list.append(_trading_date) + else: + all_history.append(_df) + p_bar.update() + + if error_list: + logger.warning(f"get error: {error_list}") + logger.info(f"total {len(self.calendar_list)}, error {len(error_list)}") + logger.info(f"end of get history companies.") + return pd.concat(all_history, sort=False) + + def get_changes(self): + return self.get_changes_with_history_companies(self.get_history_companies()) + + +class DJIAIndex(WIKIIndex): + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2000-01-01") + + def get_changes(self) -> pd.DataFrame: + pass + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + if "Symbol" in df.columns: + _df = df.loc[:, ["Symbol"]].copy() + _df["Symbol"] = _df["Symbol"].apply(lambda x: x.split(":")[-1]) + return _df + + def parse_instruments(self): + logger.warning(f"No suitable data source has been found!") + + +class SP500Index(WIKIIndex): + WIKISP500_CHANGES_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies" + + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("1999-01-01") + + def get_changes(self) -> pd.DataFrame: + logger.info(f"get sp500 history changes......") + # NOTE: may update the index of the table + changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1] + changes_df = changes_df.iloc[:, [0, 1, 3]] + changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE] + changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME]) + _result = [] + for _type in [self.ADD, self.REMOVE]: + _df = changes_df.copy() + _df[self.CHANGE_TYPE_FIELD] = _type + _df[self.SYMBOL_FIELD_NAME] = _df[_type] + _df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True) + if _type == self.ADD: + _df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply( + lambda x: get_trading_date_by_shift(self.calendar_list, x, 0) + ) + else: + _df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply( + lambda x: get_trading_date_by_shift(self.calendar_list, x, -1) + ) + _result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]]) + logger.info(f"end of get sp500 history changes.") + return pd.concat(_result, sort=False) + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + if "Symbol" in df.columns: + return df.loc[:, ["Symbol"]].copy() + + +class SP400Index(WIKIIndex): + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2000-01-01") + + def get_changes(self) -> pd.DataFrame: + pass + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + if "Ticker symbol" in df.columns: + return df.loc[:, ["Ticker symbol"]].copy() + + def parse_instruments(self): + logger.warning(f"No suitable data source has been found!") + + +def get_instruments( + qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3 +): + """ + + Parameters + ---------- + qlib_dir: str + qlib data dir, default "Path(__file__).parent/qlib_data" + index_name: str + index name, value from ["SP500", "NASDAQ100", "DJIA", "SP400"] + method: str + method, value from ["parse_instruments", "save_new_companies"] + request_retry: int + request retry, by default 5 + retry_sleep: int + request sleep, by default 3 + + Examples + ------- + # parse instruments + $ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + + # parse new companies + $ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + + """ + _cur_module = importlib.import_module("collector") + obj = getattr(_cur_module, f"{index_name.upper()}Index")( + qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep + ) + getattr(obj, method)() + + +if __name__ == "__main__": + fire.Fire(get_instruments) diff --git a/scripts/data_collector/us_index/requirements.txt b/scripts/data_collector/us_index/requirements.txt new file mode 100644 index 000000000..729271038 --- /dev/null +++ b/scripts/data_collector/us_index/requirements.txt @@ -0,0 +1,6 @@ +logure +fire +requests +pandas +lxml +loguru