mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix YahooCollector getting 1min data occasionally missing
This commit is contained in:
@@ -22,9 +22,9 @@ class BaseCollector(abc.ABC):
|
||||
NORMAL_FLAG = "NORMAL"
|
||||
|
||||
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date()
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date()
|
||||
DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D
|
||||
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
@@ -35,7 +35,7 @@ class BaseCollector(abc.ABC):
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_workers=1,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
@@ -48,7 +48,7 @@ class BaseCollector(abc.ABC):
|
||||
save_dir: str
|
||||
instrument save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
@@ -310,7 +310,7 @@ class Normalize:
|
||||
|
||||
|
||||
class BaseRun(abc.ABC):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -320,7 +320,7 @@ class BaseRun(abc.ABC):
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
|
||||
@@ -25,6 +25,7 @@ CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
|
||||
from data_collector.utils import (
|
||||
deco_retry,
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
get_us_stock_symbols,
|
||||
@@ -92,10 +93,6 @@ class YahooCollector(BaseCollector):
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@@ -140,40 +137,36 @@ class YahooCollector(BaseCollector):
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
@deco_retry(retry_sleep=1)
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
|
||||
return self.get_data_from_remote(
|
||||
resp = self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
if resp is None or resp.empty:
|
||||
raise ValueError(f"get data error: {symbol}--{start_}--{end_}")
|
||||
return resp
|
||||
|
||||
_result = None
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
elif interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
_res = []
|
||||
_start = self.start_datetime
|
||||
while _start < self.end_datetime:
|
||||
_tmp_end = min(_start + pd.Timedelta(days=7), self.end_datetime)
|
||||
try:
|
||||
_resp = _get_simple(_start, _tmp_end)
|
||||
_res.append(_resp)
|
||||
except ValueError as e:
|
||||
pass
|
||||
_start = _tmp_end
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self.interval}")
|
||||
return pd.DataFrame() if _result is None else _result
|
||||
@@ -520,6 +513,10 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
|
||||
return data_1d
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
if df.empty:
|
||||
@@ -529,9 +526,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d = YahooCollector.get_data_from_remote(
|
||||
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
|
||||
)
|
||||
data_1d = self.get_1d_data(symbol, _start, _end)
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
# TODO: np.nan or 1 or 0
|
||||
@@ -579,21 +574,21 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
def calc_paused_num(self, df: pd.DataFrame):
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
df = df.copy()
|
||||
df["date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
# remove data that starts and ends with `np.nan` all day
|
||||
all_data = []
|
||||
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby(level="date"):
|
||||
for _date, _df in df.groupby("_tmp_date"):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
_df.loc[_df["volume"] < 0, "volume"] = np.nan
|
||||
|
||||
check_fields = set(_df.columns) - {
|
||||
"date",
|
||||
"_tmp_date",
|
||||
"paused",
|
||||
"factor",
|
||||
self._date_field_name,
|
||||
@@ -618,7 +613,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
logger.warning(f"data is empty: {_symbol}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
del df["date"]
|
||||
del df["_tmp_date"]
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -690,7 +685,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -700,7 +695,7 @@ class Run(BaseRun):
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
|
||||
Reference in New Issue
Block a user