diff --git a/examples/benchmarks/LightGBM/features_sample.py b/examples/benchmarks/LightGBM/features_sample.py index 95bdca51d..e06637b84 100644 --- a/examples/benchmarks/LightGBM/features_sample.py +++ b/examples/benchmarks/LightGBM/features_sample.py @@ -1,9 +1,16 @@ import datetime import pandas as pd +from qlib.data.inst_processor import InstProcessor -def resample_feature(df: pd.DataFrame) -> pd.DataFrame: - df = df.droplevel(level="instrument") - df = df.loc[df.index.time == datetime.time(13, 1)] - df.index = df.index.normalize() - return df + +class ResampleProcessor(InstProcessor): + def __init__(self, freq: str, hour: int, minute: int): + self.freq = freq + self.hour = hour + self.minute = minute + + def __call__(self, df: pd.DataFrame, *args, **kwargs): + df = df.loc[df.index.time == datetime.time(self.hour, self.minute)] + df.index = df.index.normalize() + return df diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml index b4074368e..9bedf241d 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml @@ -1,6 +1,5 @@ qlib_init: - provider_uri: "~/.qlib/qlib_data/cn_data" - backend_freq_config: + provider_uri: day: "~/.qlib/qlib_data/cn_data" 1min: "~/.qlib/qlib_data/cn_data_1min" region: cn @@ -19,13 +18,14 @@ data_handler_config: &data_handler_config # with label as reference sample_benchmark: label sample_config: - # using pandas.DataFrame.resample - feature: resample("1d", level="datetime").last() - # or - # using custom function, df.groupby(level="instrument").apply() -# feature: -# module_path: features_sample.py -# func: resample_feature + feature: + - class: ResampleProcessor + moudle_path: features_sample.py + kwargs: + freq: 1d + hour: 13 + minute: 1 + port_analysis_config: &port_analysis_config strategy: class: TopkDropoutStrategy diff --git a/qlib/data/cache.py b/qlib/data/cache.py index 4475a16f3..271343a01 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -384,7 +384,9 @@ class DatasetCache(BaseProviderCache): HDF_KEY = "df" - def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1): + def dataset( + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None + ): """Get feature dataset. .. note:: Same interface as `dataset` method in dataset provider @@ -395,13 +397,19 @@ class DatasetCache(BaseProviderCache): """ if disk_cache == 0: # skip cache - return self.provider.dataset(instruments, fields, start_time, end_time, freq) + return self.provider.dataset( + instruments, fields, start_time, end_time, freq, inst_processors=inst_processors + ) else: # use and replace cache try: - return self._dataset(instruments, fields, start_time, end_time, freq, disk_cache) + return self._dataset( + instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors + ) except NotImplementedError: - return self.provider.dataset(instruments, fields, start_time, end_time, freq) + return self.provider.dataset( + instruments, fields, start_time, end_time, freq, inst_processors=inst_processors + ) def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs): """Get dataset cache file uri. @@ -410,14 +418,18 @@ class DatasetCache(BaseProviderCache): """ raise NotImplementedError("Implement this function to match your own cache mechanism") - def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1): + def _dataset( + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None + ): """Get feature dataset using cache. Override this method to define how to get feature dataset corresponding to users' own cache mechanism. """ raise NotImplementedError("Implement this method if you want to use dataset feature cache") - def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1): + def _dataset_uri( + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None + ): """Get a uri of feature dataset using cache. specially: disk_cache=1 means using data set cache and return the uri of cache file. @@ -639,8 +651,8 @@ class DiskDatasetCache(DatasetCache): self.remote = kwargs.get("remote", False) @staticmethod - def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs): - return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache) + def _uri(instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs): + return hash_args(*DatasetCache.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors) def get_cache_dir(self, freq: str = None) -> Path: return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq) @@ -679,14 +691,29 @@ class DiskDatasetCache(DatasetCache): df = pd.DataFrame(columns=fields) return df - def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0): + def _dataset( + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None + ): if disk_cache == 0: # In this case, data_set cache is configured but will not be used. - return self.provider.dataset(instruments, fields, start_time, end_time, freq) - + return self.provider.dataset( + instruments, fields, start_time, end_time, freq, inst_processors=inst_processors + ) + # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date + if not inst_processors: + raise ValueError( + f"{self.__class__.__name__} does not support inst_processor. " + f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`" + ) _cache_uri = self._uri( - instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache + instruments=instruments, + fields=fields, + start_time=None, + end_time=None, + freq=freq, + disk_cache=disk_cache, + inst_processors=inst_processors, ) cache_path = self.get_cache_dir(freq).joinpath(_cache_uri) @@ -709,13 +736,19 @@ class DiskDatasetCache(DatasetCache): # cache unavailable, generate the cache with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): features = self.gen_dataset_cache( - cache_path=cache_path, instruments=instruments, fields=fields, freq=freq + cache_path=cache_path, + instruments=instruments, + fields=fields, + freq=freq, + inst_processors=inst_processors, ) if not features.empty: features = features.sort_index().loc(axis=0)[:, start_time:end_time] return features - def _dataset_uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0): + def _dataset_uri( + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None + ): if disk_cache == 0: # In this case, server only checks the expression cache. # The client will load the cache data by itself. @@ -723,9 +756,20 @@ class DiskDatasetCache(DatasetCache): LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq) return "" - + # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date + if not inst_processors: + raise ValueError( + f"{self.__class__.__name__} does not support inst_processor. " + f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`" + ) _cache_uri = self._uri( - instruments=instruments, fields=fields, start_time=None, end_time=None, freq=freq, disk_cache=disk_cache + instruments=instruments, + fields=fields, + start_time=None, + end_time=None, + freq=freq, + disk_cache=disk_cache, + inst_processors=inst_processors, ) cache_path = self.get_cache_dir(freq).joinpath(_cache_uri) @@ -737,7 +781,13 @@ class DiskDatasetCache(DatasetCache): else: # cache unavailable, generate the cache with CacheUtils.writer_lock(self.r, f"{str(C.get_data_path(freq))}:dataset-{_cache_uri}"): - self.gen_dataset_cache(cache_path=cache_path, instruments=instruments, fields=fields, freq=freq) + self.gen_dataset_cache( + cache_path=cache_path, + instruments=instruments, + fields=fields, + freq=freq, + inst_processors=inst_processors, + ) return _cache_uri class IndexManager: @@ -804,7 +854,7 @@ class DiskDatasetCache(DatasetCache): index_data += start_index return index_data - def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq): + def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=None): """gen_dataset_cache .. note:: This function does not consider the cache read write lock. Please @@ -839,6 +889,7 @@ class DiskDatasetCache(DatasetCache): :param instruments: The instruments to store the cache. :param fields: The fields to store the cache. :param freq: The freq to store the cache. + :param inst_processors: Instrument processors. :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function. """ @@ -852,7 +903,9 @@ class DiskDatasetCache(DatasetCache): # while running self.clear_cache(cache_path) - features = self.provider.dataset(instruments, fields, _calendar[0], _calendar[-1], freq) + features = self.provider.dataset( + instruments, fields, _calendar[0], _calendar[-1], freq, inst_processors=inst_processors + ) if features.empty: return features @@ -877,6 +930,7 @@ class DiskDatasetCache(DatasetCache): "fields": cache_columns, "freq": freq, "last_update": str(_calendar[-1]), # The last_update to store the cache + "inst_processors": inst_processors, # The last_update to store the cache }, "meta": {"last_visit": time.time(), "visits": 1}, } @@ -913,6 +967,7 @@ class DiskDatasetCache(DatasetCache): fields = d["info"]["fields"] freq = d["info"]["freq"] last_update_time = d["info"]["last_update"] + inst_processors = d["info"]["inst_processors"] index_data = im.get_index() self.logger.debug("Updating dataset: {}".format(d)) @@ -963,7 +1018,12 @@ class DiskDatasetCache(DatasetCache): ) data = self.provider.dataset( - instruments, fields, whole_calendar[current_index - rm_n_period], new_calendar[-1], freq + instruments, + fields, + whole_calendar[current_index - rm_n_period], + new_calendar[-1], + freq, + inst_processors=inst_processors, ) if not data.empty: @@ -1013,17 +1073,23 @@ class SimpleDatasetCache(DatasetCache): except KeyError as e: self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism") - def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs): + def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs): instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq) - return hash_args(instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path)) + return hash_args( + instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors + ) - def _dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1): + def _dataset( + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=None + ): if disk_cache == 0: # In this case, data_set cache is configured but will not be used. return self.provider.dataset(instruments, fields, start_time, end_time, freq) self.local_cache_path.mkdir(exist_ok=True, parents=True) cache_file = self.local_cache_path.joinpath( - self._uri(instruments, fields, start_time, end_time, freq, disk_cache=disk_cache) + self._uri( + instruments, fields, start_time, end_time, freq, disk_cache=disk_cache, inst_processors=inst_processors + ) ) gen_flag = False @@ -1039,7 +1105,9 @@ class SimpleDatasetCache(DatasetCache): gen_flag = True if gen_flag: - data = self.provider.dataset(instruments, normalize_cache_fields(fields), start_time, end_time, freq) + data = self.provider.dataset( + instruments, normalize_cache_fields(fields), start_time, end_time, freq, inst_processors=inst_processors + ) data.to_pickle(cache_file) return self.cache_to_origin_data(data, fields) @@ -1047,26 +1115,53 @@ class SimpleDatasetCache(DatasetCache): class DatasetURICache(DatasetCache): """Prepared cache mechanism for server.""" - def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs): - return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache) + def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=None, **kwargs): + return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors) - def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0): + def dataset( + self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=None + ): if "local" in C.dataset_provider.lower(): # use LocalDatasetProvider - return self.provider.dataset(instruments, fields, start_time, end_time, freq) + return self.provider.dataset( + instruments, fields, start_time, end_time, freq, inst_processors=inst_processors + ) if disk_cache == 0: # do not use data_set cache, load data from remote expression cache directly - return self.provider.dataset(instruments, fields, start_time, end_time, freq, disk_cache, return_uri=False) - + return self.provider.dataset( + instruments, + fields, + start_time, + end_time, + freq, + disk_cache, + return_uri=False, + inst_processors=inst_processors, + ) + # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date + if not inst_processors: + raise ValueError( + f"{self.__class__.__name__} does not support inst_processor. " + f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`" + ) # use ClientDatasetProvider - feature_uri = self._uri(instruments, fields, None, None, freq, disk_cache=disk_cache) + feature_uri = self._uri( + instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors + ) value, expire = MemCacheExpire.get_cache(H["f"], feature_uri) mnt_feature_uri = C.get_data_path(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri) if value is None or expire or not mnt_feature_uri.exists(): df, uri = self.provider.dataset( - instruments, fields, start_time, end_time, freq, disk_cache, return_uri=True + instruments, + fields, + start_time, + end_time, + freq, + disk_cache, + return_uri=True, + inst_processors=inst_processors, ) # cache uri MemCacheExpire.set_cache(H["f"], uri, uri) diff --git a/qlib/data/data.py b/qlib/data/data.py index 82f8c3745..31c4ac6a8 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -13,15 +13,26 @@ import bisect import numpy as np import pandas as pd from multiprocessing import Pool +from typing import Iterable, Union from .cache import H from ..config import C -from .ops import Operators -from ..log import get_module_logger -from ..utils import parse_field, hash_args, normalize_cache_fields, code_to_fname from .base import Feature +from .ops import Operators +from .inst_processor import InstProcessor + +from ..log import get_module_logger from .cache import DiskDatasetCache, DiskExpressionCache -from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path +from ..utils import ( + Wrapper, + init_instance_by_config, + register_wrapper, + get_module_by_module_path, + parse_field, + hash_args, + normalize_cache_fields, + code_to_fname, +) class ProviderBackendMixin: @@ -342,7 +353,7 @@ class DatasetProvider(abc.ABC): """ @abc.abstractmethod - def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"): + def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=None): """Get dataset data. Parameters @@ -357,6 +368,8 @@ class DatasetProvider(abc.ABC): end of the time range. freq : str time frequency. + inst_processors: Iterable[Union[dict, InstProcessor]] + the operations performed on each instrument Returns ---------- @@ -373,6 +386,7 @@ class DatasetProvider(abc.ABC): end_time=None, freq="day", disk_cache=1, + inst_processors=None, **kwargs, ): """Get task uri, used when generating rabbitmq task in qlib_server @@ -393,7 +407,8 @@ class DatasetProvider(abc.ABC): whether to skip(0)/use(1)/replace(2) disk_cache. """ - return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache) + # TODO: qlib-server support inst_processors + return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors) @staticmethod def get_instruments_d(instruments, freq): @@ -434,7 +449,7 @@ class DatasetProvider(abc.ABC): return [ExpressionD.get_expression_instance(f) for f in fields] @staticmethod - def dataset_processor(instruments_d, column_names, start_time, end_time, freq): + def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=None): """ Load and process the data, return the data set. - default using multi-kernel method. @@ -460,6 +475,7 @@ class DatasetProvider(abc.ABC): normalize_column_names, spans, C, + inst_processors, ), ) else: @@ -474,6 +490,7 @@ class DatasetProvider(abc.ABC): normalize_column_names, None, C, + inst_processors, ), ) @@ -495,7 +512,9 @@ class DatasetProvider(abc.ABC): return data @staticmethod - def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None): + def expression_calculator( + inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=None + ): """ Calculate the expressions for one instrument, return a df result. If the expression has been calculated before, load from cache. @@ -519,13 +538,17 @@ class DatasetProvider(abc.ABC): data.index = _calendar[data.index.values.astype(int)] data.index.names = ["datetime"] - if spans is None: - return data - else: + if spans is not None: mask = np.zeros(len(data), dtype=bool) for begin, end in spans: mask |= (data.index >= begin) & (data.index <= end) - return data[mask] + data = data[mask] + + if inst_processors is not None: + for _processor in inst_processors if isinstance(inst_processors, (list, tuple, set)) else [inst_processors]: + _processor = init_instance_by_config(_processor, accept_types=InstProcessor) + data = _processor(data) + return data class LocalCalendarProvider(CalendarProvider): @@ -689,7 +712,15 @@ class LocalDatasetProvider(DatasetProvider): def __init__(self): pass - def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day"): + def dataset( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + inst_processors=None, + ): instruments_d = self.get_instruments_d(instruments, freq) column_names = self.get_column_names(fields) cal = Cal.calendar(start_time, end_time, freq) @@ -698,7 +729,9 @@ class LocalDatasetProvider(DatasetProvider): start_time = cal[0] end_time = cal[-1] - data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq) + data = self.dataset_processor( + instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors + ) return data @@ -841,6 +874,7 @@ class ClientDatasetProvider(DatasetProvider): freq="day", disk_cache=0, return_uri=False, + inst_processors=None, ): if Inst.get_inst_type(instruments) == Inst.DICT: get_module_logger("data").warning( @@ -893,6 +927,13 @@ class ClientDatasetProvider(DatasetProvider): - using single-process implementation. """ + # TODO: support inst_processors, need to change the code of qlib-server at the same time + # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date + if not inst_processors: + raise ValueError( + f"{self.__class__.__name__} does not support inst_processor. " + f"Please use `D.features(disk_cache=0)` or `qlib.init(dataset_cache=None)`" + ) self.conn.send_request( request_type="feature", request_content={ @@ -950,6 +991,7 @@ class BaseProvider: end_time=None, freq="day", disk_cache=None, + inst_processors=None, ): """ Parameters: @@ -964,9 +1006,11 @@ class BaseProvider: disk_cache = C.default_disk_cache if disk_cache is None else disk_cache fields = list(fields) # In case of tuple. try: - return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache) + return DatasetD.dataset( + instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors + ) except TypeError: - return DatasetD.dataset(instruments, fields, start_time, end_time, freq) + return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors) class LocalProvider(BaseProvider): diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index beef5c9fb..0b77d3a1d 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -181,7 +181,7 @@ class QlibDataLoader(DLWParser): self.freq = freq # sample - self.sample_config = sample_config + self.sample_config = sample_config if sample_config is not None else {} self.sample_benchmark = sample_benchmark self.can_sample = False super().__init__(config) @@ -192,30 +192,12 @@ class QlibDataLoader(DLWParser): for _gp in config.keys(): if _gp not in freq: raise ValueError(f"freq(={freq}) missing group(={_gp})") - if len(set(freq.values())) == 1: - self.freq = list(freq.values())[0] - else: - assert self.sample_config, f"freq(={self.freq}), sample_config cannot be None/empty" - assert isinstance(self.sample_config, dict), f"sample_config(={self.sample_config}) must be dict" - assert ( - self.sample_benchmark and self.sample_benchmark in self.fields - ), f"sample_benchmark not to specification" - self.can_sample = True - - def _get_sample_method(self, gp_name: str) -> Union[str, Type]: - _method = self.sample_config.get(gp_name, None) - if _method is None: - return _method - if isinstance(_method, str): - # pandas.DataFrame.resample - if not _method.startswith("resample"): - raise ValueError(f"sample method error, only pandas.DataFrame.resample is supported") - elif isinstance(_method, dict): - # module_path && func name - _method, _ = get_callable_kwargs(_method) - else: - raise TypeError(f"sample_method only supports [str, dict], currently it is {_method}") - return _method + assert sample_config, f"freq(={self.freq}), sample_config(={sample_config}) cannot be None/empty" + assert isinstance(sample_config, dict), f"sample_config(={sample_config}) must be dict" + assert ( + self.sample_benchmark and self.sample_benchmark in self.fields + ), f"sample_benchmark not to specification" + self.can_sample = True def load_group_df( self, @@ -235,17 +217,10 @@ class QlibDataLoader(DLWParser): warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") freq = self.freq[gp_name] if self.can_sample else self.freq - df = D.features(instruments, exprs, start_time, end_time, freq) + df = D.features( + instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.sample_config.get(freq, None) + ) df.columns = names - - if self.can_sample and self.sample_benchmark != gp_name: - sample_method = self._get_sample_method(gp_name) - if sample_method is None: - warnings.warn(f"{gp_name} sample_method is None") - if isinstance(sample_method, str): - df = eval(f"df.groupby(level='instrument').{sample_method}") - else: - df = df.groupby(level="instrument").apply(sample_method) if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df @@ -256,11 +231,13 @@ class QlibDataLoader(DLWParser): grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp) for grp, (exprs, names) in self.fields.items() } - for grp, _df in group.items(): - if grp == self.sample_benchmark: - continue - else: - group[grp] = _df.reindex(group[self.sample_benchmark].index) + if self.can_sample: + # reindex: alignment to index of sample_benchmark + for grp, _df in group.items(): + if grp == self.sample_benchmark: + continue + else: + group[grp] = _df.reindex(group[self.sample_benchmark].index) df = pd.concat(group, axis=1) else: exprs, names = self.fields diff --git a/qlib/data/inst_processor.py b/qlib/data/inst_processor.py new file mode 100644 index 000000000..34c2d517b --- /dev/null +++ b/qlib/data/inst_processor.py @@ -0,0 +1,30 @@ +import abc +import pandas as pd + + +class InstProcessor: + @abc.abstractmethod + def __call__(self, df: pd.DataFrame, *args, **kwargs): + """ + process the data + + NOTE: **The processor could change the content of `df` inplace !!!!! ** + User should keep a copy of data outside + + Parameters + ---------- + df : pd.DataFrame + The raw_df of handler or result from previous processor. + """ + pass + + +class ResampleProcessor(InstProcessor): + """resample data""" + + def __init__(self, freq: str, func: str, *args, **kwargs): + self.freq = freq + self.func = func + + def __call__(self, df: pd.DataFrame, *args, **kwargs): + return getattr(df.resample(self.freq, level="datetime"), self.func)().dropna(how="all") diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 4189f8e61..f801b125c 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -1,10 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from pathlib import Path import pickle -import typing import dill +from pathlib import Path from typing import Union @@ -18,6 +17,7 @@ class Serializable: pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python. default_dump_all = False # if dump all things + FLAG_KEY = "_qlib_serial_flag" def __init__(self): self._dump_all = self.default_dump_all @@ -45,8 +45,6 @@ class Serializable: """ return getattr(self, "_exclude", []) - FLAG_KEY = "_qlib_serial_flag" - def config(self, dump_all: bool = None, exclude: list = None, recursive=False): """ configure the serializable object