From 8cf7bb3aaf2783a56f13e656b189cb6baf874e0a Mon Sep 17 00:00:00 2001 From: Linlang Date: Tue, 5 Mar 2024 17:24:03 +0800 Subject: [PATCH] fix CI error --- examples/benchmarks/TRA/src/model.py | 3 --- qlib/backtest/__init__.py | 16 +++++++++------- qlib/backtest/report.py | 8 +++++--- qlib/contrib/eva/alpha.py | 1 + qlib/contrib/model/pytorch_tra.py | 3 --- qlib/contrib/strategy/signal_strategy.py | 1 - qlib/data/dataset/utils.py | 8 ++------ qlib/model/ens/ensemble.py | 2 -- qlib/model/riskmodel/shrink.py | 4 +--- qlib/workflow/online/strategy.py | 1 - scripts/dump_bin.py | 4 +--- scripts/dump_pit.py | 8 +++++--- 12 files changed, 24 insertions(+), 35 deletions(-) diff --git a/examples/benchmarks/TRA/src/model.py b/examples/benchmarks/TRA/src/model.py index affb115a1..ebafd6a52 100644 --- a/examples/benchmarks/TRA/src/model.py +++ b/examples/benchmarks/TRA/src/model.py @@ -324,7 +324,6 @@ class TRAModel(Model): class LSTM(nn.Module): - """LSTM Model Args: @@ -414,7 +413,6 @@ class PositionalEncoding(nn.Module): class Transformer(nn.Module): - """Transformer Model Args: @@ -475,7 +473,6 @@ class Transformer(nn.Module): class TRA(nn.Module): - """Temporal Routing Adaptor (TRA) TRA takes historical prediction errors & latent representation as inputs, diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index d784aed57..9daba9115 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -162,13 +162,15 @@ def create_account_instance( init_cash=init_cash, position_dict=position_dict, pos_type=pos_type, - benchmark_config={} - if benchmark is None - else { - "benchmark": benchmark, - "start_time": start_time, - "end_time": end_time, - }, + benchmark_config=( + {} + if benchmark is None + else { + "benchmark": benchmark, + "start_time": start_time, + "end_time": end_time, + } + ), ) diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 8e7440ba9..e7c6041ef 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -622,9 +622,11 @@ class Indicator: print( "[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format( freq, - trade_start_time - if isinstance(trade_start_time, str) - else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"), + ( + trade_start_time + if isinstance(trade_start_time, str) + else trade_start_time.strftime("%Y-%m-%d %H:%M:%S") + ), fulfill_rate, price_advantage, positive_rate, diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index 95ec9b91e..86d366d20 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -3,6 +3,7 @@ Here is a batch of evaluation functions. The interface should be redesigned carefully in the future. """ + import pandas as pd from typing import Tuple from qlib import get_module_logger diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index 964febf11..bc9a6aa97 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -511,7 +511,6 @@ class TRAModel(Model): class RNN(nn.Module): - """RNN Model Args: @@ -601,7 +600,6 @@ class PositionalEncoding(nn.Module): class Transformer(nn.Module): - """Transformer Model Args: @@ -649,7 +647,6 @@ class Transformer(nn.Module): class TRA(nn.Module): - """Temporal Routing Adaptor (TRA) TRA takes historical prediction errors & latent representation as inputs, diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index 9ba960eeb..bad19ddfd 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -373,7 +373,6 @@ class WeightStrategyBase(BaseSignalStrategy): class EnhancedIndexingStrategy(WeightStrategyBase): - """Enhanced Indexing Strategy Enhanced indexing combines the arts of active management and passive management, diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 76f3ed404..f19dfe08f 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -71,15 +71,11 @@ def fetch_df_by_index( if fetch_orig: for slc in idx_slc: if slc != slice(None, None): - return df.loc[ - pd.IndexSlice[idx_slc], - ] # noqa: E231 + return df.loc[pd.IndexSlice[idx_slc],] # noqa: E231 else: # pylint: disable=W0120 return df else: - return df.loc[ - pd.IndexSlice[idx_slc], - ] # noqa: E231 + return df.loc[pd.IndexSlice[idx_slc],] # noqa: E231 def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame: diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index ede1f8e3a..1ebb16f18 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -30,7 +30,6 @@ class Ensemble: class SingleKeyEnsemble(Ensemble): - """ Extract the object if there is only one key and value in the dict. Make the result more readable. {Only key: Only value} -> Only value @@ -64,7 +63,6 @@ class SingleKeyEnsemble(Ensemble): class RollingEnsemble(Ensemble): - """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". diff --git a/qlib/model/riskmodel/shrink.py b/qlib/model/riskmodel/shrink.py index b2594f707..c3c0e48ef 100644 --- a/qlib/model/riskmodel/shrink.py +++ b/qlib/model/riskmodel/shrink.py @@ -247,9 +247,7 @@ class ShrinkCovEstimator(RiskModel): v1 = y.T.dot(z) / t - cov_mkt[:, None] * S roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt v3 = z.T.dot(z) / t - var_mkt * S - roff3 = ( - np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 - ) + roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 roff = 2 * roff1 - roff3 rho = rdiag + roff diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index f2988d843..d545e4bc9 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -90,7 +90,6 @@ class OnlineStrategy: class RollingStrategy(OnlineStrategy): - """ This example strategy always uses the latest rolling model sas online models. """ diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 92abc8bee..a65b1f58e 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -146,9 +146,7 @@ class DumpDataBase: return ( self._include_fields if self._include_fields - else set(df_columns) - set(self._exclude_fields) - if self._exclude_fields - else df_columns + else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns ) @staticmethod diff --git a/scripts/dump_pit.py b/scripts/dump_pit.py index 34d304ed7..1ca9cfc94 100644 --- a/scripts/dump_pit.py +++ b/scripts/dump_pit.py @@ -132,9 +132,11 @@ class DumpPitData: return ( set(self._include_fields) if self._include_fields - else set(df[self.field_column_name]) - set(self._exclude_fields) - if self._exclude_fields - else set(df[self.field_column_name]) + else ( + set(df[self.field_column_name]) - set(self._exclude_fields) + if self._exclude_fields + else set(df[self.field_column_name]) + ) ) def get_filenames(self, symbol, field, interval):