From 40dd84857c1c27152d2e960d48c3e90af8b72922 Mon Sep 17 00:00:00 2001 From: "Linlang Lv (iSoftStone)" Date: Mon, 28 Feb 2022 03:48:07 +0800 Subject: [PATCH] update-csi500 --- scripts/data_collector/cn_index/collector.py | 100 +++++++++---------- scripts/data_collector/utils.py | 93 +++++++++-------- 2 files changed, 96 insertions(+), 97 deletions(-) diff --git a/scripts/data_collector/cn_index/collector.py b/scripts/data_collector/cn_index/collector.py index 7c23b2394..aed25834b 100644 --- a/scripts/data_collector/cn_index/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -5,7 +5,6 @@ import re import abc import sys import importlib -from tqdm import tqdm from io import BytesIO from typing import List, Iterable from pathlib import Path @@ -14,7 +13,7 @@ import fire import requests import pandas as pd import baostock as bs -from lxml import etree +from tqdm import tqdm from loguru import logger CUR_DIR = Path(__file__).resolve().parent @@ -47,7 +46,6 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None): class CSIIndex(IndexBase): - @property def calendar_list(self) -> List[pd.Timestamp]: """get history trading date @@ -74,20 +72,20 @@ class CSIIndex(IndexBase): @abc.abstractmethod def bench_start_date(self) -> pd.Timestamp: """ - Returns - ------- - index start date - """ + Returns + ------- + index start date + """ raise NotImplementedError("rewrite bench_start_date") @property @abc.abstractmethod def index_code(self) -> str: """ - Returns - ------- - index code - """ + Returns + ------- + index code + """ raise NotImplementedError("rewrite index_code") @property @@ -95,10 +93,10 @@ class CSIIndex(IndexBase): def html_table_index(self) -> int: """Which table of changes in html - CSI300: 0 - CSI100: 1 - :return: - """ + CSI300: 0 + CSI100: 1 + :return: + """ raise NotImplementedError() def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame: @@ -149,15 +147,15 @@ class CSIIndex(IndexBase): def normalize_symbol(symbol: str) -> str: """ - Parameters - ---------- - symbol: str - symbol + Parameters + ---------- + symbol: str + symbol - Returns - ------- - symbol - """ + Returns + ------- + symbol + """ symbol = f"{int(symbol):06}" return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}" @@ -214,10 +212,10 @@ class CSIIndex(IndexBase): def _read_change_from_url(self, url: str) -> pd.DataFrame: """read change from url - Parameters - ---------- - url : str - change url + Parameters + ---------- + url : str + change url Returns ------- @@ -288,12 +286,12 @@ class CSIIndex(IndexBase): def get_new_companies(self) -> pd.DataFrame: """ - Returns - ------- - pd.DataFrame: + Returns + ------- + pd.DataFrame: - symbol start_date end_date - SH600000 2000-01-01 2099-12-31 + symbol start_date end_date + SH600000 2000-01-01 2099-12-31 dtypes: symbol: str @@ -318,7 +316,6 @@ class CSIIndex(IndexBase): class CSI300(CSIIndex): - @property def index_code(self): return "000300" @@ -329,11 +326,10 @@ class CSI300(CSIIndex): @property def html_table_index(self): - return 0 + return 1 class CSI100(CSIIndex): - @property def index_code(self): return "000903" @@ -344,11 +340,10 @@ class CSI100(CSIIndex): @property def html_table_index(self): - return 1 + return 2 class CSI500(CSIIndex): - @property def index_code(self): return "000905" @@ -366,13 +361,13 @@ class CSI500(CSIIndex): def get_history_companies(self): """ - Data source:http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1 - Avoid a large number of parallel data acquisition, - such as 1000 times of concurrent data acquisition, because IP will be blocked - Returns - ------- + Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1 + Avoid a large number of parallel data acquisition, + such as 1000 times of concurrent data acquisition, because IP will be blocked + Returns + ------- - """ + """ lg = bs.login() today = pd.datetime.now() date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date @@ -391,7 +386,12 @@ class CSI500(CSIIndex): def get_instruments( - qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3 + qlib_dir: str, + index_name: str, + method: str = "parse_instruments", + freq: str = "day", + request_retry: int = 5, + retry_sleep: int = 3, ): """ @@ -410,13 +410,13 @@ def get_instruments( retry_sleep: int request sleep, by default 3 - Examples - ------- - # parse instruments - $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + Examples + ------- + # parse instruments + $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments - # parse new companies - $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + # parse new companies + $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies """ _cur_module = importlib.import_module("data_collector.cn_index.collector") diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index eb3c8da17..d522994d6 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -34,6 +34,7 @@ CALENDAR_BENCH_URL_MAP = { "ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"), # NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks "US_ALL": "^GSPC", + "IN_ALL": "^NSEI", } _BENCH_CALENDAR_LIST = None @@ -52,15 +53,15 @@ MINIMUM_SYMBOLS_NUM = 3900 def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: """get SH/SZ history calendar list - Parameters - ---------- - bench_code: str - value from ["CSI300", "CSI500", "ALL", "US_ALL"] + Parameters + ---------- + bench_code: str + value from ["CSI300", "CSI500", "ALL", "US_ALL"] - Returns - ------- - history calendar list - """ + Returns + ------- + history calendar list + """ logger.info(f"get calendar list: {bench_code}......") @@ -177,10 +178,10 @@ def get_calendar_list_by_ratio( def get_hs_stock_symbols() -> list: """get SH/SZ stock symbols - Returns - ------- - stock symbols - """ + Returns + ------- + stock symbols + """ global _HS_SYMBOLS def _get_symbol(): @@ -221,10 +222,10 @@ def get_hs_stock_symbols() -> list: def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: """get US stock symbols - Returns - ------- - stock symbols - """ + Returns + ------- + stock symbols + """ global _US_SYMBOLS @deco_retry @@ -416,16 +417,16 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list: def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix - Parameters - ---------- - symbol: str - symbol - capital : bool - by default True - Returns - ------- + Parameters + ---------- + symbol: str + symbol + capital : bool + by default True + Returns + ------- - """ + """ code, exchange = symbol.split(".") if exchange.lower() in ["sh", "ss"]: res = f"sh{code}" @@ -437,24 +438,22 @@ def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str: """symbol prefix to sufix - Parameters - ---------- - symbol: str - symbol - capital : bool - by default True - Returns - ------- + Parameters + ---------- + symbol: str + symbol + capital : bool + by default True + Returns + ------- - """ + """ res = f"{symbol[:-2]}.{symbol[-2:]}" return res.upper() if capital else res.lower() def deco_retry(retry: int = 5, retry_sleep: int = 3): - def deco_func(func): - @functools.wraps(func) def wrapper(*args, **kwargs): _retry = 5 if callable(retry) else retry @@ -480,19 +479,19 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3): def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): """get trading date by shift - Parameters - ---------- - trading_list: list - trading calendar list - shift : int - shift, default is 1 + Parameters + ---------- + trading_list: list + trading calendar list + shift : int + shift, default is 1 - trading_date : pd.Timestamp - trading date - Returns - ------- + trading_date : pd.Timestamp + trading date + Returns + ------- - """ + """ trading_date = pd.Timestamp(trading_date) left_index = bisect.bisect_left(trading_list, trading_date) try: