mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 19:41:00 +08:00
refactor yahoo_collector && support US and 1m
This commit is contained in:
@@ -3,56 +3,69 @@
|
||||
|
||||
import re
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
|
||||
SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231"
|
||||
|
||||
CALENDAR_BENCH_URL_MAP = {
|
||||
"CSI300": CALENDAR_URL_BASE.format(bench_code="000300"),
|
||||
"CSI100": CALENDAR_URL_BASE.format(bench_code="000903"),
|
||||
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
|
||||
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
|
||||
# NOTE: Use the time series of SH600000 as the sequence of all stocks
|
||||
"ALL": CALENDAR_URL_BASE.format(bench_code="600000"),
|
||||
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
|
||||
"US_ALL": "^GSPC",
|
||||
}
|
||||
|
||||
|
||||
_BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
_US_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
# NOTE: Until 2020-10-20 20:00:00
|
||||
MINIMUM_SYMBOLS_NUM = 3900
|
||||
|
||||
|
||||
def get_hs_calendar_list(bench_code="CSI300") -> list:
|
||||
def get_calendar_list(bench_code="CSI300") -> list:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bench_code: str
|
||||
value from ["CSI300", "CSI500", "ALL"]
|
||||
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
def _get_calendar(url):
|
||||
_value_list = requests.get(url).json()["data"]["klines"]
|
||||
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
|
||||
|
||||
calendar = _CALENDAR_MAP.get(bench_code, None)
|
||||
if calendar is None:
|
||||
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
|
||||
if bench_code.startswith("US_"):
|
||||
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
|
||||
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
|
||||
else:
|
||||
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
|
||||
_CALENDAR_MAP[bench_code] = calendar
|
||||
logger.info(f"end of get calendar list: {bench_code}.")
|
||||
return calendar
|
||||
|
||||
|
||||
@@ -68,13 +81,14 @@ def get_hs_stock_symbols() -> list:
|
||||
def _get_symbol():
|
||||
_res = set()
|
||||
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
|
||||
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
|
||||
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k))
|
||||
_res |= set(
|
||||
map(
|
||||
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
|
||||
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
|
||||
)
|
||||
)
|
||||
time.sleep(3)
|
||||
return _res
|
||||
|
||||
if _HS_SYMBOLS is None:
|
||||
@@ -99,6 +113,84 @@ def get_hs_stock_symbols() -> list:
|
||||
return _HS_SYMBOLS
|
||||
|
||||
|
||||
def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get US stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _US_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
return _symbols
|
||||
|
||||
@deco_retry
|
||||
def _get_nasdaq():
|
||||
_res_symbols = []
|
||||
for _name in ["otherlisted", "nasdaqtraded"]:
|
||||
url = f"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt"
|
||||
df = pd.read_csv(url, sep="|")
|
||||
df = df.rename(columns={"ACT Symbol": "Symbol"})
|
||||
_symbols = df["Symbol"].dropna()
|
||||
_symbols = _symbols.str.replace("$", "-P", regex=False)
|
||||
_symbols = _symbols.str.replace(".W", "-WT", regex=False)
|
||||
_symbols = _symbols.str.replace(".U", "-UN", regex=False)
|
||||
_symbols = _symbols.str.replace(".R", "-RI", regex=False)
|
||||
_symbols = _symbols.str.replace(".", "-", regex=False)
|
||||
_res_symbols += _symbols.unique().tolist()
|
||||
return _res_symbols
|
||||
|
||||
@deco_retry
|
||||
def _get_nyse():
|
||||
url = "https://www.nyse.com/api/quotes/filter"
|
||||
_parms = {
|
||||
"instrumentType": "EQUITY",
|
||||
"pageNumber": 1,
|
||||
"sortColumn": "NORMALIZED_TICKER",
|
||||
"sortOrder": "ASC",
|
||||
"maxResultsPerPage": 10000,
|
||||
"filterToken": "",
|
||||
}
|
||||
resp = requests.post(url, json=_parms)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
_symbols = []
|
||||
return _symbols
|
||||
|
||||
if _US_SYMBOLS is None:
|
||||
_all_symbols = _get_eastmoney() + _get_nasdaq() + _get_nyse()
|
||||
if qlib_data_path is not None:
|
||||
for _index in ["nasdaq100", "sp500"]:
|
||||
ins_df = pd.read_csv(
|
||||
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
|
||||
sep="\t",
|
||||
names=["symbol", "start_date", "end_date"],
|
||||
)
|
||||
_all_symbols += ins_df["symbol"].unique().tolist()
|
||||
_US_SYMBOLS = sorted(
|
||||
set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))
|
||||
)
|
||||
|
||||
return _US_SYMBOLS
|
||||
|
||||
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
@@ -137,5 +229,52 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
def deco_func(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
_retry = 5 if callable(retry) else retry
|
||||
_result = None
|
||||
for _i in range(1, _retry + 1):
|
||||
try:
|
||||
_result = func(*args, **kwargs)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{func.__name__}: {_i} :{e}")
|
||||
if _i == _retry:
|
||||
raise
|
||||
time.sleep(retry_sleep)
|
||||
return _result
|
||||
|
||||
return wrapper
|
||||
|
||||
return deco_func(retry) if callable(retry) else deco_func
|
||||
|
||||
|
||||
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
|
||||
"""get trading date by shift
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trading_list: list
|
||||
trading calendar list
|
||||
shift : int
|
||||
shift, default is 1
|
||||
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
trading_date = pd.Timestamp(trading_date)
|
||||
left_index = bisect.bisect_left(trading_list, trading_date)
|
||||
try:
|
||||
res = trading_list[left_index + shift]
|
||||
except IndexError:
|
||||
res = trading_date
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -18,31 +18,29 @@ pip install -r requirements.txt
|
||||
|
||||
## Collector Data
|
||||
|
||||
### Download data -> Normalize data -> Dump data
|
||||
### Download data and Normalize data
|
||||
```bash
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
```
|
||||
|
||||
### Download Data From Yahoo Finance
|
||||
### Download Data
|
||||
|
||||
```bash
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source
|
||||
python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
```
|
||||
|
||||
### Normalize Yahoo Finance Data
|
||||
### Normalize Data
|
||||
|
||||
```bash
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN
|
||||
```
|
||||
|
||||
### Manual Ajust Yahoo Finance Data
|
||||
|
||||
### Help
|
||||
```bash
|
||||
python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
|
||||
pythono collector.py collector_data --help
|
||||
```
|
||||
|
||||
### Dump Yahoo Finance Data
|
||||
## Parameters
|
||||
|
||||
```bash
|
||||
python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
```
|
||||
- interval: 1m or 1d
|
||||
- region: CN or US
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
@@ -16,30 +20,81 @@ from yahooquery import Ticker
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from dump_bin import DumpData
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
MIN_NUMBERS_TRADING = 252 / 4
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
REGION_CN = "CN"
|
||||
REGION_US = "US"
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=False, max_collector_count=5, delay=0):
|
||||
START_DATETIME = pd.Timestamp("2000-01-01")
|
||||
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 5))
|
||||
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=5,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 5
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1m, 1d], default 1m
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._delay = delay
|
||||
self._stock_list = None
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
self.max_workers = max_workers
|
||||
self._asynchronous = asynchronous
|
||||
self._max_collector_count = max_collector_count
|
||||
self._mini_symbol_map = {}
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
self._start_datetime = pd.Timestamp(start) if start else self.START_DATETIME
|
||||
self._end_datetime = pd.Timestamp(end) if end else self.END_DATETIME
|
||||
if self._interval == "1m":
|
||||
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == "1d":
|
||||
self._start_datetime = max(self._start_datetime, self.START_DATETIME)
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
self._end_datetime = min(self._end_datetime, self.END_DATETIME)
|
||||
|
||||
@property
|
||||
def stock_list(self):
|
||||
if self._stock_list is None:
|
||||
self._stock_list = get_hs_stock_symbols()
|
||||
return self._stock_list
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("")
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
@@ -57,63 +112,85 @@ class YahooCollector:
|
||||
if df.empty:
|
||||
raise ValueError("df is empty")
|
||||
|
||||
symbol_s = symbol.split(".")
|
||||
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
|
||||
symbol = self.normailze_symbol(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
df.to_csv(stock_path, index=False)
|
||||
|
||||
def _temp_save_small_data(self, symbol, df):
|
||||
if len(df) <= MIN_NUMBERS_TRADING:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {MIN_NUMBERS_TRADING}!")
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self._mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return symbol
|
||||
else:
|
||||
if symbol in self._mini_symbol_map:
|
||||
self._mini_symbol_map.pop(symbol)
|
||||
return None
|
||||
|
||||
def _get_from_remote(self, symbol):
|
||||
if self._interval == "1d":
|
||||
self._sleep()
|
||||
try:
|
||||
resp = Ticker(symbol, asynchronous=False).history(
|
||||
interval=self._interval, start=self._start_datetime, end=self._end_datetime
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol}-{self._interval}-{self._start_datetime}-{self._end_datetime}:{e}")
|
||||
resp = None
|
||||
yield resp
|
||||
elif self._interval == "1m":
|
||||
_res = []
|
||||
for _start in pd.date_range(self._start_datetime, self._end_datetime + pd.Timedelta(days=-1)):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
self._sleep()
|
||||
try:
|
||||
resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=_start, end=_end)
|
||||
if isinstance(resp, pd.DataFrame):
|
||||
_res.append(resp)
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol}-{self._interval}-{_start}-{_end}:{e}")
|
||||
if _res:
|
||||
yield pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
yield None
|
||||
else:
|
||||
raise ValueError(f"cannot support {self._interval}")
|
||||
|
||||
def _collector(self, stock_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
futures = {}
|
||||
p_bar = tqdm(total=len(stock_list))
|
||||
for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]:
|
||||
self._sleep()
|
||||
resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history(
|
||||
period="max"
|
||||
)
|
||||
if isinstance(resp, dict):
|
||||
for symbol, df in resp.items():
|
||||
if isinstance(df, pd.DataFrame):
|
||||
self._temp_save_small_data(self, df)
|
||||
futures[
|
||||
worker.submit(
|
||||
self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"})
|
||||
)
|
||||
] = symbol
|
||||
else:
|
||||
error_symbol.append(symbol)
|
||||
else:
|
||||
for symbol, df in resp.reset_index().groupby("symbol"):
|
||||
self._temp_save_small_data(self, df)
|
||||
futures[worker.submit(self.save_stock, symbol, df)] = symbol
|
||||
p_bar.update(self.max_workers)
|
||||
p_bar.close()
|
||||
|
||||
with tqdm(total=len(futures.values())) as p_bar:
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
error_symbol.append(futures[future])
|
||||
p_bar.update()
|
||||
for _symbol in tqdm(stock_list):
|
||||
for _resp in self._get_from_remote(_symbol):
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
df = _resp.reset_index()
|
||||
if self._check_small_data:
|
||||
if self._save_small_data(_symbol, df) is not None:
|
||||
error_symbol.append(_symbol)
|
||||
futures[worker.submit(self.save_stock, _symbol, df)] = _symbol
|
||||
elif isinstance(_resp, dict):
|
||||
if "timestamp" in _resp[_symbol]:
|
||||
logger.warning(_resp[_symbol])
|
||||
error_symbol.append(_symbol)
|
||||
elif _resp is None:
|
||||
error_symbol.append(_symbol)
|
||||
else:
|
||||
if not (("1m data not available for" in _resp) or ("Data doesn't exist for" in _resp)):
|
||||
error_symbol.append(_symbol)
|
||||
logger.info("save stock data......")
|
||||
for future in tqdm(as_completed(futures)):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
error_symbol.append(futures[future])
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self._mini_symbol_map.keys())
|
||||
return error_symbol
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
@@ -126,20 +203,51 @@ class YahooCollector:
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self._mini_symbol_map.items():
|
||||
self.save_stock(_symbol, max(_df_list, key=len))
|
||||
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
|
||||
|
||||
logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
|
||||
self.download_index_data()
|
||||
|
||||
@abc.abstractmethod
|
||||
def download_index_data(self):
|
||||
"""download index data"""
|
||||
raise NotImplementedError("rewrite download_index_data")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normailze_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
|
||||
class YahooCollectorCN(YahooCollector):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
if self._interval == "1m":
|
||||
return 60 * 4 * 5
|
||||
elif self._interval == "1d":
|
||||
return 252 / 4
|
||||
|
||||
def get_stock_list(self):
|
||||
logger.info("get HS stock symbos......")
|
||||
symbols = get_hs_stock_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
# FIXME: 1m
|
||||
_format = "%Y%m%d"
|
||||
_begin = self._start_datetime.strftime(_format)
|
||||
_end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
df = pd.DataFrame(
|
||||
map(
|
||||
lambda x: x.split(","),
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"],
|
||||
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()["data"][
|
||||
"klines"
|
||||
],
|
||||
)
|
||||
)
|
||||
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
|
||||
@@ -148,59 +256,71 @@ class YahooCollector:
|
||||
df["adjclose"] = df["close"]
|
||||
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
|
||||
|
||||
def normailze_symbol(self, symbol):
|
||||
symbol_s = symbol.split(".")
|
||||
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
|
||||
return symbol
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4):
|
||||
|
||||
class YahooCollectorUS(YahooCollector):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
if self._interval == "1m":
|
||||
return 60 * 6.5 * 5
|
||||
elif self._interval == "1d":
|
||||
return 252 / 4
|
||||
|
||||
def get_stock_list(self):
|
||||
logger.info("get US stock symbols......")
|
||||
symbols = get_us_stock_symbols(qlib_data_path="/data1/data/yahoo_staock_data/backup/us_1d_qlib") + [
|
||||
"^GSPC",
|
||||
"^NDX",
|
||||
"^DJI",
|
||||
]
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def download_index_data(self):
|
||||
pass
|
||||
|
||||
def normailze_symbol(self, symbol):
|
||||
return symbol.upper()
|
||||
|
||||
|
||||
class YahooNormalize:
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
|
||||
def __init__(self, source_dir: [str, Path], target_dir: [str, Path], max_workers: int = 16):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
qlib_dir: str
|
||||
qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data"
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
Concurrent number, default is 16
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if qlib_dir is None:
|
||||
qlib_dir = CUR_DIR.joinpath("qlib_data")
|
||||
self.qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
self.qlib_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_workers = max_workers
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._max_workers = max_workers
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
def normalize_data(self):
|
||||
"""normalize data
|
||||
logger.info("normalize data......")
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
|
||||
|
||||
"""
|
||||
|
||||
def _normalize(file_path: Path):
|
||||
columns = ["open", "close", "high", "low", "volume"]
|
||||
df = pd.read_csv(file_path)
|
||||
def _normalize(source_path: Path):
|
||||
columns = copy.deepcopy(self.COLUMNS)
|
||||
df = pd.read_csv(source_path)
|
||||
df.set_index("date", inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
|
||||
# using China stock market data calendar
|
||||
df = df.reindex(pd.Index(get_calendar_list("ALL")))
|
||||
if self._calendar_list is not None:
|
||||
df = df.reindex(pd.DataFrame(index=self._calendar_list).loc[df.index.min() : df.index.max()].index)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan
|
||||
df["factor"] = df["adjclose"] / df["close"]
|
||||
for _col in columns:
|
||||
@@ -213,22 +333,17 @@ class Run:
|
||||
columns += ["change", "factor"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
df.index.names = ["date"]
|
||||
df.loc[:, columns].to_csv(self.normalize_dir.joinpath(file_path.name))
|
||||
df.loc[:, columns].to_csv(self._target_dir.joinpath(source_path.name))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
file_list = list(self.source_dir.glob("*.csv"))
|
||||
with ThreadPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(_normalize, file_list):
|
||||
p_bar.update()
|
||||
|
||||
def manual_adj_data(self):
|
||||
"""manual adjust data
|
||||
|
||||
Examples
|
||||
--------
|
||||
$ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
|
||||
|
||||
"""
|
||||
"""adjust data"""
|
||||
logger.info("manual adjust data......")
|
||||
|
||||
def _adj(file_path: Path):
|
||||
df = pd.read_csv(file_path)
|
||||
@@ -244,59 +359,148 @@ class Run:
|
||||
df[_col] = df[_col] / _close
|
||||
else:
|
||||
pass
|
||||
df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
|
||||
df.reset_index().to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
|
||||
file_list = list(self.normalize_dir.glob("*.csv"))
|
||||
with ThreadPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._target_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(_adj, file_list):
|
||||
p_bar.update()
|
||||
|
||||
def dump_data(self):
|
||||
"""dump yahoo data
|
||||
def normalize(self):
|
||||
self.normalize_data()
|
||||
self.manual_adj_data()
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class YahooNormalizeUS(YahooNormalize):
|
||||
def _get_calendar_list(self):
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("US_ALL")
|
||||
|
||||
|
||||
class YahooNormalizeCN(YahooNormalize):
|
||||
def _get_calendar_list(self):
|
||||
# TODO: from MSN
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
|
||||
"""
|
||||
DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump(
|
||||
include_fields="close,open,high,low,volume,change,factor"
|
||||
)
|
||||
|
||||
def download_data(self, asynchronous=False, max_collector_count=5, delay=0):
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
region: str
|
||||
region, value from ["CN", "US"], default "CN"
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
self.region = region
|
||||
|
||||
def download_data(
|
||||
self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=True
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 5
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1m, 1d], default 1m
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default True
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source
|
||||
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
YahooCollector(
|
||||
|
||||
_class = getattr(self._cur_module, f"YahooCollector{self.region.upper()}")
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
asynchronous=asynchronous,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
).collector_data()
|
||||
|
||||
def download_index_data(self):
|
||||
YahooCollector(self.source_dir).download_index_data()
|
||||
def normalize_data(self):
|
||||
"""normalize data
|
||||
|
||||
def download_bench_data(self):
|
||||
"""download bench stock data(SH000300)"""
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}")
|
||||
_class(self.source_dir, self.normalize_dir, self.max_workers).normalize()
|
||||
|
||||
def collector_data(self):
|
||||
"""download -> normalize -> dump data
|
||||
def collector_data(
|
||||
self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=False
|
||||
):
|
||||
"""download -> normalize
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 5
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1m, 1d], default 1m
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
|
||||
Examples
|
||||
-------
|
||||
$ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
"""
|
||||
self.download_data()
|
||||
self.download_data(
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
)
|
||||
self.normalize_data()
|
||||
self.manual_adj_data()
|
||||
self.dump_data()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user