1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

cleaned with black

This commit is contained in:
Gaurav
2021-07-15 11:24:41 +05:30
parent 457dcaa466
commit cfcd9fb1f8
2 changed files with 57 additions and 215 deletions

View File

@@ -69,15 +69,8 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
calendar = _CALENDAR_MAP.get(bench_code, None)
if calendar is None:
if bench_code.startswith("US_"):
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(
interval="1d", period="max"
)
calendar = (
df.index.get_level_values(level="date")
.map(pd.Timestamp)
.unique()
.tolist()
)
df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")
calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist()
else:
if bench_code.upper() == "ALL":
@@ -85,9 +78,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
def _get_calendar(month):
_cal = []
try:
resp = requests.get(
SZSE_CALENDAR_URL.format(month=month, random=random.random)
).json()
resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json()
for _r in resp["data"]:
if int(_r["jybz"]):
_cal.append(pd.Timestamp(_r["jyrq"]))
@@ -95,11 +86,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]:
raise ValueError(f"{month}-->{e}")
return _cal
month_range = pd.date_range(
start="2000-01",
end=pd.Timestamp.now() + pd.Timedelta(days=31),
freq="M",
)
month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M")
calendar = []
for _m in month_range:
cal = _get_calendar(_m.strftime("%Y-%m"))
@@ -169,9 +156,7 @@ def get_calendar_list_by_ratio(
p_bar.update()
logger.info(f"count how many funds have founded in this day......")
_dict_count_founding = {
date: _number_all_funds for date in _dict_count_trade.keys()
} # dict{date:count}
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
with tqdm(total=_number_all_funds) as p_bar:
for oldest_date in all_oldest_list:
for date in _dict_count_founding.keys():
@@ -181,8 +166,7 @@ def get_calendar_list_by_ratio(
calendar = [
date
for date in _dict_count_trade
if _dict_count_trade[date]
>= max(int(_dict_count_founding[date] * threshold), minimum_count)
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
]
return calendar
@@ -204,9 +188,7 @@ def get_hs_stock_symbols() -> list:
_res |= set(
map(
lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v),
etree.HTML(resp.text).xpath(
"//div[@class='result']/ul//li/a/text()"
),
etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"),
)
)
time.sleep(3)
@@ -250,10 +232,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
if resp.status_code != 200:
raise ValueError("request error")
try:
_symbols = [
_v["f12"].replace("_", "-P")
for _v in resp.json()["data"]["diff"].values()
]
_symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()]
except Exception as e:
logger.warning(f"request error: {e}")
raise
@@ -315,14 +294,7 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
s_ = s_.strip("*")
return s_
_US_SYMBOLS = sorted(
set(
map(
_format,
filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols),
)
)
)
_US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))))
return _US_SYMBOLS
@@ -385,10 +357,7 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
raise ValueError("request error")
try:
_symbols = []
for sub_data in re.findall(
r"[\[](.*?)[\]]",
resp.content.decode().split("= [")[-1].replace("];", ""),
):
for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")):
data = sub_data.replace('"', "").replace("'", "")
# TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']
_symbols.append(data.split(",")[0])
@@ -467,9 +436,7 @@ def deco_retry(retry: int = 5, retry_sleep: int = 3):
return deco_func(retry) if callable(retry) else deco_func
def get_trading_date_by_shift(
trading_list: list, trading_date: pd.Timestamp, shift: int = 1
):
def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1):
"""get trading date by shift
Parameters

View File

@@ -93,9 +93,7 @@ class YahooCollector(BaseCollector):
def init_datetime(self):
if self.interval == self.INTERVAL_1min:
self.start_datetime = max(
self.start_datetime, self.DEFAULT_START_DATETIME_1MIN
)
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
elif self.interval == self.INTERVAL_1d:
pass
else:
@@ -119,9 +117,7 @@ class YahooCollector(BaseCollector):
raise NotImplementedError("rewrite get_timezone")
@staticmethod
def get_data_from_remote(
symbol, interval, start, end, show_1min_logging: bool = False
):
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
error_msg = f"{symbol}-{interval}-{start}-{end}"
def _show_logging_func():
@@ -130,16 +126,13 @@ class YahooCollector(BaseCollector):
interval = "1m" if interval in ["1m", "1min"] else interval
try:
_resp = Ticker(symbol, asynchronous=False).history(
interval=interval, start=start, end=end
)
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
elif isinstance(_resp, dict):
_temp_data = _resp.get(symbol, {})
if isinstance(_temp_data, str) or (
isinstance(_resp, dict)
and _temp_data.get("indicators", {}).get("quote", None) is None
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
):
_show_logging_func()
else:
@@ -148,11 +141,7 @@ class YahooCollector(BaseCollector):
logger.warning(f"{error_msg}:{e}")
def get_data(
self,
symbol: str,
interval: str,
start_datetime: pd.Timestamp,
end_datetime: pd.Timestamp,
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
) -> pd.DataFrame:
@deco_retry(retry_sleep=self.delay)
def _get_simple(start_, end_):
@@ -225,35 +214,21 @@ class YahooCollectorCN1d(YahooCollectorCN):
_format = "%Y%m%d"
_begin = self.start_datetime.strftime(_format)
_end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
for _index_name, _index_code in {
"csi300": "000300",
"csi100": "000903",
}.items():
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
logger.info(f"get bench data: {_index_name}({_index_code})......")
try:
df = pd.DataFrame(
map(
lambda x: x.split(","),
requests.get(
INDEX_BENCH_URL.format(
index_code=_index_code, begin=_begin, end=_end
)
).json()["data"]["klines"],
requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[
"data"
]["klines"],
)
)
except Exception as e:
logger.warning(f"get {_index_name} error: {e}")
continue
df.columns = [
"date",
"open",
"close",
"high",
"low",
"volume",
"money",
"change",
]
df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"]
df["date"] = pd.to_datetime(df["date"])
df = df.astype(float, errors="ignore")
df["adjclose"] = df["close"]
@@ -351,18 +326,13 @@ class YahooNormalize(BaseNormalize):
df = df.reindex(
pd.DataFrame(index=calendar_list)
.loc[
pd.Timestamp(df.index.min())
.date() : pd.Timestamp(df.index.max())
.date()
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
+ pd.Timedelta(hours=23, minutes=59)
]
.index
)
df.sort_index(inplace=True)
df.loc[
(df["volume"] <= 0) | np.isnan(df["volume"]),
set(df.columns) - {symbol_field_name},
] = np.nan
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
_tmp_series = df["close"].fillna(method="ffill")
_tmp_shift_series = _tmp_series.shift(1)
if last_close is not None:
@@ -377,9 +347,7 @@ class YahooNormalize(BaseNormalize):
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
# normalize
df = self.normalize_yahoo(
df, self._calendar_list, self._date_field_name, self._symbol_field_name
)
df = self.normalize_yahoo(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
# adjusted price
df = self.adjusted_price(df)
return df
@@ -450,11 +418,7 @@ class YahooNormalize1d(YahooNormalize, ABC):
class YahooNormalize1dExtend(YahooNormalize1d):
def __init__(
self,
old_qlib_data_dir: [str, Path],
date_field_name: str = "date",
symbol_field_name: str = "symbol",
**kwargs,
self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
@@ -483,9 +447,7 @@ class YahooNormalize1dExtend(YahooNormalize1d):
return df
def _get_close(self, df: pd.DataFrame, field_name: str):
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][
self._symbol_field_name
].upper()
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
_df = self.old_qlib_data.loc(axis=0)[_symbol]
_close = _df.loc[_df.last_valid_index()][field_name]
return _close
@@ -505,9 +467,7 @@ class YahooNormalize1dExtend(YahooNormalize1d):
return _close
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][
self._symbol_field_name
].upper()
_symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper()
try:
_df = self.old_qlib_data.loc(axis=0)[_symbol]
_date = _df.index.max()
@@ -535,11 +495,7 @@ class YahooNormalize1dExtend(YahooNormalize1d):
)
# normalize
df = self.normalize_yahoo(
df,
self._calendar_list,
self._date_field_name,
self._symbol_field_name,
last_close=_last_close,
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
)
# adjusted price
df = self.adjusted_price(df)
@@ -577,14 +533,10 @@ class YahooNormalize1min(YahooNormalize, ABC):
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
"""
data_1d = YahooCollector.get_data_from_remote(
self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end
)
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
if not (data_1d is None or data_1d.empty):
_class_name = self.__class__.__name__.replace("min", "d")
_class: type(YahooNormalize) = getattr(
importlib.import_module("collector"), _class_name
)
_class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name)
data_1d_obj = _class(self._date_field_name, self._symbol_field_name)
data_1d = data_1d_obj.normalize(data_1d)
return data_1d
@@ -597,12 +549,8 @@ class YahooNormalize1min(YahooNormalize, ABC):
df = df.sort_values(self._date_field_name)
symbol = df.iloc[0][self._symbol_field_name]
# get 1d data from yahoo
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(
self.DAILY_FORMAT
)
_end = (
pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)
).strftime(self.DAILY_FORMAT)
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
data_1d = data_1d.copy()
if data_1d is None or data_1d.empty:
@@ -613,9 +561,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
# NOTE: volume is np.nan or volume <= 0, paused = 1
# FIXME: find a more accurate data source
data_1d["paused"] = 0
data_1d.loc[
(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"
] = 1
data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1
data_1d = data_1d.set_index(self._date_field_name)
# add factor from 1d data
@@ -623,9 +569,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
df["date_tmp"] = df[self._date_field_name].apply(
lambda x: pd.Timestamp(x).date()
)
df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
df.set_index("date_tmp", inplace=True)
df.loc[:, "factor"] = data_1d["close"] / df["close"]
df.loc[:, "paused"] = data_1d["paused"]
@@ -636,16 +580,12 @@ class YahooNormalize1min(YahooNormalize, ABC):
df.set_index(self._date_field_name, inplace=True)
df = df.reindex(
self.generate_1min_from_daily(
pd.to_datetime(
data_1d.reset_index()[
self._date_field_name
].drop_duplicates()
)
pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates())
)
)
df[self._symbol_field_name] = df.loc[
df[self._symbol_field_name].first_valid_index()
][self._symbol_field_name]
df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][
self._symbol_field_name
]
df.index.names = [self._date_field_name]
df.reset_index(inplace=True)
for _col in self.COLUMNS:
@@ -663,9 +603,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
def calc_paused_num(self, df: pd.DataFrame):
_symbol = df.iloc[0][self._symbol_field_name]
df = df.copy()
df["_tmp_date"] = df[self._date_field_name].apply(
lambda x: pd.Timestamp(x).date()
)
df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
# remove data that starts and ends with `np.nan` all day
all_data = []
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
@@ -685,10 +623,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
self._date_field_name,
self._symbol_field_name,
}
if (
_df.loc[:, check_fields].isna().values.all()
or (_df["volume"] == 0).all()
):
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
all_nan_nums += 1
not_nan_nums = 0
_df["paused"] = 1
@@ -723,11 +658,7 @@ class YahooNormalize1minOffline(YahooNormalize1min):
"""Normalised to 1min using local 1d data"""
def __init__(
self,
qlib_data_1d_dir: [str, Path],
date_field_name: str = "date",
symbol_field_name: str = "symbol",
**kwargs,
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
@@ -741,9 +672,7 @@ class YahooNormalize1minOffline(YahooNormalize1min):
symbol field name, default is symbol
"""
self.qlib_data_1d_dir = qlib_data_1d_dir
super(YahooNormalize1minOffline, self).__init__(
date_field_name, symbol_field_name
)
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
self._all_1d_data = self._get_all_1d_data()
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
@@ -758,19 +687,9 @@ class YahooNormalize1minOffline(YahooNormalize1min):
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
df = D.features(
D.instruments("all"),
["$paused", "$volume", "$factor", "$close"],
freq="day",
)
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
df.reset_index(inplace=True)
df.rename(
columns={
"datetime": self._date_field_name,
"instrument": self._symbol_field_name,
},
inplace=True,
)
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
return df
@@ -838,11 +757,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
def symbol_to_yahoo(self, symbol):
if "." not in symbol:
_exchange = symbol[:2]
_exchange = (
("ss" if _exchange.islower() else "SS")
if _exchange.lower() == "sh"
else _exchange
)
_exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange
symbol = symbol[2:] + "." + _exchange
return symbol
@@ -851,14 +766,7 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
class Run(BaseRun):
def __init__(
self,
source_dir=None,
normalize_dir=None,
max_workers=1,
interval="1d",
region=REGION_CN,
):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
"""
Parameters
@@ -930,13 +838,7 @@ class Run(BaseRun):
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
"""
super(Run, self).download_data(
max_collector_count,
delay,
start,
end,
self.interval,
check_data_length,
limit_nums,
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
)
def normalize_data(
@@ -971,25 +873,16 @@ class Run(BaseRun):
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
"""
if self.interval.lower() == "1min":
if (
qlib_data_1d_dir is None
or not Path(qlib_data_1d_dir).expanduser().exists()
):
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
raise ValueError(
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: https://github.com/zhupr/qlib/tree/support_extend_data/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance"
)
super(Run, self).normalize_data(
date_field_name,
symbol_field_name,
end_date=end_date,
qlib_data_1d_dir=qlib_data_1d_dir,
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
)
def normalize_data_1d_extend(
self,
old_qlib_data_dir,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
):
"""normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
@@ -1120,30 +1013,19 @@ class Run(BaseRun):
# start/end date
if trading_date is None:
trading_date = datetime.datetime.now().strftime("%Y-%m-%d")
logger.warning(
f"trading_date is None, use the current date: {trading_date}"
)
logger.warning(f"trading_date is None, use the current date: {trading_date}")
if end_date is None:
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime(
"%Y-%m-%d"
)
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
# download qlib 1d data
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
if not exists_qlib_data(qlib_data_1d_dir):
GetData().qlib_data(
target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region
)
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
# download data from yahoo
# NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1
self.download_data(
delay=delay,
start=trading_date,
end=end_date,
check_data_length=check_data_length,
)
self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length)
# NOTE: a larger max_workers setting here would be faster
self.max_workers = (
max(multiprocessing.cpu_count() - 2, 1)
@@ -1165,18 +1047,11 @@ class Run(BaseRun):
# parse index
_region = self.region.lower()
if _region not in ["cn", "us"]:
logger.warning(
f"Unsupported region: region={_region}, component downloads will be ignored"
)
logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored")
return
index_list = (
["CSI100", "CSI300"]
if _region == "cn"
else ["SP500", "NASDAQ100", "DJIA", "SP400"]
)
index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"]
get_instruments = getattr(
importlib.import_module(f"data_collector.{_region}_index.collector"),
"get_instruments",
importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments"
)
for _index in index_list:
get_instruments(str(qlib_data_1d_dir), _index)