From aef3f186c16ea3ea514710b915d6cd11cc9991f9 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Fri, 14 May 2021 06:58:02 +0000 Subject: [PATCH] format code --- examples/online_srv/online_management_simulate.py | 2 +- qlib/model/ens/ensemble.py | 5 +++-- qlib/model/ens/group.py | 2 +- qlib/model/trainer.py | 4 ++-- qlib/utils/serial.py | 1 - qlib/workflow/online/manager.py | 8 +++++--- qlib/workflow/online/strategy.py | 2 +- qlib/workflow/task/collect.py | 2 +- qlib/workflow/task/utils.py | 2 +- 9 files changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index c09b10aa7..4bb5022ee 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -113,7 +113,7 @@ class OnlineSimulationExample: self.rolling_gen = RollingGen( step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None ) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time. - self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR + self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR self.rolling_online_manager = OnlineManager( RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), trainer=self.trainer, diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 0f48ce728..4fa6a5ec6 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -16,9 +16,9 @@ class Ensemble: For example: {Rollinga_b: object, Rollingb_c: object} -> object When calling this class: - + Args: - ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging + ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging Returns: object: the ensemble object @@ -103,6 +103,7 @@ class AverageEnsemble(Ensemble): Returns: pd.DataFrame: the complete result of averaging and standardizing. """ + def __call__(self, ensemble_dict: dict) -> pd.DataFrame: # need to flatten the nested dict ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE) diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index 93903f433..7f45b06a5 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -64,7 +64,7 @@ class Group: else: raise NotImplementedError(f"Please specify valid `_ens_func`.") - def __call__(self, ungrouped_dict: dict, n_jobs:int=1, verbose:int=0, *args, **kwargs) -> dict: + def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict: """ Group the ungrouped_dict into different groups. diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 0c9c9e2c2..fd76e6728 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -240,13 +240,13 @@ class DelayTrainerR(TrainerR): """ Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. - + Args: recs (list): a list of Recorder, the tasks have been saved to them end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. experiment_name (str): the experiment name, None for use default name. kwargs: the params for end_train_func. - + Returns: List[Recorder]: a list of Recorders """ diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index c7c51bac2..263e632de 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -8,7 +8,6 @@ import dill from typing import Union - class Serializable: """ Serializable will change the behaviors of pickle. diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 7d1c723f3..f2a576560 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -41,8 +41,8 @@ class OnlineManager(Serializable): It also provides a history recording of which models are online at what time. """ - STATUS_SIMULATING = "simulating" # when calling `simulate` - STATUS_NORMAL = "normal" # the normal status + STATUS_SIMULATING = "simulating" # when calling `simulate` + STATUS_NORMAL = "normal" # the normal status def __init__( self, @@ -210,7 +210,9 @@ class OnlineManager(Serializable): SIM_LOG_LEVEL = logging.INFO + 1 SIM_LOG_NAME = "SIMULATE_INFO" - def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}) -> Union[pd.Series, pd.DataFrame]: + def simulate( + self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={} + ) -> Union[pd.Series, pd.DataFrame]: """ Starting from the current time, this method will simulate every routine in OnlineManager until the end time. diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index 491b191dd..a54eb32bf 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -118,7 +118,7 @@ class RollingStrategy(OnlineStrategy): def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None): """ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models. - + Assumption: the models can be distinguished based on the model name and rolling test segments. If you do not want this assumption, please implement your method or use another rec_key_func. diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 3a8bd1f2c..9410c2b9c 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -98,7 +98,7 @@ class MergeCollector(Collector): def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None): """ Init MergeCollector. - + Args: collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector} process_list (List[Callable]): the list of processors or the instance of processor to process dict. diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 89059e9f8..174b4b9bf 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -20,7 +20,7 @@ def get_mongodb() -> Database: """ Get database in MongoDB, which means you need to declare the address and the name of a database at first. - + For example: Using qlib.init():