diff --git a/docs/conf.py b/docs/conf.py index 61fe784e7..6e52b0e34 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -191,7 +191,15 @@ man_pages = [(master_doc, "qlib", u"QLib Documentation", [author], 1)] # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, "QLib", u"QLib Documentation", author, "QLib", "One line description of project.", "Miscellaneous",), + ( + master_doc, + "QLib", + u"QLib Documentation", + author, + "QLib", + "One line description of project.", + "Miscellaneous", + ), ] diff --git a/examples/benchmarks/TFT/libs/tft_model.py b/examples/benchmarks/TFT/libs/tft_model.py index f40a1aece..b39f17825 100644 --- a/examples/benchmarks/TFT/libs/tft_model.py +++ b/examples/benchmarks/TFT/libs/tft_model.py @@ -721,7 +721,12 @@ class TemporalFusionTransformer: encoder_steps = self.num_encoder_steps # Inputs. - all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,)) + all_inputs = tf.keras.layers.Input( + shape=( + time_steps, + combined_input_size, + ) + ) unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs) @@ -861,7 +866,10 @@ class TemporalFusionTransformer: """Returns LSTM cell initialized with default parameters.""" if self.use_cudnn: lstm = tf.keras.layers.CuDNNLSTM( - self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False, + self.hidden_layer_size, + return_sequences=True, + return_state=return_state, + stateful=False, ) else: lstm = tf.keras.layers.LSTM( diff --git a/examples/highfreq/highfreq_handler.py b/examples/highfreq/highfreq_handler.py index 2fc411ab6..d35650514 100644 --- a/examples/highfreq/highfreq_handler.py +++ b/examples/highfreq/highfreq_handler.py @@ -20,7 +20,10 @@ class HighFreqHandler(DataHandlerLP): new_l = [] for p in proc_l: p["kwargs"].update( - {"fit_start_time": fit_start_time, "fit_end_time": fit_end_time,} + { + "fit_start_time": fit_start_time, + "fit_end_time": fit_end_time, + } ) new_l.append(p) return new_l @@ -30,7 +33,11 @@ class HighFreqHandler(DataHandlerLP): data_loader = { "class": "QlibDataLoader", - "kwargs": {"config": self.get_feature_config(), "swap_level": False, "freq": "1min",}, + "kwargs": { + "config": self.get_feature_config(), + "swap_level": False, + "freq": "1min", + }, } super().__init__( instruments=instruments, @@ -61,7 +68,8 @@ class HighFreqHandler(DataHandlerLP): feature_ops = template_norm.format( template_if.format( - template_fillnan.format(template_paused.format("$close")), template_paused.format(price_field), + template_fillnan.format(template_paused.format("$close")), + template_paused.format(price_field), ), template_fillnan.format(template_paused.format("$close")), ) @@ -111,14 +119,24 @@ class HighFreqHandler(DataHandlerLP): class HighFreqBacktestHandler(DataHandler): def __init__( - self, instruments="csi300", start_time=None, end_time=None, + self, + instruments="csi300", + start_time=None, + end_time=None, ): data_loader = { "class": "QlibDataLoader", - "kwargs": {"config": self.get_feature_config(), "swap_level": False, "freq": "1min",}, + "kwargs": { + "config": self.get_feature_config(), + "swap_level": False, + "freq": "1min", + }, } super().__init__( - instruments=instruments, start_time=start_time, end_time=end_time, data_loader=data_loader, + instruments=instruments, + start_time=start_time, + end_time=end_time, + data_loader=data_loader, ) def get_feature_config(self): @@ -137,7 +155,8 @@ class HighFreqBacktestHandler(DataHandler): fields += [ "Cut({0}, 240, None)".format( template_if.format( - template_fillnan.format(template_paused.format("$close")), template_paused.format(simpson_vwap), + template_fillnan.format(template_paused.format("$close")), + template_paused.format(simpson_vwap), ) ) ] diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py index 73510ef06..f0ab0dec2 100644 --- a/examples/highfreq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -65,6 +65,8 @@ class HighFreqNorm(Processor): feat = df_values[:, [0, 1, 2, 3, 4, 10]].reshape(-1, 6 * 240) feat_1 = df_values[:, [5, 6, 7, 8, 9, 11]].reshape(-1, 6 * 240) df_new_features = pd.DataFrame( - data=np.concatenate((feat, feat_1), axis=1), index=idx, columns=["FEATURE_%d" % i for i in range(12 * 240)], + data=np.concatenate((feat, feat_1), axis=1), + index=idx, + columns=["FEATURE_%d" % i for i in range(12 * 240)], ).sort_index() return df_new_features diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 0bfd0c2a0..01de59c0e 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -63,7 +63,13 @@ class HighfreqWorkflow(object): "module_path": "highfreq_handler", "kwargs": DATA_HANDLER_CONFIG0, }, - "segments": {"train": (start_time, train_end_time), "test": (test_start_time, end_time,),}, + "segments": { + "train": (start_time, train_end_time), + "test": ( + test_start_time, + end_time, + ), + }, }, }, "dataset_backtest": { @@ -75,7 +81,13 @@ class HighfreqWorkflow(object): "module_path": "highfreq_handler", "kwargs": DATA_HANDLER_CONFIG1, }, - "segments": {"train": (start_time, train_end_time), "test": (test_start_time, end_time,),}, + "segments": { + "train": (start_time, train_end_time), + "test": ( + test_start_time, + end_time, + ), + }, }, }, } @@ -140,11 +152,24 @@ class HighfreqWorkflow(object): "start_time": "2021-01-19 00:00:00", "end_time": "2021-01-25 16:00:00", }, - segment_kwargs={"test": ("2021-01-19 00:00:00", "2021-01-25 16:00:00",),}, + segment_kwargs={ + "test": ( + "2021-01-19 00:00:00", + "2021-01-25 16:00:00", + ), + }, ) dataset_backtest.init( - handler_kwargs={"start_time": "2021-01-19 00:00:00", "end_time": "2021-01-25 16:00:00",}, - segment_kwargs={"test": ("2021-01-19 00:00:00", "2021-01-25 16:00:00",),}, + handler_kwargs={ + "start_time": "2021-01-19 00:00:00", + "end_time": "2021-01-25 16:00:00", + }, + segment_kwargs={ + "test": ( + "2021-01-19 00:00:00", + "2021-01-25 16:00:00", + ), + }, ) ##=============get data============= diff --git a/examples/run_all_model.py b/examples/run_all_model.py index d356b4128..d587eff15 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -34,7 +34,10 @@ exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name) exp_manager = { "class": "MLflowExpManager", "module_path": "qlib.workflow.expm", - "kwargs": {"uri": "file:" + exp_path, "default_exp_name": "Experiment",}, + "kwargs": { + "uri": "file:" + exp_path, + "default_exp_name": "Experiment", + }, } if not exists_qlib_data(provider_uri): print(f"Qlib data is not found in {provider_uri}") diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 6f5c11dc0..d5dab8917 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -81,7 +81,10 @@ if __name__ == "__main__": "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.strategy", - "kwargs": {"topk": 50, "n_drop": 5,}, + "kwargs": { + "topk": 50, + "n_drop": 5, + }, }, "backtest": { "verbose": False, diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 24526e332..743f89462 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -39,7 +39,13 @@ class YahooData: INTERVAL_1d = "1d" def __init__( - self, timezone: str = None, start=None, end=None, interval="1d", delay=0, show_1min_logging: bool = False, + self, + timezone: str = None, + start=None, + end=None, + interval="1d", + delay=0, + show_1min_logging: bool = False, ): """ @@ -119,7 +125,11 @@ class YahooData: self._sleep() _remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval return self.get_data_from_remote( - symbol, interval=_remote_interval, start=start_, end=end_, show_1min_logging=self._show_1min_logging, + symbol, + interval=_remote_interval, + start=start_, + end=end_, + show_1min_logging=self._show_1min_logging, ) _result = None @@ -428,7 +438,9 @@ class YahooNormalize: DAILY_FORMAT = "%Y-%m-%d" def __init__( - self, date_field_name: str = "date", symbol_field_name: str = "symbol", + self, + date_field_name: str = "date", + symbol_field_name: str = "symbol", ): """ @@ -446,7 +458,10 @@ class YahooNormalize: @staticmethod def normalize_yahoo( - df: pd.DataFrame, calendar_list: list = None, date_field_name: str = "date", symbol_field_name: str = "symbol", + df: pd.DataFrame, + calendar_list: list = None, + date_field_name: str = "date", + symbol_field_name: str = "symbol", ): if df.empty: return df @@ -551,7 +566,9 @@ class YahooNormalize1min(YahooNormalize, ABC): CONSISTENT_1d = False def __init__( - self, date_field_name: str = "date", symbol_field_name: str = "symbol", + self, + date_field_name: str = "date", + symbol_field_name: str = "symbol", ): """ diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index ab24fa9ca..4811fd486 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -153,13 +153,22 @@ class DumpDataBase: @staticmethod def _read_calendars(calendar_path: Path) -> List[pd.Timestamp]: - return sorted(map(pd.Timestamp, pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(),)) + return sorted( + map( + pd.Timestamp, + pd.read_csv(calendar_path, header=None).loc[:, 0].tolist(), + ) + ) def _read_instruments(self, instrument_path: Path) -> pd.DataFrame: df = pd.read_csv( instrument_path, sep=self.INSTRUMENTS_SEP, - names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD,], + names=[ + self.symbol_field_name, + self.INSTRUMENTS_START_FIELD, + self.INSTRUMENTS_END_FIELD, + ], ) return df diff --git a/setup.py b/setup.py index d8a9d9efa..83cf6e1b6 100644 --- a/setup.py +++ b/setup.py @@ -70,10 +70,16 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as f: # Cython Extensions extensions = [ Extension( - "qlib.data._libs.rolling", ["qlib/data/_libs/rolling.pyx"], language="c++", include_dirs=[NUMPY_INCLUDE], + "qlib.data._libs.rolling", + ["qlib/data/_libs/rolling.pyx"], + language="c++", + include_dirs=[NUMPY_INCLUDE], ), Extension( - "qlib.data._libs.expanding", ["qlib/data/_libs/expanding.pyx"], language="c++", include_dirs=[NUMPY_INCLUDE], + "qlib.data._libs.expanding", + ["qlib/data/_libs/expanding.pyx"], + language="c++", + include_dirs=[NUMPY_INCLUDE], ), ] @@ -92,7 +98,9 @@ setup( # py_modules=['qlib'], entry_points={ # 'console_scripts': ['mycli=mymodule:cli'], - "console_scripts": ["qrun=qlib.workflow.cli:run",], + "console_scripts": [ + "qrun=qlib.workflow.cli:run", + ], }, ext_modules=extensions, install_requires=REQUIRED, diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 8b3819c83..f6e77cba4 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -78,7 +78,10 @@ port_analysis_config = { "strategy": { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.strategy", - "kwargs": {"topk": 50, "n_drop": 5,}, + "kwargs": { + "topk": 50, + "n_drop": 5, + }, }, "backtest": { "verbose": False, @@ -173,7 +176,9 @@ class TestAllFlow(TestAutoData): def test_1_backtest(self): analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID) self.assertGreaterEqual( - analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], 0.10, "backtest failed", + analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], + 0.10, + "backtest failed", ) diff --git a/tests/test_dump_data.py b/tests/test_dump_data.py index de649c37e..dfa7f8556 100644 --- a/tests/test_dump_data.py +++ b/tests/test_dump_data.py @@ -40,7 +40,9 @@ class TestDumpData(unittest.TestCase): TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv"))) provider_uri = str(QLIB_DIR.resolve()) qlib.init( - provider_uri=provider_uri, expression_cache=None, dataset_cache=None, + provider_uri=provider_uri, + expression_cache=None, + dataset_cache=None, ) @classmethod @@ -52,7 +54,10 @@ class TestDumpData(unittest.TestCase): def test_1_dump_calendars(self): ori_calendars = set( - map(pd.Timestamp, pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values,) + map( + pd.Timestamp, + pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values, + ) ) res_calendars = set(D.calendar()) assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed" diff --git a/tests/test_get_data.py b/tests/test_get_data.py index d5637b025..c511d1b91 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -26,7 +26,9 @@ class TestGetData(unittest.TestCase): def setUpClass(cls) -> None: provider_uri = str(QLIB_DIR.resolve()) qlib.init( - provider_uri=provider_uri, expression_cache=None, dataset_cache=None, + provider_uri=provider_uri, + expression_cache=None, + dataset_cache=None, ) @classmethod