mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Fix yahoo collector
This commit is contained in:
@@ -44,6 +44,7 @@ class YahooCollector:
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
show_1m_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -67,10 +68,13 @@ class YahooCollector:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1m_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._delay = delay
|
||||
self._show_1m_logging = show_1m_logging
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
@@ -83,7 +87,7 @@ class YahooCollector:
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
|
||||
self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME
|
||||
self._end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
|
||||
if self._interval == "1m":
|
||||
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == "1d":
|
||||
@@ -91,8 +95,12 @@ class YahooCollector:
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
# using for 1m
|
||||
self._next_datetime = self.convert_datetime(self._start_datetime.date() + pd.Timedelta(days=1))
|
||||
self._latest_datetime = self.convert_datetime(self._end_datetime.date())
|
||||
|
||||
self._start_datetime = self.convert_datetime(self._start_datetime)
|
||||
self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME))
|
||||
self._end_datetime = self.convert_datetime(self._end_datetime)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -100,20 +108,24 @@ class YahooCollector:
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewirte min_numbers_trading")
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewirte get_stock_list")
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
|
||||
@property
|
||||
@abc.abstractclassmethod
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def convert_datetime(self, dt: pd.Timestamp):
|
||||
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
|
||||
return pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
def convert_datetime(self, dt: [pd.Timestamp, datetime.date, str]):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
@@ -136,7 +148,7 @@ class YahooCollector:
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
with stock_path.open("a") as fp:
|
||||
df.to_csv(fp, index=False, header=None)
|
||||
df.to_csv(fp, index=False, header=False)
|
||||
else:
|
||||
with stock_path.open("w") as fp:
|
||||
df.to_csv(fp, index=False)
|
||||
@@ -155,34 +167,47 @@ class YahooCollector:
|
||||
def _get_from_remote(self, symbol):
|
||||
def _get_simple(start_, end_):
|
||||
self._sleep()
|
||||
error_msg = f"{symbol}-{self._interval}-{start_}-{end_}"
|
||||
|
||||
def _show_logging_func():
|
||||
if self._interval == "1m" and self._show_1m_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}")
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}")
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
_result = None
|
||||
if self._interval == "1d":
|
||||
_result = _get_simple(self._start_datetime, self._end_datetime)
|
||||
elif self._interval == "1m":
|
||||
_start_date = self._start_datetime.date() + pd.Timedelta(days=1)
|
||||
_end_date = self._end_datetime.date()
|
||||
if _start_date >= _end_date:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(self._start_datetime, self._end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None:
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)):
|
||||
for _s, _e in (
|
||||
(self._start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self._end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(_start_date, _end_date, closed="left"):
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
self._sleep()
|
||||
_get_multi(_start, _end)
|
||||
@@ -472,6 +497,7 @@ class Run:
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
show_1m_logging=False,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -491,6 +517,9 @@ class Run:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1m_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
@@ -510,6 +539,7 @@ class Run:
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
show_1m_logging=show_1m_logging,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self):
|
||||
@@ -531,6 +561,7 @@ class Run:
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
show_1m_logging=False,
|
||||
):
|
||||
"""download -> normalize
|
||||
|
||||
@@ -550,6 +581,9 @@ class Run:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1m_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
|
||||
Examples
|
||||
-------
|
||||
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
@@ -562,6 +596,7 @@ class Run:
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
show_1m_logging=show_1m_logging,
|
||||
)
|
||||
self.normalize_data()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user