diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py index 1d444b104..ddf97fb36 100644 --- a/qlib/contrib/report/analysis_model/analysis_model_performance.py +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -84,30 +84,29 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure: qqplot_data = _plt_fig.gca().lines fig = go.Figure() - fig.add_trace({ - 'type': 'scatter', - 'x': qqplot_data[0].get_xdata(), - # 'x': [0, 1], - 'y': qqplot_data[0].get_ydata(), - # 'y': [1, 2], - 'mode': 'markers', - 'marker': { - 'color': '#19d3f3' + fig.add_trace( + { + "type": "scatter", + "x": qqplot_data[0].get_xdata(), + # 'x': [0, 1], + "y": qqplot_data[0].get_ydata(), + # 'y': [1, 2], + "mode": "markers", + "marker": {"color": "#19d3f3"}, } - }) + ) - fig.add_trace({ - 'type': 'scatter', - 'x': qqplot_data[1].get_xdata(), - # 'x': [0, 1], - 'y': qqplot_data[1].get_ydata(), - # 'y': [1, 2], - 'mode': 'lines', - 'line': { - 'color': '#636efa' + fig.add_trace( + { + "type": "scatter", + "x": qqplot_data[1].get_xdata(), + # 'x': [0, 1], + "y": qqplot_data[1].get_ydata(), + # 'y': [1, 2], + "mode": "lines", + "line": {"color": "#636efa"}, } - - }) + ) del qqplot_data return fig diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 44cfce7ca..b474d3924 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -478,8 +478,8 @@ class YahooNormalize1min(YahooNormalize, ABC): PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00") # Whether the trading day of 1min data is consistent with 1d - CONSISTENT_1d = False - CALC_PAUSED_NUM = False + CONSISTENT_1d = True + CALC_PAUSED_NUM = True @property def calendar_list_1d(self): @@ -500,7 +500,7 @@ class YahooNormalize1min(YahooNormalize, ABC): Returns ------ data_1d: pd.DataFrame - set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {} + data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] """ data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) @@ -516,6 +516,7 @@ class YahooNormalize1min(YahooNormalize, ABC): if df.empty: return df df = df.copy() + df = df.sort_values(self._date_field_name) symbol = df.iloc[0][self._symbol_field_name] # get 1d data from yahoo _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) @@ -523,7 +524,7 @@ class YahooNormalize1min(YahooNormalize, ABC): data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: - df["factor"] = 1 + df["factor"] = 1 / df.loc[df["close"].first_valid_index()] # TODO: np.nan or 1 or 0 df["paused"] = np.nan else: @@ -534,9 +535,13 @@ class YahooNormalize1min(YahooNormalize, ABC): data_1d = data_1d.set_index(self._date_field_name) # add factor from 1d data + # NOTE: yahoo 1d data info: + # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. + # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. + # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) df.set_index("date_tmp", inplace=True) - df.loc[:, "factor"] = data_1d["factor"] + df.loc[:, "factor"] = data_1d["close"] / df["close"] df.loc[:, "paused"] = data_1d["paused"] df.reset_index("date_tmp", drop=True, inplace=True) @@ -619,6 +624,61 @@ class YahooNormalize1min(YahooNormalize, ABC): raise NotImplementedError("rewrite _get_1d_calendar_list") +class YahooNormalize1minOffline(YahooNormalize1min): + """Normalised to 1min using local 1d data""" + + def __init__( + self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + ): + """ + + Parameters + ---------- + qlib_data_1d_dir: str, Path + the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + self.qlib_data_1d_dir = qlib_data_1d_dir + super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name) + self._all_1d_data = self._get_all_1d_data() + + def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: + import qlib + from qlib.data import D + + qlib.init(provider_uri=self.qlib_data_1d_dir) + return list(D.calendar(freq="day")) + + def _get_all_1d_data(self): + import qlib + from qlib.data import D + + qlib.init(provider_uri=self.qlib_data_1d_dir) + df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") + df.reset_index(inplace=True) + df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) + df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) + return df + + def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: + """get 1d data + + Returns + ------ + data_1d: pd.DataFrame + data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] + + """ + return self._all_1d_data[ + (self._all_1d_data[self._symbol_field_name] == symbol.upper()) + & (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start)) + & (self._all_1d_data[self._date_field_name] < pd.Timestamp(end)) + ] + + class YahooNormalizeUS: def _get_calendar_list(self) -> Iterable[pd.Timestamp]: # TODO: from MSN @@ -629,8 +689,8 @@ class YahooNormalizeUS1d(YahooNormalizeUS, YahooNormalize1d): pass -class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min): - CONSISTENT_1d = False +class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline): + CALC_PAUSED_NUM = False def _get_calendar_list(self) -> Iterable[pd.Timestamp]: # TODO: support 1min @@ -657,20 +717,17 @@ class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend): pass -class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): +class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): AM_RANGE = ("09:30:00", "11:29:00") PM_RANGE = ("13:00:00", "14:59:00") - CONSISTENT_1d = True - CALC_PAUSED_NUM = True - def _get_calendar_list(self) -> Iterable[pd.Timestamp]: return self.generate_1min_from_daily(self.calendar_list_1d) def symbol_to_yahoo(self, symbol): if "." not in symbol: - _exchange = symbol[:2].lower() - _exchange = "ss" if _exchange == "sh" else _exchange + _exchange = symbol[:2] + _exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange symbol = symbol[2:] + "." + _exchange return symbol @@ -678,63 +735,6 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): return get_calendar_list("ALL") -class YahooNormalizeCN1minOffline(YahooNormalizeCN1min): - """Normalised to 1min using local 1d data - 1d data usually from: Normalised to 1min using local 1d data - """ - - def __init__( - self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs - ): - """ - - Parameters - ---------- - qlib_data_1d_dir: str, Path - the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - self.qlib_data_1d_dir = qlib_data_1d_dir - super(YahooNormalizeCN1minOffline, self).__init__(date_field_name, symbol_field_name) - self._all_1d_data = self._get_all_1d_data() - - def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: - import qlib - from qlib.data import D - - qlib.init(provider_uri=self.qlib_data_1d_dir) - return list(D.calendar(freq="day")) - - def _get_all_1d_data(self): - import qlib - from qlib.data import D - - qlib.init(provider_uri=self.qlib_data_1d_dir) - df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor"], freq="day") - df.reset_index(inplace=True) - df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) - df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) - return df - - def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: - """get 1d data - - Returns - ------ - data_1d: pd.DataFrame - set(data_1d.columns) - set([self._date_field_name, self._symbol_field_name, "paused", "volume", "factor"]) == {} - - """ - return self._all_1d_data[ - (self._all_1d_data[self._symbol_field_name] == symbol.upper()) - & (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start)) - & (self._all_1d_data[self._date_field_name] < pd.Timestamp(end)) - ] - - class Run(BaseRun): def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): """ @@ -811,7 +811,13 @@ class Run(BaseRun): max_collector_count, delay, start, end, self.interval, check_data_length, limit_nums ) - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", end_date: str = None): + def normalize_data( + self, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + end_date: str = None, + qlib_data_1d_dir: str = None, + ): """normalize data Parameters @@ -822,12 +828,29 @@ class Run(BaseRun): symbol field name, default symbol end_date: str if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None + qlib_data_1d_dir: str + if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data; + + qlib_data_1d can be obtained like this: + $ python scripts/get_data.py qlilb_data --target_dir --interval 1d + $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --trading_date 2021-06-01 + or: + download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo Examples --------- $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d + $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min """ - super(Run, self).normalize_data(date_field_name, symbol_field_name, end_date=end_date) + if self.interval.lower() == "1min": + if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): + # TODO: add reference url + raise ValueError( + "If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: " + ) + super(Run, self).normalize_data( + date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir + ) def normalize_data_1d_extend( self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" @@ -873,36 +896,6 @@ class Run(BaseRun): ) yc.normalize() - def normalize_data_1min_cn_offline( - self, qlib_data_1d_dir: str, date_field_name: str = "date", symbol_field_name: str = "symbol" - ): - """Normalised to 1min using local 1d data - - Parameters - ---------- - qlib_data_1d_dir: str - the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data - date_field_name: str - date field name, default date - symbol_field_name: str - symbol field name, default symbol - - Examples - --------- - $ python collector.py normalize_data_1min_cn_offline --qlib_data_1d_dir ~/.qlib/qlib_data/cn_1d --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min - """ - _class = getattr(self._cur_module, f"{self.normalize_class_name}Offline") - yc = Normalize( - source_dir=self.source_dir, - target_dir=self.normalize_dir, - normalize_class=_class, - max_workers=self.max_workers, - date_field_name=date_field_name, - symbol_field_name=symbol_field_name, - qlib_data_1d_dir=qlib_data_1d_dir, - ) - yc.normalize() - def download_today_data( self, max_collector_count=2, @@ -987,7 +980,7 @@ class Run(BaseRun): end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") # download qlib 1d data - qlib_data_1d_dir = Path(qlib_data_1d_dir).expanduser().resolve() + qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region) @@ -995,7 +988,7 @@ class Run(BaseRun): self.download_data(delay=1, start=trading_date, end=end_date, check_data_length=1) # normalize data - self.normalize_data_1d_extend(str(qlib_data_1d_dir)) + self.normalize_data_1d_extend(qlib_data_1d_dir) # dump bin _dump = DumpDataUpdate(