From 9d8a8c6f136a0773afb4f50b9e54be4ffb154a87 Mon Sep 17 00:00:00 2001 From: Hyeongmin Moon Date: Mon, 5 Dec 2022 15:50:28 +0900 Subject: [PATCH] Resolve issues while running Automatic update of daily frequency data (from yahoo finance) for US region (#1358) * Update YahooNormalizeUS1dExtend(#1196) * Prevent pandas read_csv errors while running update_data_to_bin for US region * Fix parse_index error while running update_data_to_bin for US region * prevent pandas.read_csv error on specific symbol names * Reordering parameters for better rendering * removes prefix during feature_dir existence checking * add explanation comments --- qlib/utils/__init__.py | 2 +- scripts/data_collector/base.py | 15 ++++++++++++++- scripts/data_collector/yahoo/collector.py | 6 +++++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index ad4feffc0..3bfacc288 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -749,7 +749,7 @@ def exists_qlib_data(qlib_dir): return False # check instruments - code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir())) + code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir())) _instrument = instruments_dir.joinpath("all.txt") miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names) if miss_code and any(map(lambda x: "sht" not in x, miss_code)): diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index b06a3e292..e3cf1fcac 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -289,7 +289,20 @@ class Normalize: def _executor(self, file_path: Path): file_path = Path(file_path) - df = pd.read_csv(file_path) + + # some symbol_field values such as TRUE, NA are decoded as True(bool), NaN(np.float) by pandas default csv parsing. + # manually defines dtype and na_values of the symbol_field. + default_na = pd._libs.parsers.STR_NA_VALUES + symbol_na = default_na.copy() + symbol_na.remove("NA") + columns = pd.read_csv(file_path, nrows=0).columns + df = pd.read_csv( + file_path, + dtype={self._symbol_field_name: str}, + keep_default_na=False, + na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns}, + ) + df = self._normalize_obj.normalize(df) if df is not None and not df.empty: if self._end_date is not None: diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 61f801c95..95a286b45 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -817,6 +817,10 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d): pass +class YahooNormalizeUS1dExtend(YahooNormalizeUS, YahooNormalize1dExtend): + pass + + class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline): CALC_PAUSED_NUM = False @@ -1196,7 +1200,7 @@ class Run(BaseRun): importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments" ) for _index in index_list: - get_instruments(str(qlib_data_1d_dir), _index) + get_instruments(str(qlib_data_1d_dir), _index, market_index=f"{_region}_index") if __name__ == "__main__":