1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Fix yahoo collector

This commit is contained in:
zhupr
2020-11-28 00:36:23 +08:00
parent 47cbfdc50c
commit 6fc4ff0a62

View File

@@ -44,6 +44,7 @@ class YahooCollector:
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
show_1m_logging: bool = False,
):
"""
@@ -67,10 +68,13 @@ class YahooCollector:
check data length, by default False
limit_nums: int
using for debug, by default None
show_1m_logging: bool
show 1m logging, by default False; if True, there may be many warning logs
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self._delay = delay
self._show_1m_logging = show_1m_logging
self.stock_list = sorted(set(self.get_stock_list()))
if limit_nums is not None:
try:
@@ -83,7 +87,7 @@ class YahooCollector:
self._interval = interval
self._check_small_data = check_data_length
self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME
self._end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
if self._interval == "1m":
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
elif self._interval == "1d":
@@ -91,8 +95,12 @@ class YahooCollector:
else:
raise ValueError(f"interval error: {self._interval}")
# using for 1m
self._next_datetime = self.convert_datetime(self._start_datetime.date() + pd.Timedelta(days=1))
self._latest_datetime = self.convert_datetime(self._end_datetime.date())
self._start_datetime = self.convert_datetime(self._start_datetime)
self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME))
self._end_datetime = self.convert_datetime(self._end_datetime)
@property
@abc.abstractmethod
@@ -100,20 +108,24 @@ class YahooCollector:
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
raise NotImplementedError("rewirte min_numbers_trading")
raise NotImplementedError("rewrite min_numbers_trading")
@abc.abstractmethod
def get_stock_list(self):
raise NotImplementedError("rewirte get_stock_list")
raise NotImplementedError("rewrite get_stock_list")
@property
@abc.abstractclassmethod
@abc.abstractmethod
def _timezone(self):
raise NotImplementedError("rewrite get_timezone")
def convert_datetime(self, dt: pd.Timestamp):
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
return pd.Timestamp(dt, tz=tzlocal(), unit="s")
def convert_datetime(self, dt: [pd.Timestamp, datetime.date, str]):
try:
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
except ValueError as e:
pass
return dt
def _sleep(self):
time.sleep(self._delay)
@@ -136,7 +148,7 @@ class YahooCollector:
df["symbol"] = symbol
if stock_path.exists():
with stock_path.open("a") as fp:
df.to_csv(fp, index=False, header=None)
df.to_csv(fp, index=False, header=False)
else:
with stock_path.open("w") as fp:
df.to_csv(fp, index=False)
@@ -155,34 +167,47 @@ class YahooCollector:
def _get_from_remote(self, symbol):
def _get_simple(start_, end_):
self._sleep()
error_msg = f"{symbol}-{self._interval}-{start_}-{end_}"
def _show_logging_func():
if self._interval == "1m" and self._show_1m_logging:
logger.warning(f"{error_msg}:{_resp}")
try:
_resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_)
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
elif isinstance(_resp, dict):
_temp_data = _resp.get(symbol, {})
if isinstance(_temp_data, str) or (
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
):
_show_logging_func()
else:
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}")
_show_logging_func()
except Exception as e:
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}")
logger.warning(f"{error_msg}:{e}")
_result = None
if self._interval == "1d":
_result = _get_simple(self._start_datetime, self._end_datetime)
elif self._interval == "1m":
_start_date = self._start_datetime.date() + pd.Timedelta(days=1)
_end_date = self._end_datetime.date()
if _start_date >= _end_date:
if self._next_datetime >= self._latest_datetime:
_result = _get_simple(self._start_datetime, self._end_datetime)
else:
_res = []
def _get_multi(start_, end_):
_resp = _get_simple(start_, end_)
if _resp is not None:
if _resp is not None and not _resp.empty:
_res.append(_resp)
for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)):
for _s, _e in (
(self._start_datetime, self._next_datetime),
(self._latest_datetime, self._end_datetime),
):
_get_multi(_s, _e)
for _start in pd.date_range(_start_date, _end_date, closed="left"):
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
_end = _start + pd.Timedelta(days=1)
self._sleep()
_get_multi(_start, _end)
@@ -472,6 +497,7 @@ class Run:
interval="1d",
check_data_length=False,
limit_nums=None,
show_1m_logging=False,
):
"""download data from Internet
@@ -491,6 +517,9 @@ class Run:
check data length, by default False
limit_nums: int
using for debug, by default None
show_1m_logging: bool
show 1m logging, by default False; if True, there may be many warning logs
Examples
---------
# get daily data
@@ -510,6 +539,7 @@ class Run:
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
show_1m_logging=show_1m_logging,
).collector_data()
def normalize_data(self):
@@ -531,6 +561,7 @@ class Run:
interval="1d",
check_data_length=False,
limit_nums=None,
show_1m_logging=False,
):
"""download -> normalize
@@ -550,6 +581,9 @@ class Run:
check data length, by default False
limit_nums: int
using for debug, by default None
show_1m_logging: bool
show 1m logging, by default False; if True, there may be many warning logs
Examples
-------
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
@@ -562,6 +596,7 @@ class Run:
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
show_1m_logging=show_1m_logging,
)
self.normalize_data()