1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 19:41:00 +08:00

refactor yahoo_collector && support US and 1m

This commit is contained in:
zhupr
2020-11-16 16:21:59 +08:00
committed by you-n-g
parent 77bfeadb65
commit ae300592a0
3 changed files with 493 additions and 152 deletions

View File

@@ -3,56 +3,69 @@
import re
import time
import bisect
import pickle
import requests
import functools
from pathlib import Path
import pandas as pd
from lxml import etree
from loguru import logger
from yahooquery import Ticker
SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231"
CALENDAR_BENCH_URL_MAP = {
"CSI300": CALENDAR_URL_BASE.format(bench_code="000300"),
"CSI100": CALENDAR_URL_BASE.format(bench_code="000903"),
"CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"),
"CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"),
# NOTE: Use the time series of SH600000 as the sequence of all stocks
"ALL": CALENDAR_URL_BASE.format(bench_code="600000"),
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
"US_ALL": "^GSPC",
}
_BENCH_CALENDAR_LIST = None
_ALL_CALENDAR_LIST = None
_HS_SYMBOLS = None
_US_SYMBOLS = None
_CALENDAR_MAP = {}
# NOTE: Until 2020-10-20 20:00:00
MINIMUM_SYMBOLS_NUM = 3900
def get_hs_calendar_list(bench_code="CSI300") -> list:
def get_calendar_list(bench_code="CSI300") -> list:
"""get SH/SZ history calendar list
Parameters
----------
bench_code: str
value from ["CSI300", "CSI500", "ALL"]
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
Returns
-------
history calendar list
"""
logger.info(f"get calendar list: {bench_code}......")
def _get_calendar(url):
_value_list = requests.get(url).json()["data"]["klines"]
return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list))
calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
if bench_code.startswith("US_"):
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
else:
calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code])
_CALENDAR_MAP[bench_code] = calendar
logger.info(f"end of get calendar list: {bench_code}.")
return calendar
@@ -68,13 +81,14 @@ def get_hs_stock_symbols() -> list:
def _get_symbol():
_res = set()
for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")):
resp = requests.get(SYMBOLS_URL.format(s_type=_k))
resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k))
_res |= set(
map(
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
)
)
time.sleep(3)
return _res
if _HS_SYMBOLS is None:
@@ -99,6 +113,84 @@ def get_hs_stock_symbols() -> list:
return _HS_SYMBOLS
def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
"""get US stock symbols
Returns
-------
stock symbols
"""
global _US_SYMBOLS
@deco_retry
def _get_eastmoney():
url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12"
resp = requests.get(url)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
except Exception as e:
logger.warning(f"request error: {e}")
raise
if len(_symbols) < 8000:
raise ValueError("request error")
return _symbols
@deco_retry
def _get_nasdaq():
_res_symbols = []
for _name in ["otherlisted", "nasdaqtraded"]:
url = f"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt"
df = pd.read_csv(url, sep="|")
df = df.rename(columns={"ACT Symbol": "Symbol"})
_symbols = df["Symbol"].dropna()
_symbols = _symbols.str.replace("$", "-P", regex=False)
_symbols = _symbols.str.replace(".W", "-WT", regex=False)
_symbols = _symbols.str.replace(".U", "-UN", regex=False)
_symbols = _symbols.str.replace(".R", "-RI", regex=False)
_symbols = _symbols.str.replace(".", "-", regex=False)
_res_symbols += _symbols.unique().tolist()
return _res_symbols
@deco_retry
def _get_nyse():
url = "https://www.nyse.com/api/quotes/filter"
_parms = {
"instrumentType": "EQUITY",
"pageNumber": 1,
"sortColumn": "NORMALIZED_TICKER",
"sortOrder": "ASC",
"maxResultsPerPage": 10000,
"filterToken": "",
}
resp = requests.post(url, json=_parms)
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()]
except Exception as e:
logger.warning(f"request error: {e}")
_symbols = []
return _symbols
if _US_SYMBOLS is None:
_all_symbols = _get_eastmoney() + _get_nasdaq() + _get_nyse()
if qlib_data_path is not None:
for _index in ["nasdaq100", "sp500"]:
ins_df = pd.read_csv(
Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"),
sep="\t",
names=["symbol", "start_date", "end_date"],
)
_all_symbols += ins_df["symbol"].unique().tolist()
_US_SYMBOLS = sorted(
set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))
)
return _US_SYMBOLS
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
"""symbol suffix to prefix
@@ -137,5 +229,52 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
return res.upper() if capital else res.lower()
if __name__ == '__main__':
def deco_retry(retry: int = 5, retry_sleep: int = 3):
def deco_func(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
_retry = 5 if callable(retry) else retry
_result = None
for _i in range(1, _retry + 1):
try:
_result = func(*args, **kwargs)
break
except Exception as e:
logger.warning(f"{func.__name__}: {_i} :{e}")
if _i == _retry:
raise
time.sleep(retry_sleep)
return _result
return wrapper
return deco_func(retry) if callable(retry) else deco_func
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
"""get trading date by shift
Parameters
----------
trading_list: list
trading calendar list
shift : int
shift, default is 1
trading_date : pd.Timestamp
trading date
Returns
-------
"""
trading_date = pd.Timestamp(trading_date)
left_index = bisect.bisect_left(trading_list, trading_date)
try:
res = trading_list[left_index + shift]
except IndexError:
res = trading_date
return res
if __name__ == "__main__":
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM

View File

@@ -18,31 +18,29 @@ pip install -r requirements.txt
## Collector Data
### Download data -> Normalize data -> Dump data
### Download data and Normalize data
```bash
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
```
### Download Data From Yahoo Finance
### Download Data
```bash
python collector.py download_data --source_dir ~/.qlib/stock_data/source
python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
```
### Normalize Yahoo Finance Data
### Normalize Data
```bash
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN
```
### Manual Ajust Yahoo Finance Data
### Help
```bash
python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
pythono collector.py collector_data --help
```
### Dump Yahoo Finance Data
## Parameters
```bash
python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
```
- interval: 1m or 1d
- region: CN or US

View File

@@ -1,8 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import sys
import copy
import time
import datetime
import importlib
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -16,30 +20,81 @@ from yahooquery import Ticker
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from dump_bin import DumpData
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
MIN_NUMBERS_TRADING = 252 / 4
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
REGION_CN = "CN"
REGION_US = "US"
class YahooCollector:
def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=False, max_collector_count=5, delay=0):
START_DATETIME = pd.Timestamp("2000-01-01")
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 5))
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
def __init__(
self,
save_dir: [str, Path],
start=None,
end=None,
interval="1d",
max_workers=4,
max_collector_count=5,
delay=0,
check_data_length: bool = False,
):
"""
Parameters
----------
save_dir: str
stock save dir
max_workers: int
workers, default 4
max_collector_count: int
default 5
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1m, 1d], default 1m
start: str
start datetime, default None
end: str
end datetime, default None
check_data_length: bool
check data length, by default False
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self._delay = delay
self._stock_list = None
self.stock_list = sorted(set(self.get_stock_list()))
self.max_workers = max_workers
self._asynchronous = asynchronous
self._max_collector_count = max_collector_count
self._mini_symbol_map = {}
self._interval = interval
self._check_small_data = check_data_length
self._start_datetime = pd.Timestamp(start) if start else self.START_DATETIME
self._end_datetime = pd.Timestamp(end) if end else self.END_DATETIME
if self._interval == "1m":
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
elif self._interval == "1d":
self._start_datetime = max(self._start_datetime, self.START_DATETIME)
else:
raise ValueError(f"interval error: {self._interval}")
self._end_datetime = min(self._end_datetime, self.END_DATETIME)
@property
def stock_list(self):
if self._stock_list is None:
self._stock_list = get_hs_stock_symbols()
return self._stock_list
@abc.abstractmethod
def min_numbers_trading(self):
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
raise NotImplementedError("")
@abc.abstractmethod
def get_stock_list(self):
raise NotImplementedError("")
def _sleep(self):
time.sleep(self._delay)
@@ -57,63 +112,85 @@ class YahooCollector:
if df.empty:
raise ValueError("df is empty")
symbol_s = symbol.split(".")
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
symbol = self.normailze_symbol(symbol)
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
df["symbol"] = symbol
df.to_csv(stock_path, index=False)
def _temp_save_small_data(self, symbol, df):
if len(df) <= MIN_NUMBERS_TRADING:
logger.warning(f"the number of trading days of {symbol} is less than {MIN_NUMBERS_TRADING}!")
def _save_small_data(self, symbol, df):
if len(df) <= self.min_numbers_trading:
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
_temp = self._mini_symbol_map.setdefault(symbol, [])
_temp.append(df.copy())
return symbol
else:
if symbol in self._mini_symbol_map:
self._mini_symbol_map.pop(symbol)
return None
def _get_from_remote(self, symbol):
if self._interval == "1d":
self._sleep()
try:
resp = Ticker(symbol, asynchronous=False).history(
interval=self._interval, start=self._start_datetime, end=self._end_datetime
)
except Exception as e:
logger.warning(f"{symbol}-{self._interval}-{self._start_datetime}-{self._end_datetime}:{e}")
resp = None
yield resp
elif self._interval == "1m":
_res = []
for _start in pd.date_range(self._start_datetime, self._end_datetime + pd.Timedelta(days=-1)):
_end = _start + pd.Timedelta(days=1)
self._sleep()
try:
resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=_start, end=_end)
if isinstance(resp, pd.DataFrame):
_res.append(resp)
except Exception as e:
logger.warning(f"{symbol}-{self._interval}-{_start}-{_end}:{e}")
if _res:
yield pd.concat(_res, sort=False).sort_values(["symbol", "date"])
else:
yield None
else:
raise ValueError(f"cannot support {self._interval}")
def _collector(self, stock_list):
error_symbol = []
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
futures = {}
p_bar = tqdm(total=len(stock_list))
for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]:
self._sleep()
resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history(
period="max"
)
if isinstance(resp, dict):
for symbol, df in resp.items():
if isinstance(df, pd.DataFrame):
self._temp_save_small_data(self, df)
futures[
worker.submit(
self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"})
)
] = symbol
else:
error_symbol.append(symbol)
else:
for symbol, df in resp.reset_index().groupby("symbol"):
self._temp_save_small_data(self, df)
futures[worker.submit(self.save_stock, symbol, df)] = symbol
p_bar.update(self.max_workers)
p_bar.close()
with tqdm(total=len(futures.values())) as p_bar:
for future in as_completed(futures):
try:
future.result()
except Exception as e:
logger.error(e)
error_symbol.append(futures[future])
p_bar.update()
for _symbol in tqdm(stock_list):
for _resp in self._get_from_remote(_symbol):
if isinstance(_resp, pd.DataFrame):
df = _resp.reset_index()
if self._check_small_data:
if self._save_small_data(_symbol, df) is not None:
error_symbol.append(_symbol)
futures[worker.submit(self.save_stock, _symbol, df)] = _symbol
elif isinstance(_resp, dict):
if "timestamp" in _resp[_symbol]:
logger.warning(_resp[_symbol])
error_symbol.append(_symbol)
elif _resp is None:
error_symbol.append(_symbol)
else:
if not (("1m data not available for" in _resp) or ("Data doesn't exist for" in _resp)):
error_symbol.append(_symbol)
logger.info("save stock data......")
for future in tqdm(as_completed(futures)):
try:
future.result()
except Exception as e:
logger.error(e)
error_symbol.append(futures[future])
print(error_symbol)
logger.info(f"error symbol nums: {len(error_symbol)}")
logger.info(f"current get symbol nums: {len(stock_list)}")
error_symbol.extend(self._mini_symbol_map.keys())
return error_symbol
return sorted(set(error_symbol))
def collector_data(self):
"""collector data"""
@@ -126,20 +203,51 @@ class YahooCollector:
stock_list = self._collector(stock_list)
logger.info(f"{i+1} finish.")
for _symbol, _df_list in self._mini_symbol_map.items():
self.save_stock(_symbol, max(_df_list, key=len))
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}")
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
self.download_index_data()
@abc.abstractmethod
def download_index_data(self):
"""download index data"""
raise NotImplementedError("rewrite download_index_data")
@abc.abstractmethod
def normailze_symbol(self, symbol: str):
"""normalize symbol"""
raise NotImplementedError("rewrite normalize_symbol")
class YahooCollectorCN(YahooCollector):
@property
def min_numbers_trading(self):
if self._interval == "1m":
return 60 * 4 * 5
elif self._interval == "1d":
return 252 / 4
def get_stock_list(self):
logger.info("get HS stock symbos......")
symbols = get_hs_stock_symbols()
logger.info(f"get {len(symbols)} symbols.")
return symbols
def download_index_data(self):
# TODO: from MSN
# FIXME: 1m
_format = "%Y%m%d"
_begin = self._start_datetime.strftime(_format)
_end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format)
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
logger.info(f"get bench data: {_index_name}({_index_code})......")
df = pd.DataFrame(
map(
lambda x: x.split(","),
requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"],
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()["data"][
"klines"
],
)
)
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
@@ -148,59 +256,71 @@ class YahooCollector:
df["adjclose"] = df["close"]
df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False)
def normailze_symbol(self, symbol):
symbol_s = symbol.split(".")
symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}"
return symbol
class Run:
def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4):
class YahooCollectorUS(YahooCollector):
@property
def min_numbers_trading(self):
if self._interval == "1m":
return 60 * 6.5 * 5
elif self._interval == "1d":
return 252 / 4
def get_stock_list(self):
logger.info("get US stock symbols......")
symbols = get_us_stock_symbols(qlib_data_path="/data1/data/yahoo_staock_data/backup/us_1d_qlib") + [
"^GSPC",
"^NDX",
"^DJI",
]
logger.info(f"get {len(symbols)} symbols.")
return symbols
def download_index_data(self):
pass
def normailze_symbol(self, symbol):
return symbol.upper()
class YahooNormalize:
COLUMNS = ["open", "close", "high", "low", "volume"]
def __init__(self, source_dir: [str, Path], target_dir: [str, Path], max_workers: int = 16):
"""
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
qlib_dir: str
qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data"
source_dir: str or Path
The directory where the raw data collected from the Internet is saved
target_dir: str or Path
Directory for normalize data
max_workers: int
Concurrent number, default is 4
Concurrent number, default is 16
"""
if source_dir is None:
source_dir = CUR_DIR.joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
if qlib_dir is None:
qlib_dir = CUR_DIR.joinpath("qlib_data")
self.qlib_dir = Path(qlib_dir).expanduser().resolve()
self.qlib_dir.mkdir(parents=True, exist_ok=True)
self.max_workers = max_workers
if not (source_dir and target_dir):
raise ValueError("source_dir and target_dir cannot be None")
self._source_dir = Path(source_dir).expanduser()
self._target_dir = Path(target_dir).expanduser()
self._max_workers = max_workers
self._calendar_list = self._get_calendar_list()
def normalize_data(self):
"""normalize data
logger.info("normalize data......")
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize
"""
def _normalize(file_path: Path):
columns = ["open", "close", "high", "low", "volume"]
df = pd.read_csv(file_path)
def _normalize(source_path: Path):
columns = copy.deepcopy(self.COLUMNS)
df = pd.read_csv(source_path)
df.set_index("date", inplace=True)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
# using China stock market data calendar
df = df.reindex(pd.Index(get_calendar_list("ALL")))
if self._calendar_list is not None:
df = df.reindex(pd.DataFrame(index=self._calendar_list).loc[df.index.min() : df.index.max()].index)
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan
df["factor"] = df["adjclose"] / df["close"]
for _col in columns:
@@ -213,22 +333,17 @@ class Run:
columns += ["change", "factor"]
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
df.index.names = ["date"]
df.loc[:, columns].to_csv(self.normalize_dir.joinpath(file_path.name))
df.loc[:, columns].to_csv(self._target_dir.joinpath(source_path.name))
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
file_list = list(self.source_dir.glob("*.csv"))
with ThreadPoolExecutor(max_workers=self._max_workers) as worker:
file_list = list(self._source_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(_normalize, file_list):
p_bar.update()
def manual_adj_data(self):
"""manual adjust data
Examples
--------
$ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize
"""
"""adjust data"""
logger.info("manual adjust data......")
def _adj(file_path: Path):
df = pd.read_csv(file_path)
@@ -244,59 +359,148 @@ class Run:
df[_col] = df[_col] / _close
else:
pass
df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False)
df.reset_index().to_csv(self._target_dir.joinpath(file_path.name), index=False)
with ThreadPoolExecutor(max_workers=self.max_workers) as worker:
file_list = list(self.normalize_dir.glob("*.csv"))
with ThreadPoolExecutor(max_workers=self._max_workers) as worker:
file_list = list(self._target_dir.glob("*.csv"))
with tqdm(total=len(file_list)) as p_bar:
for _ in worker.map(_adj, file_list):
p_bar.update()
def dump_data(self):
"""dump yahoo data
def normalize(self):
self.normalize_data()
self.manual_adj_data()
Examples
---------
$ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
@abc.abstractmethod
def _get_calendar_list(self):
"""Get benchmark calendar"""
raise NotImplementedError("")
class YahooNormalizeUS(YahooNormalize):
def _get_calendar_list(self):
# TODO: from MSN
return get_calendar_list("US_ALL")
class YahooNormalizeCN(YahooNormalize):
def _get_calendar_list(self):
# TODO: from MSN
return get_calendar_list("ALL")
class Run:
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
"""
DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump(
include_fields="close,open,high,low,volume,change,factor"
)
def download_data(self, asynchronous=False, max_collector_count=5, delay=0):
Parameters
----------
source_dir: str
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
region: str
region, value from ["CN", "US"], default "CN"
"""
if source_dir is None:
source_dir = CUR_DIR.joinpath("source")
self.source_dir = Path(source_dir).expanduser().resolve()
self.source_dir.mkdir(parents=True, exist_ok=True)
if normalize_dir is None:
normalize_dir = CUR_DIR.joinpath("normalize")
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
self.normalize_dir.mkdir(parents=True, exist_ok=True)
self._cur_module = importlib.import_module("collector")
self.max_workers = max_workers
self.region = region
def download_data(
self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=True
):
"""download data from Internet
Parameters
----------
max_collector_count: int
default 5
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1m, 1d], default 1m
start: str
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check data length, by default True
Examples
---------
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source
# get daily data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
# get 1m data
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
YahooCollector(
_class = getattr(self._cur_module, f"YahooCollector{self.region.upper()}")
_class(
self.source_dir,
max_workers=self.max_workers,
asynchronous=asynchronous,
max_collector_count=max_collector_count,
delay=delay,
start=start,
end=end,
interval=interval,
check_data_length=check_data_length,
).collector_data()
def download_index_data(self):
YahooCollector(self.source_dir).download_index_data()
def normalize_data(self):
"""normalize data
def download_bench_data(self):
"""download bench stock data(SH000300)"""
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN
"""
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}")
_class(self.source_dir, self.normalize_dir, self.max_workers).normalize()
def collector_data(self):
"""download -> normalize -> dump data
def collector_data(
self, max_collector_count=5, delay=0, start=None, end=None, interval="1d", check_data_length=False
):
"""download -> normalize
Parameters
----------
max_collector_count: int
default 5
delay: float
time.sleep(delay), default 0
interval: str
freq, value from [1m, 1d], default 1m
start: str
start datetime, default "2000-01-01"
end: str
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
check_data_length: bool
check data length, by default False
Examples
-------
$ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
"""
self.download_data()
self.download_data(
max_collector_count=max_collector_count,
delay=delay,
start=start,
end=end,
interval=interval,
check_data_length=check_data_length,
)
self.normalize_data()
self.manual_adj_data()
self.dump_data()
if __name__ == "__main__":