diff --git a/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py b/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py index 4abca3361..0a721298d 100644 --- a/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py +++ b/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py @@ -28,10 +28,9 @@ def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: _dates = pd.to_datetime(_result[date_field_name]) _tmp_min = _dates.min() - min_date = min_date(min_date, _tmp_min) if min_date is not None else _tmp_min - + min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min _tmp_max = _dates.max() - max_date = min_date(max_date, _tmp_max) if max_date is not None else _tmp_max + max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max p_bar.update() return min_date, max_date @@ -81,7 +80,7 @@ def fill_1min_using_1d( tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0]) columns = tmp_df.columns _si = tmp_df[symbol_field_name].first_valid_index() - is_lower = tmp_df.loc[tmp_df][symbol_field_name].islower() + is_lower = tmp_df.loc[_si][symbol_field_name].islower() for symbol in tqdm(miss_symbols): if is_lower: symbol = symbol.lower() @@ -89,8 +88,11 @@ def fill_1min_using_1d( index_1min = generate_minutes_calendar_from_daily(index_1d) index_1min.name = date_field_name _df = pd.DataFrame(columns=columns, index=index_1min) + if date_field_name in _df.columns: + del _df[date_field_name] _df.reset_index(inplace=True) _df[symbol_field_name] = symbol + _df["paused_num"] = 0 _df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 16b0a32ba..58e1d3009 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -473,21 +473,6 @@ class YahooNormalize1min(YahooNormalize, ABC): CONSISTENT_1d = False CALC_PAUSED_NUM = False - def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): - """ - - Parameters - ---------- - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name) - _class_name = self.__class__.__name__.replace("min", "d") - _class = getattr(importlib.import_module("collector"), _class_name) # type: Type[YahooNormalize] - self.data_1d_obj = _class(self._date_field_name, self._symbol_field_name) - @property def calendar_list_1d(self): calendar_list_1d = getattr(self, "_calendar_list_1d", None) @@ -512,7 +497,10 @@ class YahooNormalize1min(YahooNormalize, ABC): """ data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) if not (data_1d is None or data_1d.empty): - data_1d = self.data_1d_obj.normalize(data_1d) + _class_name = self.__class__.__name__.replace("min", "d") + _class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name) + data_1d_obj = _class(self._date_field_name, self._symbol_field_name) + data_1d = data_1d_obj.normalize(data_1d) return data_1d def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: @@ -525,6 +513,7 @@ class YahooNormalize1min(YahooNormalize, ABC): _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) + data_1d = data_1d.copy() if data_1d is None or data_1d.empty: df["factor"] = 1 # TODO: np.nan or 1 or 0 @@ -700,8 +689,8 @@ class YahooNormalizeCN1minOffline(YahooNormalizeCN1min): symbol_field_name: str symbol field name, default is symbol """ - super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name) self.qlib_data_1d_dir = qlib_data_1d_dir + super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name) self._all_1d_data = self._get_all_1d_data() def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: