mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
update-csi500
This commit is contained in:
@@ -5,7 +5,6 @@ import re
|
||||
import abc
|
||||
import sys
|
||||
import importlib
|
||||
from tqdm import tqdm
|
||||
from io import BytesIO
|
||||
from typing import List, Iterable
|
||||
from pathlib import Path
|
||||
@@ -14,7 +13,7 @@ import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
import baostock as bs
|
||||
from lxml import etree
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
@@ -47,7 +46,6 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None):
|
||||
|
||||
|
||||
class CSIIndex(IndexBase):
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
@@ -74,20 +72,20 @@ class CSIIndex(IndexBase):
|
||||
@abc.abstractmethod
|
||||
def bench_start_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index start date
|
||||
"""
|
||||
raise NotImplementedError("rewrite bench_start_date")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def index_code(self) -> str:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index code
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
index code
|
||||
"""
|
||||
raise NotImplementedError("rewrite index_code")
|
||||
|
||||
@property
|
||||
@@ -95,10 +93,10 @@ class CSIIndex(IndexBase):
|
||||
def html_table_index(self) -> int:
|
||||
"""Which table of changes in html
|
||||
|
||||
CSI300: 0
|
||||
CSI100: 1
|
||||
:return:
|
||||
"""
|
||||
CSI300: 0
|
||||
CSI100: 1
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
@@ -149,15 +147,15 @@ class CSIIndex(IndexBase):
|
||||
def normalize_symbol(symbol: str) -> str:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
|
||||
Returns
|
||||
-------
|
||||
symbol
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
symbol
|
||||
"""
|
||||
symbol = f"{int(symbol):06}"
|
||||
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
|
||||
|
||||
@@ -214,10 +212,10 @@ class CSIIndex(IndexBase):
|
||||
def _read_change_from_url(self, url: str) -> pd.DataFrame:
|
||||
"""read change from url
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
change url
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
change url
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -288,12 +286,12 @@ class CSIIndex(IndexBase):
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
symbol start_date end_date
|
||||
SH600000 2000-01-01 2099-12-31
|
||||
|
||||
dtypes:
|
||||
symbol: str
|
||||
@@ -318,7 +316,6 @@ class CSIIndex(IndexBase):
|
||||
|
||||
|
||||
class CSI300(CSIIndex):
|
||||
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000300"
|
||||
@@ -329,11 +326,10 @@ class CSI300(CSIIndex):
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 0
|
||||
return 1
|
||||
|
||||
|
||||
class CSI100(CSIIndex):
|
||||
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000903"
|
||||
@@ -344,11 +340,10 @@ class CSI100(CSIIndex):
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 1
|
||||
return 2
|
||||
|
||||
|
||||
class CSI500(CSIIndex):
|
||||
|
||||
@property
|
||||
def index_code(self):
|
||||
return "000905"
|
||||
@@ -366,13 +361,13 @@ class CSI500(CSIIndex):
|
||||
|
||||
def get_history_companies(self):
|
||||
"""
|
||||
Data source:http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
|
||||
Avoid a large number of parallel data acquisition,
|
||||
such as 1000 times of concurrent data acquisition, because IP will be blocked
|
||||
Returns
|
||||
-------
|
||||
Data source: http://baostock.com/baostock/index.php/%E4%B8%AD%E8%AF%81500%E6%88%90%E5%88%86%E8%82%A1
|
||||
Avoid a large number of parallel data acquisition,
|
||||
such as 1000 times of concurrent data acquisition, because IP will be blocked
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
"""
|
||||
lg = bs.login()
|
||||
today = pd.datetime.now()
|
||||
date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date
|
||||
@@ -391,7 +386,12 @@ class CSI500(CSIIndex):
|
||||
|
||||
|
||||
def get_instruments(
|
||||
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
||||
qlib_dir: str,
|
||||
index_name: str,
|
||||
method: str = "parse_instruments",
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -410,13 +410,13 @@ def get_instruments(
|
||||
retry_sleep: int
|
||||
request sleep, by default 3
|
||||
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
# parse new companies
|
||||
$ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("data_collector.cn_index.collector")
|
||||
|
||||
@@ -34,6 +34,7 @@ CALENDAR_BENCH_URL_MAP = {
|
||||
"ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"),
|
||||
# NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks
|
||||
"US_ALL": "^GSPC",
|
||||
"IN_ALL": "^NSEI",
|
||||
}
|
||||
|
||||
_BENCH_CALENDAR_LIST = None
|
||||
@@ -52,15 +53,15 @@ MINIMUM_SYMBOLS_NUM = 3900
|
||||
def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
|
||||
"""get SH/SZ history calendar list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bench_code: str
|
||||
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
|
||||
Parameters
|
||||
----------
|
||||
bench_code: str
|
||||
value from ["CSI300", "CSI500", "ALL", "US_ALL"]
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
|
||||
logger.info(f"get calendar list: {bench_code}......")
|
||||
|
||||
@@ -177,10 +178,10 @@ def get_calendar_list_by_ratio(
|
||||
def get_hs_stock_symbols() -> list:
|
||||
"""get SH/SZ stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _HS_SYMBOLS
|
||||
|
||||
def _get_symbol():
|
||||
@@ -221,10 +222,10 @@ def get_hs_stock_symbols() -> list:
|
||||
def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get US stock symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
stock symbols
|
||||
"""
|
||||
global _US_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
@@ -416,16 +417,16 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
"""
|
||||
code, exchange = symbol.split(".")
|
||||
if exchange.lower() in ["sh", "ss"]:
|
||||
res = f"sh{code}"
|
||||
@@ -437,24 +438,22 @@ def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol prefix to sufix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
symbol
|
||||
capital : bool
|
||||
by default True
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
"""
|
||||
res = f"{symbol[:-2]}.{symbol[-2:]}"
|
||||
return res.upper() if capital else res.lower()
|
||||
|
||||
|
||||
def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
|
||||
def deco_func(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
_retry = 5 if callable(retry) else retry
|
||||
@@ -480,19 +479,19 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
|
||||
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
|
||||
"""get trading date by shift
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trading_list: list
|
||||
trading calendar list
|
||||
shift : int
|
||||
shift, default is 1
|
||||
Parameters
|
||||
----------
|
||||
trading_list: list
|
||||
trading calendar list
|
||||
shift : int
|
||||
shift, default is 1
|
||||
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
trading_date : pd.Timestamp
|
||||
trading date
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
"""
|
||||
trading_date = pd.Timestamp(trading_date)
|
||||
left_index = bisect.bisect_left(trading_list, trading_date)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user