mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
modify the YahooNormalize1min factor calculation
This commit is contained in:
@@ -84,30 +84,29 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
|
||||
qqplot_data = _plt_fig.gca().lines
|
||||
fig = go.Figure()
|
||||
|
||||
fig.add_trace({
|
||||
'type': 'scatter',
|
||||
'x': qqplot_data[0].get_xdata(),
|
||||
# 'x': [0, 1],
|
||||
'y': qqplot_data[0].get_ydata(),
|
||||
# 'y': [1, 2],
|
||||
'mode': 'markers',
|
||||
'marker': {
|
||||
'color': '#19d3f3'
|
||||
fig.add_trace(
|
||||
{
|
||||
"type": "scatter",
|
||||
"x": qqplot_data[0].get_xdata(),
|
||||
# 'x': [0, 1],
|
||||
"y": qqplot_data[0].get_ydata(),
|
||||
# 'y': [1, 2],
|
||||
"mode": "markers",
|
||||
"marker": {"color": "#19d3f3"},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
fig.add_trace({
|
||||
'type': 'scatter',
|
||||
'x': qqplot_data[1].get_xdata(),
|
||||
# 'x': [0, 1],
|
||||
'y': qqplot_data[1].get_ydata(),
|
||||
# 'y': [1, 2],
|
||||
'mode': 'lines',
|
||||
'line': {
|
||||
'color': '#636efa'
|
||||
fig.add_trace(
|
||||
{
|
||||
"type": "scatter",
|
||||
"x": qqplot_data[1].get_xdata(),
|
||||
# 'x': [0, 1],
|
||||
"y": qqplot_data[1].get_ydata(),
|
||||
# 'y': [1, 2],
|
||||
"mode": "lines",
|
||||
"line": {"color": "#636efa"},
|
||||
}
|
||||
|
||||
})
|
||||
)
|
||||
del qqplot_data
|
||||
return fig
|
||||
|
||||
|
||||
@@ -478,8 +478,8 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00")
|
||||
|
||||
# Whether the trading day of 1min data is consistent with 1d
|
||||
CONSISTENT_1d = False
|
||||
CALC_PAUSED_NUM = False
|
||||
CONSISTENT_1d = True
|
||||
CALC_PAUSED_NUM = True
|
||||
|
||||
@property
|
||||
def calendar_list_1d(self):
|
||||
@@ -500,7 +500,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end)
|
||||
@@ -516,6 +516,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df = df.sort_values(self._date_field_name)
|
||||
symbol = df.iloc[0][self._symbol_field_name]
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
@@ -523,7 +524,7 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end)
|
||||
data_1d = data_1d.copy()
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
df["factor"] = 1 / df.loc[df["close"].first_valid_index()]
|
||||
# TODO: np.nan or 1 or 0
|
||||
df["paused"] = np.nan
|
||||
else:
|
||||
@@ -534,9 +535,13 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
data_1d = data_1d.set_index(self._date_field_name)
|
||||
|
||||
# add factor from 1d data
|
||||
# NOTE: yahoo 1d data info:
|
||||
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
|
||||
df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
df.set_index("date_tmp", inplace=True)
|
||||
df.loc[:, "factor"] = data_1d["factor"]
|
||||
df.loc[:, "factor"] = data_1d["close"] / df["close"]
|
||||
df.loc[:, "paused"] = data_1d["paused"]
|
||||
df.reset_index("date_tmp", drop=True, inplace=True)
|
||||
|
||||
@@ -619,6 +624,61 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
raise NotImplementedError("rewrite _get_1d_calendar_list")
|
||||
|
||||
|
||||
class YahooNormalize1minOffline(YahooNormalize1min):
|
||||
"""Normalised to 1min using local 1d data"""
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self.qlib_data_1d_dir = qlib_data_1d_dir
|
||||
super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self._all_1d_data = self._get_all_1d_data()
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def _get_all_1d_data(self):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day")
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
return df
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"]
|
||||
|
||||
"""
|
||||
return self._all_1d_data[
|
||||
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
|
||||
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
|
||||
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
|
||||
|
||||
class YahooNormalizeUS:
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: from MSN
|
||||
@@ -629,8 +689,8 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min):
|
||||
CONSISTENT_1d = False
|
||||
class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline):
|
||||
CALC_PAUSED_NUM = False
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
# TODO: support 1min
|
||||
@@ -657,20 +717,17 @@ class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend):
|
||||
pass
|
||||
|
||||
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline):
|
||||
AM_RANGE = ("09:30:00", "11:29:00")
|
||||
PM_RANGE = ("13:00:00", "14:59:00")
|
||||
|
||||
CONSISTENT_1d = True
|
||||
CALC_PAUSED_NUM = True
|
||||
|
||||
def _get_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
return self.generate_1min_from_daily(self.calendar_list_1d)
|
||||
|
||||
def symbol_to_yahoo(self, symbol):
|
||||
if "." not in symbol:
|
||||
_exchange = symbol[:2].lower()
|
||||
_exchange = "ss" if _exchange == "sh" else _exchange
|
||||
_exchange = symbol[:2]
|
||||
_exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange
|
||||
symbol = symbol[2:] + "." + _exchange
|
||||
return symbol
|
||||
|
||||
@@ -678,63 +735,6 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class YahooNormalizeCN1minOffline(YahooNormalizeCN1min):
|
||||
"""Normalised to 1min using local 1d data
|
||||
1d data usually from: Normalised to 1min using local 1d data
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str, Path
|
||||
the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self.qlib_data_1d_dir = qlib_data_1d_dir
|
||||
super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name)
|
||||
self._all_1d_data = self._get_all_1d_data()
|
||||
|
||||
def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]:
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
return list(D.calendar(freq="day"))
|
||||
|
||||
def _get_all_1d_data(self):
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri=self.qlib_data_1d_dir)
|
||||
df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor"], freq="day")
|
||||
df.reset_index(inplace=True)
|
||||
df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True)
|
||||
df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns))
|
||||
return df
|
||||
|
||||
def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame:
|
||||
"""get 1d data
|
||||
|
||||
Returns
|
||||
------
|
||||
data_1d: pd.DataFrame
|
||||
set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {}
|
||||
|
||||
"""
|
||||
return self._all_1d_data[
|
||||
(self._all_1d_data[self._symbol_field_name] == symbol.upper())
|
||||
& (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start))
|
||||
& (self._all_1d_data[self._date_field_name] < pd.Timestamp(end))
|
||||
]
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
@@ -811,7 +811,13 @@ class Run(BaseRun):
|
||||
max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums
|
||||
)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", end_date: str = None):
|
||||
def normalize_data(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
end_date: str = None,
|
||||
qlib_data_1d_dir: str = None,
|
||||
):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
@@ -822,12 +828,29 @@ class Run(BaseRun):
|
||||
symbol field name, default symbol
|
||||
end_date: str
|
||||
if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None
|
||||
qlib_data_1d_dir: str
|
||||
if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data;
|
||||
|
||||
qlib_data_1d can be obtained like this:
|
||||
$ python scripts/get_data.py qlilb_data --target_dir <qlib_data_1d_dir> --interval 1d
|
||||
$ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir <qlib_data_1d_dir> --trading_date 2021-06-01
|
||||
or:
|
||||
download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d
|
||||
$ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name, end_date=end_date)
|
||||
if self.interval.lower() == "1min":
|
||||
if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists():
|
||||
# TODO: add reference url
|
||||
raise ValueError(
|
||||
"If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir <user qlib 1d data >, Reference: "
|
||||
)
|
||||
super(Run, self).normalize_data(
|
||||
date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir
|
||||
)
|
||||
|
||||
def normalize_data_1d_extend(
|
||||
self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol"
|
||||
@@ -873,36 +896,6 @@ class Run(BaseRun):
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
def normalize_data_1min_cn_offline(
|
||||
self, qlib_data_1d_dir: str, date_field_name: str = "date", symbol_field_name: str = "symbol"
|
||||
):
|
||||
"""Normalised to 1min using local 1d data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_data_1d_dir: str
|
||||
the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data_1min_cn_offline --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"{self.normalize_class_name}Offline")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
qlib_data_1d_dir=qlib_data_1d_dir,
|
||||
)
|
||||
yc.normalize()
|
||||
|
||||
def download_today_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
@@ -987,7 +980,7 @@ class Run(BaseRun):
|
||||
end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
# download qlib 1d data
|
||||
qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve()
|
||||
qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve())
|
||||
if not exists_qlib_data(qlib_data_1d_dir):
|
||||
GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region)
|
||||
|
||||
@@ -995,7 +988,7 @@ class Run(BaseRun):
|
||||
self.download_data(delay=1, start=trading_date, end=end_date, check_data_length=1)
|
||||
|
||||
# normalize data
|
||||
self.normalize_data_1d_extend(str(qlib_data_1d_dir))
|
||||
self.normalize_data_1d_extend(qlib_data_1d_dir)
|
||||
|
||||
# dump bin
|
||||
_dump = DumpDataUpdate(
|
||||
|
||||
Reference in New Issue
Block a user