mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
263 lines
10 KiB
Python
263 lines
10 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import re
|
|
import sys
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import List, Iterable, Optional, Union
|
|
|
|
import fire
|
|
import pandas as pd
|
|
import baostock as bs
|
|
from loguru import logger
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
sys.path.append(str(BASE_DIR.parent.parent))
|
|
|
|
from data_collector.base import BaseCollector, BaseRun, BaseNormalize
|
|
from data_collector.utils import get_hs_stock_symbols, get_calendar_list
|
|
|
|
|
|
class PitCollector(BaseCollector):
|
|
DEFAULT_START_DATETIME_QUARTERLY = pd.Timestamp("2000-01-01")
|
|
DEFAULT_START_DATETIME_ANNUAL = pd.Timestamp("2000-01-01")
|
|
DEFAULT_END_DATETIME_QUARTERLY = pd.Timestamp(datetime.now() + pd.Timedelta(days=1))
|
|
DEFAULT_END_DATETIME_ANNUAL = pd.Timestamp(datetime.now() + pd.Timedelta(days=1))
|
|
|
|
INTERVAL_QUARTERLY = "quarterly"
|
|
INTERVAL_ANNUAL = "annual"
|
|
|
|
def __init__(
|
|
self,
|
|
save_dir: Union[str, Path],
|
|
start: Optional[str] = None,
|
|
end: Optional[str] = None,
|
|
interval: str = "quarterly",
|
|
max_workers: int = 1,
|
|
max_collector_count: int = 1,
|
|
delay: int = 0,
|
|
check_data_length: bool = False,
|
|
limit_nums: Optional[int] = None,
|
|
symbol_regex: Optional[str] = None,
|
|
):
|
|
"""
|
|
Parameters
|
|
----------
|
|
save_dir: str
|
|
instrument save dir
|
|
max_workers: int
|
|
workers, default 1; Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1
|
|
max_collector_count: int
|
|
default 2
|
|
delay: float
|
|
time.sleep(delay), default 0
|
|
interval: str
|
|
freq, value from [1min, 1d], default 1d
|
|
start: str
|
|
start datetime, default None
|
|
end: str
|
|
end datetime, default None
|
|
check_data_length: int
|
|
check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None.
|
|
limit_nums: int
|
|
using for debug, by default None
|
|
symbol_regex: str
|
|
symbol regular expression, by default None.
|
|
"""
|
|
self.symbol_regex = symbol_regex
|
|
super().__init__(
|
|
save_dir=save_dir,
|
|
start=start,
|
|
end=end,
|
|
interval=interval,
|
|
max_workers=max_workers,
|
|
max_collector_count=max_collector_count,
|
|
delay=delay,
|
|
check_data_length=check_data_length,
|
|
limit_nums=limit_nums,
|
|
)
|
|
|
|
def get_instrument_list(self) -> List[str]:
|
|
logger.info("get cn stock symbols......")
|
|
symbols = get_hs_stock_symbols()
|
|
if self.symbol_regex is not None:
|
|
regex_compile = re.compile(self.symbol_regex)
|
|
symbols = [symbol for symbol in symbols if regex_compile.match(symbol)]
|
|
logger.info(f"get {len(symbols)} symbols.")
|
|
return symbols
|
|
|
|
def normalize_symbol(self, symbol: str) -> str:
|
|
symbol, exchange = symbol.split(".")
|
|
exchange = "sh" if exchange == "ss" else "sz"
|
|
return f"{exchange}{symbol}"
|
|
|
|
@staticmethod
|
|
def get_performance_express_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
|
column_mapping = {
|
|
"performanceExpPubDate": "date",
|
|
"performanceExpStatDate": "period",
|
|
"performanceExpressROEWa": "value",
|
|
}
|
|
|
|
resp = bs.query_performance_express_report(code=code, start_date=start_date, end_date=end_date)
|
|
report_list = []
|
|
while (resp.error_code == "0") and resp.next():
|
|
report_list.append(resp.get_row_data())
|
|
report_df = pd.DataFrame(report_list, columns=resp.fields)
|
|
try:
|
|
report_df = report_df[list(column_mapping.keys())]
|
|
except KeyError:
|
|
return pd.DataFrame()
|
|
report_df.rename(columns=column_mapping, inplace=True)
|
|
report_df["field"] = "roeWa"
|
|
report_df["value"] = pd.to_numeric(report_df["value"], errors="ignore")
|
|
report_df["value"] = report_df["value"].apply(lambda x: x / 100.0)
|
|
return report_df
|
|
|
|
@staticmethod
|
|
def get_profit_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
|
column_mapping = {"pubDate": "date", "statDate": "period", "roeAvg": "value"}
|
|
fields = bs.query_profit_data(code="sh.600519", year=2020, quarter=1).fields
|
|
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
|
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
|
args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)]
|
|
profit_list = []
|
|
for year, quarter in args:
|
|
resp = bs.query_profit_data(code=code, year=year, quarter=quarter)
|
|
while (resp.error_code == "0") and resp.next():
|
|
if "pubDate" not in resp.fields:
|
|
continue
|
|
row_data = resp.get_row_data()
|
|
pub_date = pd.Timestamp(row_data[resp.fields.index("pubDate")])
|
|
if start_date <= pub_date <= end_date and row_data:
|
|
profit_list.append(row_data)
|
|
profit_df = pd.DataFrame(profit_list, columns=fields)
|
|
try:
|
|
profit_df = profit_df[list(column_mapping.keys())]
|
|
except KeyError:
|
|
return pd.DataFrame()
|
|
profit_df.rename(columns=column_mapping, inplace=True)
|
|
profit_df["field"] = "roeWa"
|
|
profit_df["value"] = pd.to_numeric(profit_df["value"], errors="ignore")
|
|
return profit_df
|
|
|
|
@staticmethod
|
|
def get_forecast_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
|
column_mapping = {
|
|
"profitForcastExpPubDate": "date",
|
|
"profitForcastExpStatDate": "period",
|
|
"value": "value",
|
|
}
|
|
resp = bs.query_forecast_report(code=code, start_date=start_date, end_date=end_date)
|
|
forecast_list = []
|
|
while (resp.error_code == "0") and resp.next():
|
|
forecast_list.append(resp.get_row_data())
|
|
forecast_df = pd.DataFrame(forecast_list, columns=resp.fields)
|
|
numeric_fields = ["profitForcastChgPctUp", "profitForcastChgPctDwn"]
|
|
try:
|
|
forecast_df[numeric_fields] = forecast_df[numeric_fields].apply(pd.to_numeric, errors="ignore")
|
|
except KeyError:
|
|
return pd.DataFrame()
|
|
forecast_df["value"] = (forecast_df["profitForcastChgPctUp"] + forecast_df["profitForcastChgPctDwn"]) / 200
|
|
forecast_df = forecast_df[list(column_mapping.keys())]
|
|
forecast_df.rename(columns=column_mapping, inplace=True)
|
|
forecast_df["field"] = "YOYNI"
|
|
return forecast_df
|
|
|
|
@staticmethod
|
|
def get_growth_df(code: str, start_date: str, end_date: str) -> pd.DataFrame:
|
|
column_mapping = {"pubDate": "date", "statDate": "period", "YOYNI": "value"}
|
|
fields = bs.query_growth_data(code="sh.600519", year=2020, quarter=1).fields
|
|
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
|
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
|
args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)]
|
|
growth_list = []
|
|
for year, quarter in args:
|
|
resp = bs.query_growth_data(code=code, year=year, quarter=quarter)
|
|
while (resp.error_code == "0") and resp.next():
|
|
if "pubDate" not in resp.fields:
|
|
continue
|
|
row_data = resp.get_row_data()
|
|
pub_date = pd.Timestamp(row_data[resp.fields.index("pubDate")])
|
|
if start_date <= pub_date <= end_date and row_data:
|
|
growth_list.append(row_data)
|
|
growth_df = pd.DataFrame(growth_list, columns=fields)
|
|
try:
|
|
growth_df = growth_df[list(column_mapping.keys())]
|
|
except KeyError:
|
|
return pd.DataFrame()
|
|
growth_df.rename(columns=column_mapping, inplace=True)
|
|
growth_df["field"] = "YOYNI"
|
|
growth_df["value"] = pd.to_numeric(growth_df["value"], errors="ignore")
|
|
return growth_df
|
|
|
|
def get_data(
|
|
self,
|
|
symbol: str,
|
|
interval: str,
|
|
start_datetime: pd.Timestamp,
|
|
end_datetime: pd.Timestamp,
|
|
) -> pd.DataFrame:
|
|
if interval != self.INTERVAL_QUARTERLY:
|
|
raise ValueError(f"cannot support {interval}")
|
|
symbol, exchange = symbol.split(".")
|
|
exchange = "sh" if exchange == "ss" else "sz"
|
|
code = f"{exchange}.{symbol}"
|
|
start_date = start_datetime.strftime("%Y-%m-%d")
|
|
end_date = end_datetime.strftime("%Y-%m-%d")
|
|
|
|
performance_express_report_df = self.get_performance_express_report_df(code, start_date, end_date)
|
|
profit_df = self.get_profit_df(code, start_date, end_date)
|
|
forecast_report_df = self.get_forecast_report_df(code, start_date, end_date)
|
|
growth_df = self.get_growth_df(code, start_date, end_date)
|
|
|
|
df = pd.concat(
|
|
[performance_express_report_df, profit_df, forecast_report_df, growth_df],
|
|
axis=0,
|
|
)
|
|
return df
|
|
|
|
|
|
class PitNormalize(BaseNormalize):
|
|
def __init__(self, interval: str = "quarterly", *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.interval = interval
|
|
|
|
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
dt = df["period"].apply(
|
|
lambda x: (
|
|
pd.to_datetime(x) + pd.DateOffset(days=(45 if self.interval == PitCollector.INTERVAL_QUARTERLY else 90))
|
|
).date()
|
|
)
|
|
df["date"] = df["date"].fillna(dt.astype(str))
|
|
|
|
df["period"] = pd.to_datetime(df["period"])
|
|
df["period"] = df["period"].apply(
|
|
lambda x: x.year if self.interval == PitCollector.INTERVAL_ANNUAL else x.year * 100 + (x.month - 1) // 3 + 1
|
|
)
|
|
return df
|
|
|
|
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
|
return get_calendar_list()
|
|
|
|
|
|
class Run(BaseRun):
|
|
@property
|
|
def collector_class_name(self) -> str:
|
|
return f"PitCollector"
|
|
|
|
@property
|
|
def normalize_class_name(self) -> str:
|
|
return f"PitNormalize"
|
|
|
|
@property
|
|
def default_base_dir(self) -> [Path, str]:
|
|
return BASE_DIR
|
|
|
|
|
|
if __name__ == "__main__":
|
|
bs.login()
|
|
fire.Fire(Run)
|
|
bs.logout()
|