mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix data collector
This commit is contained in:
@@ -1,5 +1,13 @@
|
||||
# Collect Data From Yahoo Finance
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
|
||||
> **Examples of abnormal data**
|
||||
|
||||
- [SH600000](https://finance.yahoo.com/quote/600000.SS/history?period1=1147046400&period2=1147478400&interval=1d&filter=history&frequency=1d)
|
||||
- [SH600018](https://finance.yahoo.com/quote/600018.SS/history?period1=1158883200&period2=1161907200&interval=1d&filter=history&frequency=1d)
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
@@ -19,17 +20,20 @@ from dump_bin import DumpData
|
||||
from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols
|
||||
|
||||
CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101"
|
||||
MIN_NUMBERS_TRADING = 252 / 4
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=True, max_collector_count=3):
|
||||
def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=False, max_collector_count=5, delay=0):
|
||||
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._delay = delay
|
||||
self._stock_list = None
|
||||
self.max_workers = max_workers
|
||||
self._asynchronous = asynchronous
|
||||
self._max_collector_count = max_collector_count
|
||||
self._mini_symbol_map = {}
|
||||
|
||||
@property
|
||||
def stock_list(self):
|
||||
@@ -37,6 +41,9 @@ class YahooCollector:
|
||||
self._stock_list = get_hs_stock_symbols()
|
||||
return self._stock_list
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
|
||||
def save_stock(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
|
||||
@@ -56,6 +63,15 @@ class YahooCollector:
|
||||
df["symbol"] = symbol
|
||||
df.to_csv(stock_path, index=False)
|
||||
|
||||
def _temp_save_small_data(self, symbol, df):
|
||||
if len(df) <= MIN_NUMBERS_TRADING:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {MIN_NUMBERS_TRADING}!")
|
||||
_temp = self._mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
else:
|
||||
if symbol in self._mini_symbol_map:
|
||||
self._mini_symbol_map.pop(symbol)
|
||||
|
||||
def _collector(self, stock_list):
|
||||
|
||||
error_symbol = []
|
||||
@@ -63,12 +79,14 @@ class YahooCollector:
|
||||
futures = {}
|
||||
p_bar = tqdm(total=len(stock_list))
|
||||
for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]:
|
||||
self._sleep()
|
||||
resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history(
|
||||
period="max"
|
||||
)
|
||||
if isinstance(resp, dict):
|
||||
for symbol, df in resp.items():
|
||||
if isinstance(df, pd.DataFrame):
|
||||
self._temp_save_small_data(self, df)
|
||||
futures[
|
||||
worker.submit(
|
||||
self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"})
|
||||
@@ -78,6 +96,7 @@ class YahooCollector:
|
||||
error_symbol.append(symbol)
|
||||
else:
|
||||
for symbol, df in resp.reset_index().groupby("symbol"):
|
||||
self._temp_save_small_data(self, df)
|
||||
futures[worker.submit(self.save_stock, symbol, df)] = symbol
|
||||
p_bar.update(self.max_workers)
|
||||
p_bar.close()
|
||||
@@ -93,6 +112,7 @@ class YahooCollector:
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self._mini_symbol_map.keys())
|
||||
return error_symbol
|
||||
|
||||
def collector_data(self):
|
||||
@@ -107,7 +127,14 @@ class YahooCollector:
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self._mini_symbol_map.items():
|
||||
self.save_stock(_symbol, max(_df_list, key=len))
|
||||
|
||||
logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
|
||||
self.download_csi300_data()
|
||||
|
||||
def download_csi300_data(self):
|
||||
# TODO: from MSN
|
||||
logger.info(f"get bench data: csi300(SH000300)......")
|
||||
df = pd.DataFrame(map(lambda x: x.split(","), requests.get(CSI300_BENCH_URL).json()["data"]["klines"]))
|
||||
@@ -164,6 +191,7 @@ class Run:
|
||||
df = pd.read_csv(file_path)
|
||||
df.set_index("date", inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
|
||||
# using China stock market data calendar
|
||||
df = df.reindex(pd.Index(get_calendar_list()))
|
||||
@@ -232,7 +260,7 @@ class Run:
|
||||
include_fields="close,open,high,low,volume,change,factor"
|
||||
)
|
||||
|
||||
def download_data(self):
|
||||
def download_data(self, asynchronous=False, max_collector_count=5, delay=0):
|
||||
"""download data from Internet
|
||||
|
||||
Examples
|
||||
@@ -240,7 +268,20 @@ class Run:
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source
|
||||
|
||||
"""
|
||||
YahooCollector(self.source_dir, max_workers=self.max_workers).collector_data()
|
||||
YahooCollector(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
asynchronous=asynchronous,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
).collector_data()
|
||||
|
||||
def download_csi300_data(self):
|
||||
YahooCollector(self.source_dir).download_csi300_data()
|
||||
|
||||
def download_bench_data(self):
|
||||
"""download bench stock data(SH000300)
|
||||
"""
|
||||
|
||||
def collector_data(self):
|
||||
"""download -> normalize -> dump data
|
||||
|
||||
@@ -53,7 +53,7 @@ class GetData:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data", version="v1"):
|
||||
def qlib_data_cn(self, target_dir="~/.qlib/qlib_data/cn_data", version="latest"):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
@@ -61,7 +61,7 @@ class GetData:
|
||||
target_dir: str
|
||||
data save directory
|
||||
version: str
|
||||
data version, value from [v0, v1], by default v1
|
||||
data version, value from [v0, v1, ..., latest], by default latest
|
||||
|
||||
Examples
|
||||
---------
|
||||
|
||||
Reference in New Issue
Block a user