From c0e7cbc9830c1149c7ef0823553f3b15a0936df1 Mon Sep 17 00:00:00 2001 From: Jactus Date: Tue, 26 Jan 2021 12:33:36 +0800 Subject: [PATCH] Add filter_pipe API --- qlib/contrib/data/handler.py | 9 ++++++++- qlib/data/dataset/loader.py | 11 ++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 23e37a5e4..378a7ccb4 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -54,6 +54,7 @@ class Alpha360(DataHandlerLP): learn_processors=_DEFAULT_LEARN_PROCESSORS, fit_start_time=None, fit_end_time=None, + filter_pipe=None, **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) @@ -66,6 +67,7 @@ class Alpha360(DataHandlerLP): "feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config()), }, + "filter_pipe": filter_pipe, }, } @@ -138,6 +140,7 @@ class Alpha158(DataHandlerLP): fit_start_time=None, fit_end_time=None, process_type=DataHandlerLP.PTYPE_A, + filter_pipe=None, **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) @@ -146,7 +149,11 @@ class Alpha158(DataHandlerLP): data_loader = { "class": "QlibDataLoader", "kwargs": { - "config": {"feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config())}, + "config": { + "feature": self.get_feature_config(), + "label": kwargs.get("label", self.get_label_config()), + }, + "filter_pipe": filter_pipe, }, } super().__init__( diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 3b33ff749..096369736 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -10,7 +10,9 @@ import pandas as pd from typing import Tuple, Union from qlib.data import D -from qlib.utils import load_dataset +from qlib.data import filter as filter_module +from qlib.data.filter import BaseDFilter +from qlib.utils import load_dataset, init_instance_by_config class DataLoader(abc.ABC): @@ -145,6 +147,13 @@ class QlibDataLoader(DLWParser): swap_level : Whether to swap level of MultiIndex """ + if filter_pipe is not None: + assert isinstance(filter_pipe, list), "The type of `filter_pipe` must be list." + filter_pipe = [ + init_instance_by_config(fp, None if "module_path" in fp else filter_module, accept_types=BaseDFilter) + for fp in filter_pipe + ] + self.filter_pipe = filter_pipe self.swap_level = swap_level super().__init__(config)