1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Add YahooCollector support for extend data

This commit is contained in:
zhupr
2021-06-04 22:28:42 +08:00
parent b2fe2385d5
commit 6f150f3fd6
3 changed files with 211 additions and 18 deletions

View File

@@ -226,11 +226,7 @@ class BaseCollector(abc.ABC):
class BaseNormalize(abc.ABC):
def __init__(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
"""
Parameters
@@ -265,6 +261,7 @@ class Normalize:
max_workers: int = 16,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
**kwargs,
):
"""
@@ -291,7 +288,9 @@ class Normalize:
self._max_workers = max_workers
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
self._normalize_obj = normalize_class(
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
)
def _executor(self, file_path: Path):
file_path = Path(file_path)
@@ -404,7 +403,7 @@ class BaseRun(abc.ABC):
limit_nums=limit_nums,
).collector_data()
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
"""normalize data
Parameters
@@ -426,5 +425,6 @@ class BaseRun(abc.ABC):
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
**kwargs,
)
yc.normalize()

View File

@@ -24,7 +24,10 @@ from data_collector.utils import get_calendar_list, get_trading_date_by_shift
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
# 2020-11-27 Announcement title change
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89"
class CSIIndex(IndexBase):

View File

@@ -23,7 +23,7 @@ from qlib.config import REG_CN as REGION_CN
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
from data_collector.utils import (
get_calendar_list,
get_hs_stock_symbols,
@@ -297,6 +297,7 @@ class YahooNormalize(BaseNormalize):
calendar_list: list = None,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
last_close: float = None,
):
if df.empty:
return df
@@ -318,7 +319,10 @@ class YahooNormalize(BaseNormalize):
df.sort_index(inplace=True)
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
_tmp_series = df["close"].fillna(method="ffill")
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
_tmp_shift_series = _tmp_series.shift(1)
if last_close is not None and isinstance(last_close, (int, float)):
_tmp_shift_series.iloc[0] = last_close
df["change"] = _tmp_series / _tmp_shift_series - 1
columns += ["change"]
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
@@ -367,6 +371,17 @@ class YahooNormalize1d(YahooNormalize, ABC):
df = self._manual_adj_data(df)
return df
def _get_first_close(self, df: pd.DataFrame) -> float:
"""get first close value
Notes
-----
For incremental updates(append) to Yahoo 1D data, user need to use a close that is not 0 on the first trading day of the existing data
"""
df = df.loc[df["close"].first_valid_index() :]
_close = df["close"].iloc[0]
return _close
def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""manual adjust data: All fields (except change) are standardized according to the close of the first day"""
if df.empty:
@@ -374,8 +389,7 @@ class YahooNormalize1d(YahooNormalize, ABC):
df = df.copy()
df.sort_values(self._date_field_name, inplace=True)
df = df.set_index(self._date_field_name)
df = df.loc[df["close"].first_valid_index() :]
_close = df["close"].iloc[0]
_close = self._get_first_close(df)
for _col in df.columns:
if _col == self._symbol_field_name:
continue
@@ -388,18 +402,97 @@ class YahooNormalize1d(YahooNormalize, ABC):
return df.reset_index()
class YahooNormalize1dExtend(YahooNormalize1d):
def __init__(
self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
Parameters
----------
old_qlib_data_dir: str, Path
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
self._end_date, self._old_close = self._get_old_data(old_qlib_data_dir)
self._end_date = pd.Timestamp(self._end_date).strftime(self.DAILY_FORMAT)
def _get_old_data(self, qlib_data_dir: [str, Path]):
import qlib
from qlib.data import D
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
df = D.features(D.instruments("all"), ["$close/$factor"])
df.columns = ["close"]
return D.calendar()[-1], df
def _get_first_close(self, df: pd.DataFrame) -> float:
_symbol = df.iloc[0][self._symbol_field_name]
try:
_df = self._old_close.loc(axis=0)[_symbol.upper()]
_close = _df.loc[_df.first_valid_index()]["close"]
except KeyError:
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
return _close
def _get_last_close(self, df: pd.DataFrame) -> float:
_symbol = df.iloc[0][self._symbol_field_name]
try:
_df = self._old_close.loc(axis=0)[_symbol.upper()]
_close = _df.loc[_df.last_valid_index()]["close"]
except KeyError:
_close = None
return _close
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
_symbol = df.iloc[0][self._symbol_field_name]
try:
_df = self._old_close.loc(axis=0)[_symbol.upper()]
_date = _df.index.max()
except KeyError:
_date = None
return _date
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
_last_close = self._get_last_close(df)
# reindex
_last_date = self._get_last_date(df)
if _last_date is not None:
df = df.set_index(self._date_field_name)
df.index = pd.to_datetime(df.index)
df = df[~df.index.duplicated(keep="first")]
_max_date = df.index.max()
df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
df = df[df[self._date_field_name] > _last_date]
_si = df["close"].first_valid_index()
if _si > df.index[0]:
logger.warning(
f"{df.iloc[0][self._symbol_field_name]} missing data: {df.loc[:_si][self._date_field_name]}"
)
# normalize
df = self.normalize_yahoo(
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
)
# adjusted price
df = self.adjusted_price(df)
df = self._manual_adj_data(df)
return df
class YahooNormalize1min(YahooNormalize, ABC):
AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00")
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
# Whether the trading day of 1min data is consistent with 1d
CONSISTENT_1d = False
CALC_PAUSED_NUM = False
def __init__(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
):
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
"""
Parameters
@@ -478,6 +571,54 @@ class YahooNormalize1min(YahooNormalize, ABC):
df[_col] = df[_col] / df["factor"]
else:
df[_col] = df[_col] * df["factor"]
if self.CALC_PAUSED_NUM:
df = self.calc_paused_num(df)
return df
def calc_paused_num(self, df: pd.DataFrame):
_symbol = df.iloc[0][self._symbol_field_name]
df = df.copy()
df["date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
# remove data that starts and ends with `np.nan` all day
all_data = []
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
all_nan_nums = 0
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
not_nan_nums = 0
for _date, _df in df.groupby(level="date"):
_df["paused"] = 0
if not _df.loc[_df["volume"] < 0].empty:
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
_df.loc[_df["volume"] < 0, "volume"] = np.nan
check_fields = set(_df.columns) - {
"date",
"paused",
"factor",
self._date_field_name,
self._symbol_field_name,
}
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
all_nan_nums += 1
not_nan_nums = 0
_df["paused"] = 1
if all_data:
_df["paused_num"] = not_nan_nums
all_data.append(_df)
else:
all_nan_nums = 0
not_nan_nums += 1
_df["paused_num"] = not_nan_nums
all_data.append(_df)
all_data = all_data[: len(all_data) - all_nan_nums]
if all_data:
df = pd.concat(all_data, sort=False)
else:
logger.warning(f"data is empty: {_symbol}")
df = pd.DataFrame()
return df
del df["date"]
return df
@abc.abstractmethod
@@ -523,11 +664,16 @@ class YahooNormalizeCN1d(YahooNormalizeCN, YahooNormalize1d):
pass
class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
pass
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
AM_RANGE = ("09:30:00", "11:29:00")
PM_RANGE = ("13:00:00", "14:59:00")
CONSISTENT_1d = True
CALC_PAUSED_NUM = True
def _get_calendar_list(self):
return self.generate_1min_from_daily(self.calendar_list_1d)
@@ -624,10 +770,54 @@ class Run(BaseRun):
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
"""
super(Run, self).normalize_data(date_field_name, symbol_field_name)
def normalize_data_1d_extend(
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
):
"""normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
Notes
-----
Steps to extend yahoo qlib data:
1. download qlib data: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data; save to <dir1>
2. collector source data: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#collector-data; save to <dir2>
3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir <dir1> --source_dir <dir2> --normalize_dir <dir3> --region CN --interval 1d
4. dump data: python scripts/dump_bin.py dump_update --csv_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date
5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir <dir1> --method parse_instruments
Parameters
----------
old_qlib_data_dir: str
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
date_field_name: str
date field name, default date
symbol_field_name: str
symbol field name, default symbol
Examples
---------
$ python collector.py normalize_data_1d_extend --old_qlib_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
"""
_class = getattr(self._cur_module, f"{self.normalize_class_name}Extend")
yc = Normalize(
source_dir=self.source_dir,
target_dir=self.normalize_dir,
normalize_class=_class,
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
old_qlib_data_dir=old_qlib_data_dir,
)
yc.normalize()
if __name__ == "__main__":
fire.Fire(Run)