diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index cb51f9b22..d261f11cd 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -22,9 +22,9 @@ class BaseCollector(abc.ABC): NORMAL_FLAG = "NORMAL" DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01") - DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6)) - DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) - DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date() + DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date() + DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D INTERVAL_1min = "1min" INTERVAL_1d = "1d" @@ -35,7 +35,7 @@ class BaseCollector(abc.ABC): start=None, end=None, interval="1d", - max_workers=4, + max_workers=1, max_collector_count=2, delay=0, check_data_length: bool = False, @@ -48,7 +48,7 @@ class BaseCollector(abc.ABC): save_dir: str instrument save dir max_workers: int - workers, default 4 + workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 max_collector_count: int default 2 delay: float @@ -310,7 +310,7 @@ class Normalize: class BaseRun(abc.ABC): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): """ Parameters @@ -320,7 +320,7 @@ class BaseRun(abc.ABC): normalize_dir: str Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int - Concurrent number, default is 4 + Concurrent number, default is 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 interval: str freq, value from [1min, 1d], default 1d """ diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index fcaa9ff92..a48c5f16a 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -25,6 +25,7 @@ CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize from data_collector.utils import ( + deco_retry, get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols, @@ -92,10 +93,6 @@ class YahooCollector(BaseCollector): else: raise ValueError(f"interval error: {self.interval}") - # using for 1min - self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone) - self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone) - self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone) self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone) @@ -140,40 +137,36 @@ class YahooCollector(BaseCollector): def get_data( self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp ) -> pd.DataFrame: + @deco_retry(retry_sleep=1) def _get_simple(start_, end_): self.sleep() _remote_interval = "1m" if interval == self.INTERVAL_1min else interval - return self.get_data_from_remote( + resp = self.get_data_from_remote( symbol, interval=_remote_interval, start=start_, end=end_, ) + if resp is None or resp.empty: + raise ValueError(f"get data error: {symbol}--{start_}--{end_}") + return resp _result = None if interval == self.INTERVAL_1d: _result = _get_simple(start_datetime, end_datetime) elif interval == self.INTERVAL_1min: - if self._next_datetime >= self._latest_datetime: - _result = _get_simple(start_datetime, end_datetime) - else: - _res = [] - - def _get_multi(start_, end_): - _resp = _get_simple(start_, end_) - if _resp is not None and not _resp.empty: - _res.append(_resp) - - for _s, _e in ( - (self.start_datetime, self._next_datetime), - (self._latest_datetime, self.end_datetime), - ): - _get_multi(_s, _e) - for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"): - _end = _start + pd.Timedelta(days=1) - _get_multi(_start, _end) - if _res: - _result = pd.concat(_res, sort=False).sort_values(["symbol", "date"]) + _res = [] + _start = self.start_datetime + while _start < self.end_datetime: + _tmp_end = min(_start + pd.Timedelta(days=7), self.end_datetime) + try: + _resp = _get_simple(_start, _tmp_end) + _res.append(_resp) + except ValueError as e: + pass + _start = _tmp_end + if _res: + _result = pd.concat(_res, sort=False).sort_values(["symbol", "date"]) else: raise ValueError(f"cannot support {self.interval}") return pd.DataFrame() if _result is None else _result @@ -520,6 +513,10 @@ class YahooNormalize1min(YahooNormalize, ABC): calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE ) + def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: + data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) + return data_1d + def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: # TODO: using daily data factor if df.empty: @@ -529,9 +526,7 @@ class YahooNormalize1min(YahooNormalize, ABC): # get 1d data from yahoo _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) - data_1d = YahooCollector.get_data_from_remote( - self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end - ) + data_1d = self.get_1d_data(symbol, _start, _end) if data_1d is None or data_1d.empty: df["factor"] = 1 # TODO: np.nan or 1 or 0 @@ -579,21 +574,21 @@ class YahooNormalize1min(YahooNormalize, ABC): def calc_paused_num(self, df: pd.DataFrame): _symbol = df.iloc[0][self._symbol_field_name] df = df.copy() - df["date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) # remove data that starts and ends with `np.nan` all day all_data = [] # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan all_nan_nums = 0 # Record the number of consecutive occurrences of trading days that are not nan throughout the day not_nan_nums = 0 - for _date, _df in df.groupby(level="date"): + for _date, _df in df.groupby("_tmp_date"): _df["paused"] = 0 if not _df.loc[_df["volume"] < 0].empty: logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}") _df.loc[_df["volume"] < 0, "volume"] = np.nan check_fields = set(_df.columns) - { - "date", + "_tmp_date", "paused", "factor", self._date_field_name, @@ -618,7 +613,7 @@ class YahooNormalize1min(YahooNormalize, ABC): logger.warning(f"data is empty: {_symbol}") df = pd.DataFrame() return df - del df["date"] + del df["_tmp_date"] return df @abc.abstractmethod @@ -690,7 +685,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): """ Parameters @@ -700,7 +695,7 @@ class Run(BaseRun): normalize_dir: str Directory for normalize data, default "Path(__file__).parent/normalize" max_workers: int - Concurrent number, default is 4 + Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 interval: str freq, value from [1min, 1d], default 1d region: str