mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 09:31:18 +08:00
283 lines
9.9 KiB
Python
283 lines
9.9 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import abc
|
|
import sys
|
|
import importlib
|
|
from pathlib import Path
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import List
|
|
|
|
import fire
|
|
import requests
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
from loguru import logger
|
|
|
|
|
|
CUR_DIR = Path(__file__).resolve().parent
|
|
sys.path.append(str(CUR_DIR.parent.parent))
|
|
|
|
from data_collector.index import IndexBase
|
|
from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift
|
|
|
|
|
|
WIKI_URL = "https://en.wikipedia.org/wiki"
|
|
|
|
WIKI_INDEX_NAME_MAP = {
|
|
"NASDAQ100": "NASDAQ-100",
|
|
"SP500": "List_of_S%26P_500_companies",
|
|
"SP400": "List_of_S%26P_400_companies",
|
|
"DJIA": "Dow_Jones_Industrial_Average",
|
|
}
|
|
|
|
|
|
class WIKIIndex(IndexBase):
|
|
# NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix
|
|
# https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
|
INST_PREFIX = "_"
|
|
|
|
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
|
super(WIKIIndex, self).__init__(
|
|
index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep
|
|
)
|
|
|
|
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
"""
|
|
Returns
|
|
-------
|
|
index start date
|
|
"""
|
|
raise NotImplementedError("rewrite bench_start_date")
|
|
|
|
@abc.abstractmethod
|
|
def get_changes(self) -> pd.DataFrame:
|
|
"""get companies changes
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame:
|
|
symbol date type
|
|
SH600000 2019-11-11 add
|
|
SH600000 2020-11-10 remove
|
|
dtypes:
|
|
symbol: str
|
|
date: pd.Timestamp
|
|
type: str, value from ["add", "remove"]
|
|
"""
|
|
raise NotImplementedError("rewrite get_changes")
|
|
|
|
@property
|
|
def calendar_list(self) -> List[pd.Timestamp]:
|
|
"""get history trading date
|
|
|
|
Returns
|
|
-------
|
|
calendar list
|
|
"""
|
|
_calendar_list = getattr(self, "_calendar_list", None)
|
|
if _calendar_list is None:
|
|
_calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL")))
|
|
setattr(self, "_calendar_list", _calendar_list)
|
|
return _calendar_list
|
|
|
|
def _request_new_companies(self) -> requests.Response:
|
|
resp = requests.get(self._target_url)
|
|
if resp.status_code != 200:
|
|
raise ValueError(f"request error: {self._target_url}")
|
|
|
|
return resp
|
|
|
|
def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
_df = df.copy()
|
|
_df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip()
|
|
_df[self.START_DATE_FIELD] = self.bench_start_date
|
|
_df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE
|
|
return _df.loc[:, self.INSTRUMENTS_COLUMNS]
|
|
|
|
def get_new_companies(self):
|
|
logger.info(f"get new companies {self.index_name} ......")
|
|
_data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)()
|
|
df_list = pd.read_html(_data.text)
|
|
for _df in df_list:
|
|
_df = self.filter_df(_df)
|
|
if (_df is not None) and (not _df.empty):
|
|
_df.columns = [self.SYMBOL_FIELD_NAME]
|
|
_df = self.set_default_date_range(_df)
|
|
logger.info(f"end of get new companies {self.index_name} ......")
|
|
return _df
|
|
|
|
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
raise NotImplementedError("rewrite filter_df")
|
|
|
|
|
|
class NASDAQ100Index(WIKIIndex):
|
|
|
|
HISTORY_COMPANIES_URL = (
|
|
"https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD"
|
|
)
|
|
MAX_WORKERS = 16
|
|
|
|
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
if not (set(df.columns) - {"Company", "Ticker"}):
|
|
return df.loc[:, ["Ticker"]].copy()
|
|
|
|
@property
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
return pd.Timestamp("2003-01-02")
|
|
|
|
@deco_retry
|
|
def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame:
|
|
trade_date = trade_date.strftime("%Y-%m-%d")
|
|
cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl")
|
|
if cache_path.exists() and use_cache:
|
|
df = pd.read_pickle(cache_path)
|
|
else:
|
|
url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date)
|
|
resp = requests.post(url)
|
|
if resp.status_code != 200:
|
|
raise ValueError(f"request error: {url}")
|
|
df = pd.DataFrame(resp.json()["aaData"])
|
|
df[self.DATE_FIELD_NAME] = trade_date
|
|
df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True)
|
|
if not df.empty:
|
|
df.to_pickle(cache_path)
|
|
return df
|
|
|
|
def get_history_companies(self):
|
|
logger.info(f"start get history companies......")
|
|
all_history = []
|
|
error_list = []
|
|
with tqdm(total=len(self.calendar_list)) as p_bar:
|
|
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
|
for _trading_date, _df in zip(
|
|
self.calendar_list, executor.map(self._request_history_companies, self.calendar_list)
|
|
):
|
|
if _df.empty:
|
|
error_list.append(_trading_date)
|
|
else:
|
|
all_history.append(_df)
|
|
p_bar.update()
|
|
|
|
if error_list:
|
|
logger.warning(f"get error: {error_list}")
|
|
logger.info(f"total {len(self.calendar_list)}, error {len(error_list)}")
|
|
logger.info(f"end of get history companies.")
|
|
return pd.concat(all_history, sort=False)
|
|
|
|
def get_changes(self):
|
|
return self.get_changes_with_history_companies(self.get_history_companies())
|
|
|
|
|
|
class DJIAIndex(WIKIIndex):
|
|
@property
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
return pd.Timestamp("2000-01-01")
|
|
|
|
def get_changes(self) -> pd.DataFrame:
|
|
pass
|
|
|
|
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
if "Symbol" in df.columns:
|
|
_df = df.loc[:, ["Symbol"]].copy()
|
|
_df["Symbol"] = _df["Symbol"].apply(lambda x: x.split(":")[-1])
|
|
return _df
|
|
|
|
def parse_instruments(self):
|
|
logger.warning(f"No suitable data source has been found!")
|
|
|
|
|
|
class SP500Index(WIKIIndex):
|
|
WIKISP500_CHANGES_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies"
|
|
|
|
@property
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
return pd.Timestamp("1999-01-01")
|
|
|
|
def get_changes(self) -> pd.DataFrame:
|
|
logger.info(f"get sp500 history changes......")
|
|
# NOTE: may update the index of the table
|
|
changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1]
|
|
changes_df = changes_df.iloc[:, [0, 1, 3]]
|
|
changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE]
|
|
changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME])
|
|
_result = []
|
|
for _type in [self.ADD, self.REMOVE]:
|
|
_df = changes_df.copy()
|
|
_df[self.CHANGE_TYPE_FIELD] = _type
|
|
_df[self.SYMBOL_FIELD_NAME] = _df[_type]
|
|
_df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True)
|
|
if _type == self.ADD:
|
|
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
|
|
lambda x: get_trading_date_by_shift(self.calendar_list, x, 0)
|
|
)
|
|
else:
|
|
_df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply(
|
|
lambda x: get_trading_date_by_shift(self.calendar_list, x, -1)
|
|
)
|
|
_result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]])
|
|
logger.info(f"end of get sp500 history changes.")
|
|
return pd.concat(_result, sort=False)
|
|
|
|
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
if "Symbol" in df.columns:
|
|
return df.loc[:, ["Symbol"]].copy()
|
|
|
|
|
|
class SP400Index(WIKIIndex):
|
|
@property
|
|
def bench_start_date(self) -> pd.Timestamp:
|
|
return pd.Timestamp("2000-01-01")
|
|
|
|
def get_changes(self) -> pd.DataFrame:
|
|
pass
|
|
|
|
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
if "Ticker symbol" in df.columns:
|
|
return df.loc[:, ["Ticker symbol"]].copy()
|
|
|
|
def parse_instruments(self):
|
|
logger.warning(f"No suitable data source has been found!")
|
|
|
|
|
|
def get_instruments(
|
|
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
|
):
|
|
"""
|
|
|
|
Parameters
|
|
----------
|
|
qlib_dir: str
|
|
qlib data dir, default "Path(__file__).parent/qlib_data"
|
|
index_name: str
|
|
index name, value from ["SP500", "NASDAQ100", "DJIA", "SP400"]
|
|
method: str
|
|
method, value from ["parse_instruments", "save_new_companies"]
|
|
request_retry: int
|
|
request retry, by default 5
|
|
retry_sleep: int
|
|
request sleep, by default 3
|
|
|
|
Examples
|
|
-------
|
|
# parse instruments
|
|
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
|
|
|
# parse new companies
|
|
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
|
|
|
"""
|
|
_cur_module = importlib.import_module("collector")
|
|
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
|
|
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
|
)
|
|
getattr(obj, method)()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(get_instruments)
|