mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
format code
This commit is contained in:
@@ -113,7 +113,7 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
|
||||
@@ -16,9 +16,9 @@ class Ensemble:
|
||||
For example: {Rollinga_b: object, Rollingb_c: object} -> object
|
||||
|
||||
When calling this class:
|
||||
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
|
||||
ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
|
||||
|
||||
Returns:
|
||||
object: the ensemble object
|
||||
@@ -103,6 +103,7 @@ class AverageEnsemble(Ensemble):
|
||||
Returns:
|
||||
pd.DataFrame: the complete result of averaging and standardizing.
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
|
||||
|
||||
@@ -64,7 +64,7 @@ class Group:
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `_ens_func`.")
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, n_jobs:int=1, verbose:int=0, *args, **kwargs) -> dict:
|
||||
def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Group the ungrouped_dict into different groups.
|
||||
|
||||
|
||||
@@ -240,13 +240,13 @@ class DelayTrainerR(TrainerR):
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them
|
||||
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,6 @@ import dill
|
||||
from typing import Union
|
||||
|
||||
|
||||
|
||||
class Serializable:
|
||||
"""
|
||||
Serializable will change the behaviors of pickle.
|
||||
|
||||
@@ -41,8 +41,8 @@ class OnlineManager(Serializable):
|
||||
It also provides a history recording of which models are online at what time.
|
||||
"""
|
||||
|
||||
STATUS_SIMULATING = "simulating" # when calling `simulate`
|
||||
STATUS_NORMAL = "normal" # the normal status
|
||||
STATUS_SIMULATING = "simulating" # when calling `simulate`
|
||||
STATUS_NORMAL = "normal" # the normal status
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -210,7 +210,9 @@ class OnlineManager(Serializable):
|
||||
SIM_LOG_LEVEL = logging.INFO + 1
|
||||
SIM_LOG_NAME = "SIMULATE_INFO"
|
||||
|
||||
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}) -> Union[pd.Series, pd.DataFrame]:
|
||||
def simulate(
|
||||
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ class RollingStrategy(OnlineStrategy):
|
||||
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.
|
||||
|
||||
|
||||
Assumption: the models can be distinguished based on the model name and rolling test segments.
|
||||
If you do not want this assumption, please implement your method or use another rec_key_func.
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class MergeCollector(Collector):
|
||||
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
|
||||
"""
|
||||
Init MergeCollector.
|
||||
|
||||
|
||||
Args:
|
||||
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
|
||||
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
|
||||
|
||||
@@ -20,7 +20,7 @@ def get_mongodb() -> Database:
|
||||
|
||||
"""
|
||||
Get database in MongoDB, which means you need to declare the address and the name of a database at first.
|
||||
|
||||
|
||||
For example:
|
||||
|
||||
Using qlib.init():
|
||||
|
||||
Reference in New Issue
Block a user