mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add YahooCollector support for extend data
This commit is contained in:
@@ -226,11 +226,7 @@ class BaseCollector(abc.ABC):
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -265,6 +261,7 @@ class Normalize:
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -291,7 +288,9 @@ class Normalize:
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
self._normalize_obj = normalize_class(
|
||||
date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs
|
||||
)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
@@ -404,7 +403,7 @@ class BaseRun(abc.ABC):
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
@@ -426,5 +425,6 @@ class BaseRun(abc.ABC):
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
**kwargs,
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
@@ -24,7 +24,10 @@ from data_collector.utils import get_calendar_list, get_trading_date_by_shift
|
||||
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
|
||||
# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
# 2020-11-27 Announcement title change
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89"
|
||||
|
||||
|
||||
class CSIIndex(IndexBase):
|
||||
|
||||
@@ -23,7 +23,7 @@ from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize
|
||||
from data_collector.utils import (
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
@@ -297,6 +297,7 @@ class YahooNormalize(BaseNormalize):
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
last_close: float = None,
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
@@ -318,7 +319,10 @@ class YahooNormalize(BaseNormalize):
|
||||
df.sort_index(inplace=True)
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {symbol_field_name}] = np.nan
|
||||
_tmp_series = df["close"].fillna(method="ffill")
|
||||
df["change"] = _tmp_series / _tmp_series.shift(1) - 1
|
||||
_tmp_shift_series = _tmp_series.shift(1)
|
||||
if last_close is not None and isinstance(last_close, (int, float)):
|
||||
_tmp_shift_series.iloc[0] = last_close
|
||||
df["change"] = _tmp_series / _tmp_shift_series - 1
|
||||
columns += ["change"]
|
||||
df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan
|
||||
|
||||
@@ -367,6 +371,17 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
df = self._manual_adj_data(df)
|
||||
return df
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
"""get first close value
|
||||
|
||||
Notes
|
||||
-----
|
||||
For incremental updates(append) to Yahoo 1D data, user need to use a close that is not 0 on the first trading day of the existing data
|
||||
"""
|
||||
df = df.loc[df["close"].first_valid_index() :]
|
||||
_close = df["close"].iloc[0]
|
||||
return _close
|
||||
|
||||
def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""manual adjust data: All fields (except change) are standardized according to the close of the first day"""
|
||||
if df.empty:
|
||||
@@ -374,8 +389,7 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
df = df.copy()
|
||||
df.sort_values(self._date_field_name, inplace=True)
|
||||
df = df.set_index(self._date_field_name)
|
||||
df = df.loc[df["close"].first_valid_index() :]
|
||||
_close = df["close"].iloc[0]
|
||||
_close = self._get_first_close(df)
|
||||
for _col in df.columns:
|
||||
if _col == self._symbol_field_name:
|
||||
continue
|
||||
@@ -388,18 +402,97 @@ class YahooNormalize1d(YahooNormalize, ABC):
|
||||
return df.reset_index()
|
||||
|
||||
|
||||
class YahooNormalize1dExtend(YahooNormalize1d):
|
||||
def __init__(
|
||||
self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_qlib_data_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name)
|
||||
self._end_date, self._old_close = self._get_old_data(old_qlib_data_dir)
|
||||
self._end_date = pd.Timestamp(self._end_date).strftime(self.DAILY_FORMAT)
|
||||
|
||||
def _get_old_data(self, qlib_data_dir: [str, Path]):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve())
|
||||
qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None)
|
||||
df = D.features(D.instruments("all"), ["$close/$factor"])
|
||||
df.columns = ["close"]
|
||||
return D.calendar()[-1], df
|
||||
|
||||
def _get_first_close(self, df: pd.DataFrame) -> float:
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
try:
|
||||
_df = self._old_close.loc(axis=0)[_symbol.upper()]
|
||||
_close = _df.loc[_df.first_valid_index()]["close"]
|
||||
except KeyError:
|
||||
_close = super(YahooNormalize1dExtend, self)._get_first_close(df)
|
||||
return _close
|
||||
|
||||
def _get_last_close(self, df: pd.DataFrame) -> float:
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
try:
|
||||
_df = self._old_close.loc(axis=0)[_symbol.upper()]
|
||||
_close = _df.loc[_df.last_valid_index()]["close"]
|
||||
except KeyError:
|
||||
_close = None
|
||||
return _close
|
||||
|
||||
def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp:
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
try:
|
||||
_df = self._old_close.loc(axis=0)[_symbol.upper()]
|
||||
_date = _df.index.max()
|
||||
except KeyError:
|
||||
_date = None
|
||||
return _date
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
_last_close = self._get_last_close(df)
|
||||
# reindex
|
||||
_last_date = self._get_last_date(df)
|
||||
if _last_date is not None:
|
||||
df = df.set_index(self._date_field_name)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
_max_date = df.index.max()
|
||||
df = df.reindex(self._calendar_list).loc[:_max_date].reset_index()
|
||||
df = df[df[self._date_field_name] > _last_date]
|
||||
_si = df["close"].first_valid_index()
|
||||
if _si > df.index[0]:
|
||||
logger.warning(
|
||||
f"{df.iloc[0][self._symbol_field_name]} missing data: {df.loc[:_si][self._date_field_name]}"
|
||||
)
|
||||
# normalize
|
||||
df = self.normalize_yahoo(
|
||||
df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close
|
||||
)
|
||||
# adjusted price
|
||||
df = self.adjusted_price(df)
|
||||
df = self._manual_adj_data(df)
|
||||
return df
|
||||
|
||||
|
||||
class YahooNormalize1min(YahooNormalize, ABC):
|
||||
AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00")
|
||||
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
|
||||
|
||||
# Whether the trading day of 1min data is consistent with 1d
|
||||
CONSISTENT_1d = False
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -478,6 +571,54 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
df[_col] = df[_col] / df["factor"]
|
||||
else:
|
||||
df[_col] = df[_col] * df["factor"]
|
||||
|
||||
if self.CALC_PAUSED_NUM:
|
||||
df = self.calc_paused_num(df)
|
||||
return df
|
||||
|
||||
def calc_paused_num(self, df: pd.DataFrame):
|
||||
_symbol = df.iloc[0][self._symbol_field_name]
|
||||
df = df.copy()
|
||||
df["date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
# remove data that starts and ends with `np.nan` all day
|
||||
all_data = []
|
||||
# Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan
|
||||
all_nan_nums = 0
|
||||
# Record the number of consecutive occurrences of trading days that are not nan throughout the day
|
||||
not_nan_nums = 0
|
||||
for _date, _df in df.groupby(level="date"):
|
||||
_df["paused"] = 0
|
||||
if not _df.loc[_df["volume"] < 0].empty:
|
||||
logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}")
|
||||
_df.loc[_df["volume"] < 0, "volume"] = np.nan
|
||||
|
||||
check_fields = set(_df.columns) - {
|
||||
"date",
|
||||
"paused",
|
||||
"factor",
|
||||
self._date_field_name,
|
||||
self._symbol_field_name,
|
||||
}
|
||||
if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all():
|
||||
all_nan_nums += 1
|
||||
not_nan_nums = 0
|
||||
_df["paused"] = 1
|
||||
if all_data:
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
else:
|
||||
all_nan_nums = 0
|
||||
not_nan_nums += 1
|
||||
_df["paused_num"] = not_nan_nums
|
||||
all_data.append(_df)
|
||||
all_data = all_data[: len(all_data) - all_nan_nums]
|
||||
if all_data:
|
||||
df = pd.concat(all_data, sort=False)
|
||||
else:
|
||||
logger.warning(f"data is empty: {_symbol}")
|
||||
df = pd.DataFrame()
|
||||
return df
|
||||
del df["date"]
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -523,11 +664,16 @@ class YahooNormalizeCN1d(YahooNormalizeCN, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
AM_RANGE = ("09:30:00", "11:29:00")
|
||||
PM_RANGE = ("13:00:00", "14:59:00")
|
||||
|
||||
CONSISTENT_1d = True
|
||||
CALC_PAUSED_NUM = True
|
||||
|
||||
def _get_calendar_list(self):
|
||||
return self.generate_1min_from_daily(self.calendar_list_1d)
|
||||
@@ -624,10 +770,54 @@ class Run(BaseRun):
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
def normalize_data_1d_extend(
|
||||
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
|
||||
):
|
||||
"""normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data)
|
||||
|
||||
Notes
|
||||
-----
|
||||
Steps to extend yahoo qlib data:
|
||||
|
||||
1. download qlib data: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data; save to <dir1>
|
||||
|
||||
2. collector source data: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#collector-data; save to <dir2>
|
||||
|
||||
3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir <dir1> --source_dir <dir2> --normalize_dir <dir3> --region CN --interval 1d
|
||||
|
||||
4. dump data: python scripts/dump_bin.py dump_update --csv_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date
|
||||
|
||||
5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir <dir1> --method parse_instruments
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_qlib_data_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data_1d_extend --old_qlib_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"{self.normalize_class_name}Extend")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
old_qlib_data_dir=old_qlib_data_dir,
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
|
||||
Reference in New Issue
Block a user