diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 1a8d479d9..72bd1be18 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -39,6 +39,7 @@ _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None _US_SYMBOLS = None +_IN_SYMBOLS = None _EN_FUND_SYMBOLS = None _CALENDAR_MAP = {} @@ -298,6 +299,47 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: return _US_SYMBOLS +def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list: + """get IN stock symbols + + Returns + ------- + stock symbols + """ + global _IN_SYMBOLS + + @deco_retry + def _get_nifty(): + url = f"https://www1.nseindia.com/content/equities/EQUITY_L.csv" + df = pd.read_csv(url) + df = df.rename(columns={"SYMBOL": "Symbol"}) + df['Symbol'] = df['Symbol'] + ".NS" + _symbols = df["Symbol"].dropna() + _symbols = _symbols.unique().tolist() + return _symbols + + if _IN_SYMBOLS is None: + _all_symbols = _get_nifty() + if qlib_data_path is not None: + for _index in ["nifty"]: + ins_df = pd.read_csv( + Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"), + sep="\t", + names=["symbol", "start_date", "end_date"], + ) + _all_symbols += ins_df["symbol"].unique().tolist() + + def _format(s_): + s_ = s_.replace(".", "-") + s_ = s_.strip("$") + s_ = s_.strip("*") + return s_ + + _IN_SYMBOLS = sorted(set(_all_symbols)) + + return _IN_SYMBOLS + + def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: """get en fund symbols diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 6a128a5be..e262dac19 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -34,6 +34,7 @@ from data_collector.utils import ( get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols, + get_in_stock_symbols, generate_minutes_calendar_from_daily, ) @@ -279,6 +280,28 @@ class YahooCollectorUS1min(YahooCollectorUS): pass +class YahooCollectorIN(YahooCollector, ABC): + def get_instrument_list(self): + logger.info("get INDIA stock symbols......") + symbols = get_in_stock_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + + def download_index_data(self): + pass + + def normalize_symbol(self, symbol): + return code_to_fname(symbol).upper() + + @property + def _timezone(self): + return "Asia/Kolkata" + + +class YahooCollectorIN1d(YahooCollectorIN): + pass + + class YahooNormalize(BaseNormalize): COLUMNS = ["open", "close", "high", "low", "volume"] DAILY_FORMAT = "%Y-%m-%d"