mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix YahooNormalizeCN1minOffline bugs
This commit is contained in:
@@ -28,10 +28,9 @@ def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name:
|
||||
_dates = pd.to_datetime(_result[date_field_name])
|
||||
|
||||
_tmp_min = _dates.min()
|
||||
min_date = min_date(min_date, _tmp_min) if min_date is not None else _tmp_min
|
||||
|
||||
min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min
|
||||
_tmp_max = _dates.max()
|
||||
max_date = min_date(max_date, _tmp_max) if max_date is not None else _tmp_max
|
||||
max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max
|
||||
p_bar.update()
|
||||
return min_date, max_date
|
||||
|
||||
@@ -81,7 +80,7 @@ def fill_1min_using_1d(
|
||||
tmp_df = pd.read_csv(list(data_1min_dir.glob("*.csv"))[0])
|
||||
columns = tmp_df.columns
|
||||
_si = tmp_df[symbol_field_name].first_valid_index()
|
||||
is_lower = tmp_df.loc[tmp_df][symbol_field_name].islower()
|
||||
is_lower = tmp_df.loc[_si][symbol_field_name].islower()
|
||||
for symbol in tqdm(miss_symbols):
|
||||
if is_lower:
|
||||
symbol = symbol.lower()
|
||||
@@ -89,8 +88,11 @@ def fill_1min_using_1d(
|
||||
index_1min = generate_minutes_calendar_from_daily(index_1d)
|
||||
index_1min.name = date_field_name
|
||||
_df = pd.DataFrame(columns=columns, index=index_1min)
|
||||
if date_field_name in _df.columns:
|
||||
del _df[date_field_name]
|
||||
_df.reset_index(inplace=True)
|
||||
_df[symbol_field_name] = symbol
|
||||
_df["paused_num"] = 0
|
||||
_df.to_csv(data_1min_dir.joinpath(f"{symbol}.csv"), index=False)
|
||||
|
||||
|
||||
|
||||
@@ -473,21 +473,6 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
CONSISTENT_1d = False
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name)
|
||||
_class_name = self.__class__.__name__.replace("min", "d")
|
||||
_class = getattr(importlib.import_module("collector"), _class_name) # type: Type[YahooNormalize]
|
||||
self.data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
|
||||
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
calendar_list_1d = getattr(self, "_calendar_list_1d", None)
|
||||
@@ -512,7 +497,10 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
"""
|
||||
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
|
||||
if not (data_1d is None or data_1d.empty):
|
||||
data_1d = self.data_1d_obj.normalize(data_1d)
|
||||
_class_name = self.__class__.__name__.replace("min", "d")
|
||||
_class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name)
|
||||
data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
|
||||
data_1d = data_1d_obj.normalize(data_1d)
|
||||
return data_1d
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
@@ -525,6 +513,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
|
||||
data_1d = data_1d.copy()
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
# TODO: np.nan or 1 or 0
|
||||
@@ -700,8 +689,8 @@ class YahooNormalizeCN1minOffline(YahooNormalizeCN1min):
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self.qlib_data_1d_dir = qlib_data_1d_dir
|
||||
super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self._all_1d_data = self._get_all_1d_data()
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
|
||||
Reference in New Issue
Block a user