1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 19:41:00 +08:00

cleaned with black

This commit is contained in:
Gaurav
2021-07-15 11:24:41 +05:30
parent 457dcaa466
commit cfcd9fb1f8
2 changed files with 57 additions and 215 deletions

View File

@@ -69,15 +69,8 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
if bench_code.startswith("US_"):
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(
interval="1d", period="max"
)
calendar = (
df.index.get_level_values(level="date")
.map(pd.Timestamp)
.unique()
.tolist()
)
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
else:
if bench_code.upper() == "ALL":
@@ -85,9 +78,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
def _get_calendar(month):
_cal = []
try:
resp = requests.get(
SZSE_CALENDAR_URL.format(month=month, random=random.random)
).json()
resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json()
for _r in resp["data"]:
if int(_r["jybz"]):
_cal.append(pd.Timestamp(_r["jyrq"]))
@@ -95,11 +86,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
raise ValueError(f"{month}-->{e}")
return _cal
month_range = pd.date_range(
start="2000-01",
end=pd.Timestamp.now() + pd.Timedelta(days=31),
freq="M",
)
month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M")
calendar = []
for _m in month_range:
cal = _get_calendar(_m.strftime("%Y-%m"))
@@ -169,9 +156,7 @@ def get_calendar_list_by_ratio(
p_bar.update()
logger.info(f"count how many funds have founded in this day......")
_dict_count_founding = {
date: _number_all_funds for date in _dict_count_trade.keys()
} # dict{date:count}
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
with tqdm(total=_number_all_funds) as p_bar:
for oldest_date in all_oldest_list:
for date in _dict_count_founding.keys():
@@ -181,8 +166,7 @@ def get_calendar_list_by_ratio(
calendar = [
date
for date in _dict_count_trade
if _dict_count_trade[date]
>= max(int(_dict_count_founding[date] * threshold), minimum_count)
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
]
return calendar
@@ -204,9 +188,7 @@ def get_hs_stock_symbols() -> list:
_res |= set(
map(
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
etree.HTML(resp.text).xpath(
"//div[@class='result']/ul//li/a/text()"
),
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
)
)
time.sleep(3)
@@ -250,10 +232,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [
_v["f12"].replace("_", "-P")
for _v in resp.json()["data"]["diff"].values()
]
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
except Exception as e:
logger.warning(f"request error: {e}")
raise
@@ -315,14 +294,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
s_ = s_.strip("*")
return s_
_US_SYMBOLS = sorted(
set(
map(
_format,
filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols),
)
)
)
_US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))))
return _US_SYMBOLS
@@ -385,10 +357,7 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
raise ValueError("request error")
try:
_symbols = []
for sub_data in re.findall(
r"[\[](.*?)[\]]",
resp.content.decode().split("= [")[-1].replace("];", ""),
):
for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")):
data = sub_data.replace('"', "").replace("'", "")
# TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']
_symbols.append(data.split(",")[0])
@@ -467,9 +436,7 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
return deco_func(retry) if callable(retry) else deco_func
def get_trading_date_by_shift(
trading_list: list, trading_date: pd.Timestamp, shift: int = 1
):
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
"""get trading date by shift
Parameters