1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix YahooCollector getting 1min data occasionally missing

This commit is contained in:
zhupr
2021-06-05 16:01:01 +08:00
parent 6f150f3fd6
commit 554b9c7826
2 changed files with 36 additions and 41 deletions

View File

@@ -22,9 +22,9 @@ class BaseCollector(abc.ABC):
NORMAL_FLAG = "NORMAL"
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date()
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date()
DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D
INTERVAL_1min = "1min"
INTERVAL_1d = "1d"
@@ -35,7 +35,7 @@ class BaseCollector(abc.ABC):
start=None,
end=None,
interval="1d",
max_workers=4,
max_workers=1,
max_collector_count=2,
delay=0,
check_data_length: bool = False,
@@ -48,7 +48,7 @@ class BaseCollector(abc.ABC):
save_dir: str
instrument save dir
max_workers: int
workers, default 4
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
max_collector_count: int
default 2
delay: float
@@ -310,7 +310,7 @@ class Normalize:
class BaseRun(abc.ABC):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
"""
Parameters
@@ -320,7 +320,7 @@ class BaseRun(abc.ABC):
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
interval: str
freq, value from [1min, 1d], default 1d
"""

View File

@@ -25,6 +25,7 @@ CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
from data_collector.utils import (
deco_retry,
get_calendar_list,
get_hs_stock_symbols,
get_us_stock_symbols,
@@ -92,10 +93,6 @@ class YahooCollector(BaseCollector):
else:
raise ValueError(f"interval error: {self.interval}")
# using for 1min
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
@@ -140,40 +137,36 @@ class YahooCollector(BaseCollector):
def get_data(
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
@deco_retry(retry_sleep=1)
def _get_simple(start_, end_):
self.sleep()
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
return self.get_data_from_remote(
resp = self.get_data_from_remote(
symbol,
interval=_remote_interval,
start=start_,
end=end_,
)
if resp is None or resp.empty:
raise ValueError(f"get data error: {symbol}--{start_}--{end_}")
return resp
_result = None
if interval == self.INTERVAL_1d:
_result = _get_simple(start_datetime, end_datetime)
elif interval == self.INTERVAL_1min:
if self._next_datetime >= self._latest_datetime:
_result = _get_simple(start_datetime, end_datetime)
else:
_res = []
def _get_multi(start_, end_):
_resp = _get_simple(start_, end_)
if _resp is not None and not _resp.empty:
_res.append(_resp)
for _s, _e in (
(self.start_datetime, self._next_datetime),
(self._latest_datetime, self.end_datetime),
):
_get_multi(_s, _e)
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
_end = _start + pd.Timedelta(days=1)
_get_multi(_start, _end)
if _res:
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
_res = []
_start = self.start_datetime
while _start < self.end_datetime:
_tmp_end = min(_start + pd.Timedelta(days=7), self.end_datetime)
try:
_resp = _get_simple(_start, _tmp_end)
_res.append(_resp)
except ValueError as e:
pass
_start = _tmp_end
if _res:
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
else:
raise ValueError(f"cannot support {self.interval}")
return pd.DataFrame() if _result is None else _result
@@ -520,6 +513,10 @@ class YahooNormalize1min(YahooNormalize, ABC):
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
)
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
return data_1d
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
# TODO: using daily data factor
if df.empty:
@@ -529,9 +526,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
# get 1d data from yahoo
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
data_1d = YahooCollector.get_data_from_remote(
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
)
data_1d = self.get_1d_data(symbol, _start, _end)
if data_1d is None or data_1d.empty:
df["factor"] = 1
# TODO: np.nan or 1 or 0
@@ -579,21 +574,21 @@ class YahooNormalize1min(YahooNormalize, ABC):
def calc_paused_num(self, df: pd.DataFrame):
_symbol = df.iloc[0][self._symbol_field_name]
df = df.copy()
df["date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
# remove data that starts and ends with `np.nan` all day
all_data = []
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
all_nan_nums = 0
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
not_nan_nums = 0
for _date, _df in df.groupby(level="date"):
for _date, _df in df.groupby("_tmp_date"):
_df["paused"] = 0
if not _df.loc[_df["volume"] < 0].empty:
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
_df.loc[_df["volume"] < 0, "volume"] = np.nan
check_fields = set(_df.columns) - {
"date",
"_tmp_date",
"paused",
"factor",
self._date_field_name,
@@ -618,7 +613,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
logger.warning(f"data is empty: {_symbol}")
df = pd.DataFrame()
return df
del df["date"]
del df["_tmp_date"]
return df
@abc.abstractmethod
@@ -690,7 +685,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
"""
Parameters
@@ -700,7 +695,7 @@ class Run(BaseRun):
normalize_dir: str
Directory for normalize data, default "Path(__file__).parent/normalize"
max_workers: int
Concurrent number, default is 4
Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
interval: str
freq, value from [1min, 1d], default 1d
region: str