1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

modify the YahooNormalize1min factor calculation

This commit is contained in:
zhupr
2021-06-22 11:15:09 +08:00
parent 99fb49650a
commit 46714adf4c
2 changed files with 117 additions and 125 deletions

View File

@@ -84,30 +84,29 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
qqplot_data = _plt_fig.gca().lines
fig = go.Figure()
fig.add_trace({
'type': 'scatter',
'x': qqplot_data[0].get_xdata(),
# 'x': [0, 1],
'y': qqplot_data[0].get_ydata(),
# 'y': [1, 2],
'mode': 'markers',
'marker': {
'color': '#19d3f3'
fig.add_trace(
{
"type": "scatter",
"x": qqplot_data[0].get_xdata(),
# 'x': [0, 1],
"y": qqplot_data[0].get_ydata(),
# 'y': [1, 2],
"mode": "markers",
"marker": {"color": "#19d3f3"},
}
})
)
fig.add_trace({
'type': 'scatter',
'x': qqplot_data[1].get_xdata(),
# 'x': [0, 1],
'y': qqplot_data[1].get_ydata(),
# 'y': [1, 2],
'mode': 'lines',
'line': {
'color': '#636efa'
fig.add_trace(
{
"type": "scatter",
"x": qqplot_data[1].get_xdata(),
# 'x': [0, 1],
"y": qqplot_data[1].get_ydata(),
# 'y': [1, 2],
"mode": "lines",
"line": {"color": "#636efa"},
}
})
)
del qqplot_data
return fig

View File

@@ -478,8 +478,8 @@ class YahooNormalize1min(YahooNormalize, ABC):
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
# Whether the trading day of 1min data is consistent with 1d
CONSISTENT_1d = False
CALC_PAUSED_NUM = False
CONSISTENT_1d = True
CALC_PAUSED_NUM = True
@property
def calendar_list_1d(self):
@@ -500,7 +500,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
Returns
------
data_1d: pd.DataFrame
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
"""
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
@@ -516,6 +516,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
if df.empty:
return df
df = df.copy()
df = df.sort_values(self._date_field_name)
symbol = df.iloc[0][self._symbol_field_name]
# get 1d data from yahoo
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
@@ -523,7 +524,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
data_1d = data_1d.copy()
if data_1d is None or data_1d.empty:
df["factor"] = 1
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]
# TODO: np.nan or 1 or 0
df["paused"] = np.nan
else:
@@ -534,9 +535,13 @@ class YahooNormalize1min(YahooNormalize, ABC):
data_1d = data_1d.set_index(self._date_field_name)
# add factor from 1d data
# NOTE: yahoo 1d data info:
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
df.set_index("date_tmp", inplace=True)
df.loc[:, "factor"] = data_1d["factor"]
df.loc[:, "factor"] = data_1d["close"] / df["close"]
df.loc[:, "paused"] = data_1d["paused"]
df.reset_index("date_tmp", drop=True, inplace=True)
@@ -619,6 +624,61 @@ class YahooNormalize1min(YahooNormalize, ABC):
raise NotImplementedError("rewrite _get_1d_calendar_list")
class YahooNormalize1minOffline(YahooNormalize1min):
"""Normalised to 1min using local 1d data"""
def __init__(
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
Parameters
----------
qlib_data_1d_dir: str, Path
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
self.qlib_data_1d_dir = qlib_data_1d_dir
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
self._all_1d_data = self._get_all_1d_data()
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
return list(D.calendar(freq="day"))
def _get_all_1d_data(self):
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
df.reset_index(inplace=True)
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
return df
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
"""get 1d data
Returns
------
data_1d: pd.DataFrame
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
"""
return self._all_1d_data[
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
]
class YahooNormalizeUS:
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: from MSN
@@ -629,8 +689,8 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):
pass
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
CONSISTENT_1d = False
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
CALC_PAUSED_NUM = False
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
# TODO: support 1min
@@ -657,20 +717,17 @@ class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
pass
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
AM_RANGE = ("09:30:00", "11:29:00")
PM_RANGE = ("13:00:00", "14:59:00")
CONSISTENT_1d = True
CALC_PAUSED_NUM = True
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
return self.generate_1min_from_daily(self.calendar_list_1d)
def symbol_to_yahoo(self, symbol):
if "." not in symbol:
_exchange = symbol[:2].lower()
_exchange = "ss" if _exchange == "sh" else _exchange
_exchange = symbol[:2]
_exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange
symbol = symbol[2:] + "." + _exchange
return symbol
@@ -678,63 +735,6 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
return get_calendar_list("ALL")
class YahooNormalizeCN1minOffline(YahooNormalizeCN1min):
"""Normalised to 1min using local 1d data
1d data usually from: Normalised to 1min using local 1d data
"""
def __init__(
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
):
"""
Parameters
----------
qlib_data_1d_dir: str, Path
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
date_field_name: str
date field name, default is date
symbol_field_name: str
symbol field name, default is symbol
"""
self.qlib_data_1d_dir = qlib_data_1d_dir
super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name)
self._all_1d_data = self._get_all_1d_data()
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
return list(D.calendar(freq="day"))
def _get_all_1d_data(self):
import qlib
from qlib.data import D
qlib.init(provider_uri=self.qlib_data_1d_dir)
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor"], freq="day")
df.reset_index(inplace=True)
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
return df
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
"""get 1d data
Returns
------
data_1d: pd.DataFrame
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
"""
return self._all_1d_data[
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
]
class Run(BaseRun):
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
"""
@@ -811,7 +811,13 @@ class Run(BaseRun):
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
)
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", end_date: str = None):
def normalize_data(
self,
date_field_name: str = "date",
symbol_field_name: str = "symbol",
end_date: str = None,
qlib_data_1d_dir: str = None,
):
"""normalize data
Parameters
@@ -822,12 +828,29 @@ class Run(BaseRun):
symbol field name, default symbol
end_date: str
if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None
qlib_data_1d_dir: str
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
qlib_data_1d can be obtained like this:
$ python scripts/get_data.py qlilb_data --target_dir <qlib_data_1d_dir> --interval 1d
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
or:
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
Examples
---------
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
"""
super(Run, self).normalize_data(date_field_name, symbol_field_name, end_date=end_date)
if self.interval.lower() == "1min":
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
# TODO: add reference url
raise ValueError(
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: "
)
super(Run, self).normalize_data(
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
)
def normalize_data_1d_extend(
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
@@ -873,36 +896,6 @@ class Run(BaseRun):
)
yc.normalize()
def normalize_data_1min_cn_offline(
self, qlib_data_1d_dir: str, date_field_name: str = "date", symbol_field_name: str = "symbol"
):
"""Normalised to 1min using local 1d data
Parameters
----------
qlib_data_1d_dir: str
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
date_field_name: str
date field name, default date
symbol_field_name: str
symbol field name, default symbol
Examples
---------
$ python collector.py normalize_data_1min_cn_offline --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
"""
_class = getattr(self._cur_module, f"{self.normalize_class_name}Offline")
yc = Normalize(
source_dir=self.source_dir,
target_dir=self.normalize_dir,
normalize_class=_class,
max_workers=self.max_workers,
date_field_name=date_field_name,
symbol_field_name=symbol_field_name,
qlib_data_1d_dir=qlib_data_1d_dir,
)
yc.normalize()
def download_today_data(
self,
max_collector_count=2,
@@ -987,7 +980,7 @@ class Run(BaseRun):
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
# download qlib 1d data
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
if not exists_qlib_data(qlib_data_1d_dir):
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
@@ -995,7 +988,7 @@ class Run(BaseRun):
self.download_data(delay=1, start=trading_date, end=end_date, check_data_length=1)
# normalize data
self.normalize_data_1d_extend(str(qlib_data_1d_dir))
self.normalize_data_1d_extend(qlib_data_1d_dir)
# dump bin
_dump = DumpDataUpdate(