diff --git a/scripts/data_collector/cn_index/collector.py b/scripts/data_collector/cn_index/collector.py index aed25834b..e5970c256 100644 --- a/scripts/data_collector/cn_index/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -345,7 +345,7 @@ class CSI100(CSIIndex): class CSI500(CSIIndex): @property - def index_code(self): + def index_code(self) -> str: return "000905" @property @@ -353,22 +353,41 @@ class CSI500(CSIIndex): return pd.Timestamp("2007-01-15") @property - def html_table_index(self): + def html_table_index(self) -> int: return 0 - def get_changes(self): + def get_changes(self) -> pd.DataFrame: + """get companies changes + + Return + -------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + """ return self.get_changes_with_history_companies(self.get_history_companies()) - def get_history_companies(self): + def get_history_companies(self) -> pd.DataFrame: """ - Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1 - Avoid a large number of parallel data acquisition, - such as 1000 times of concurrent data acquisition, because IP will be blocked + Returns ------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] """ - lg = bs.login() + bs.login() today = pd.datetime.now() date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date ret_list = [] @@ -380,10 +399,64 @@ class CSI500(CSIIndex): zz500_stocks.append(rs.get_row_data()) result = pd.DataFrame(zz500_stocks, columns=col) result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper()) + result = self.get_data_from_baostock(date) ret_list.append(result[["date", "symbol"]]) bs.logout() return pd.concat(ret_list, sort=False) + def get_data_from_baostock(self, date) -> pd.DataFrame: + """ + Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1 + Avoid a large number of parallel data acquisition, + such as 1000 times of concurrent data acquisition, because IP will be blocked + + Returns + ------- + pd.DataFrame: + date symbol code_name + SH600039 2007-01-15 四川路桥 + SH600051 2020-01-15 宁波联合 + dtypes: + date: pd.Timestamp + symbol: str + code_name: str + """ + col = ["date", "symbol", "code_name"] + rs = bs.query_zz500_stocks(date=str(date)) + zz500_stocks = [] + while (rs.error_code == "0") & rs.next(): + zz500_stocks.append(rs.get_row_data()) + result = pd.DataFrame(zz500_stocks, columns=col) + result["symbol"] = result["symbol"].apply(lambda x: x.replace(".", "").upper()) + return result + + def get_new_companies(self) -> pd.DataFrame: + """ + + Returns + ------- + pd.DataFrame: + + symbol start_date end_date + SH600000 2000-01-01 2099-12-31 + + dtypes: + symbol: str + start_date: pd.Timestamp + end_date: pd.Timestamp + """ + logger.info("get new companies......") + today = datetime.date.today() + bs.login() + result = self.get_data_from_baostock(today) + bs.logout() + df = result[["date", "symbol"]] + df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME] + df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str)) + df[self.START_DATE_FIELD] = self.bench_start_date + logger.info("end of get new companies.") + return df + def get_instruments( qlib_dir: str,