From d44c5bb2b2bfbffccb1959ea1e376f8c0dcb71b9 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Wed, 20 Jan 2021 21:14:03 +0800 Subject: [PATCH 1/5] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 735e080a0..d3f8170bc 100644 --- a/README.md +++ b/README.md @@ -312,6 +312,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into # Related Reports +- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/) - [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ) - [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ) - [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ) From fc81a393170306d508329f8d44b0b588c5d9ca9c Mon Sep 17 00:00:00 2001 From: Young Date: Wed, 20 Jan 2021 11:33:29 +0000 Subject: [PATCH 2/5] Add dataset standalone usage example --- examples/workflow_by_code.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index b8cf3f935..ea9c70083 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -106,6 +106,11 @@ if __name__ == "__main__": model = init_instance_by_config(task["model"]) dataset = init_instance_by_config(task["dataset"]) + # NOTE: This line is optional + # It demonstrates that the dataset can be used standalone. + example_df = dataset.prepare("train") + print(example_df.head()) + # start exp with R.start(experiment_name="workflow"): R.log_params(**flatten_dict(task)) From e85646762cd4eb05009a3564197b614f39c004ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E9=9B=AA?= Date: Wed, 20 Jan 2021 21:58:26 +0800 Subject: [PATCH 3/5] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 0af365d7b..5b3745a02 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.pyc +*.pyd *.so *.ipynb .ipynb_checkpoints From 5ad1b4cc3376ca88ba90e12dc4131d4a67519eaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E9=9B=AA?= Date: Wed, 20 Jan 2021 19:52:41 +0800 Subject: [PATCH 4/5] for IDE auto-complete with global Wrapper R, D, Cal, Inst, FeatureD, ExpressionD, DatasetD, D --- qlib/data/data.py | 200 +++++++++++++++++++++++++++++--------- qlib/workflow/__init__.py | 15 ++- 2 files changed, 166 insertions(+), 49 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 3dcb22699..671f9ce58 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -25,7 +25,12 @@ from ..log import get_module_logger from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache -from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path +from ..utils import ( + Wrapper, + init_instance_by_config, + register_wrapper, + get_module_by_module_path, +) class CalendarProvider(abc.ABC): @@ -54,7 +59,9 @@ class CalendarProvider(abc.ABC): list calendar list """ - raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method") + raise NotImplementedError( + "Subclass of CalendarProvider must implement `calendar` method" + ) def locate_index(self, start_time, end_time, freq, future): """Locate the start time index and end time index in a calendar under certain frequency. @@ -176,7 +183,9 @@ class InstrumentProvider(abc.ABC): return config @abc.abstractmethod - def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def list_instruments( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): """List the instruments based on a certain stockpool config. Parameters @@ -195,9 +204,13 @@ class InstrumentProvider(abc.ABC): dict or list instruments list or dictionary with time spans """ - raise NotImplementedError("Subclass of InstrumentProvider must implement `list_instruments` method") + raise NotImplementedError( + "Subclass of InstrumentProvider must implement `list_instruments` method" + ) - def _uri(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def _uri( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): return hash_args(instruments, start_time, end_time, freq, as_list) # instruments type @@ -221,10 +234,16 @@ class InstrumentProvider(abc.ABC): _df_list = [] # FIXME: each process will read these files for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"): - _df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) + _df = pd.read_csv( + _path, + sep="\t", + names=["inst", "start_datetime", "end_datetime", "save_inst"], + ) _df_list.append(_df.iloc[:, [0, -1]]) df = pd.concat(_df_list, sort=False).sort_values("save_inst") - df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(axis=1, method="ffill") + df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna( + axis=1, method="ffill" + ) _instruments_map = df.set_index("inst").iloc[:, 0].to_dict() setattr(self, "_instruments_map", _instruments_map) return _instruments_map.get(instrument, instrument) @@ -258,7 +277,9 @@ class FeatureProvider(abc.ABC): pd.Series data of a certain feature """ - raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method") + raise NotImplementedError( + "Subclass of FeatureProvider must implement `feature` method" + ) class ExpressionProvider(abc.ABC): @@ -279,11 +300,14 @@ class ExpressionProvider(abc.ABC): self.expression_instance_cache[field] = expression except NameError as e: get_module_logger("data").exception( - "ERROR: field [%s] contains invalid operator/variable [%s]" % (str(field), str(e).split()[1]) + "ERROR: field [%s] contains invalid operator/variable [%s]" + % (str(field), str(e).split()[1]) ) raise except SyntaxError: - get_module_logger("data").exception("ERROR: field [%s] contains invalid syntax" % str(field)) + get_module_logger("data").exception( + "ERROR: field [%s] contains invalid syntax" % str(field) + ) raise return expression @@ -309,7 +333,9 @@ class ExpressionProvider(abc.ABC): pd.Series data of a certain expression """ - raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method") + raise NotImplementedError( + "Subclass of ExpressionProvider must implement `Expression` method" + ) class DatasetProvider(abc.ABC): @@ -340,7 +366,9 @@ class DatasetProvider(abc.ABC): pd.DataFrame a pandas dataframe with index. """ - raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method") + raise NotImplementedError( + "Subclass of DatasetProvider must implement `Dataset` method" + ) def _uri( self, @@ -370,7 +398,9 @@ class DatasetProvider(abc.ABC): whether to skip(0)/use(1)/replace(2) disk_cache. """ - return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache) + return DiskDatasetCache._uri( + instruments, fields, start_time, end_time, freq, disk_cache + ) @staticmethod def get_instruments_d(instruments, freq): @@ -382,7 +412,9 @@ class DatasetProvider(abc.ABC): if isinstance(instruments, dict): if "market" in instruments: # dict of stockpool config - instruments_d = Inst.list_instruments(instruments=instruments, freq=freq, as_list=False) + instruments_d = Inst.list_instruments( + instruments=instruments, freq=freq, as_list=False + ) else: # dict of instruments and timestamp instruments_d = instruments @@ -472,7 +504,9 @@ class DatasetProvider(abc.ABC): return data @staticmethod - def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None): + def expression_calculator( + inst, start_time, end_time, freq, column_names, spans=None, g_config=None + ): """ Calculate the expressions for one instrument, return a df result. If the expression has been calculated before, load from cache. @@ -537,7 +571,9 @@ class LocalCalendarProvider(CalendarProvider): fname = self._uri_cal.format(freq + "_future") # if future calendar not exists, return current calendar if not os.path.exists(fname): - get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") + get_module_logger("data").warning( + f"{freq}_future.txt not exists, return current calendar!" + ) fname = self._uri_cal.format(freq) else: fname = self._uri_cal.format(freq) @@ -588,14 +624,20 @@ class LocalInstrumentProvider(InstrumentProvider): if not os.path.exists(fname): raise ValueError("instruments not exists for market " + market) _instruments = dict() - df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"]) + df = pd.read_csv( + fname, + sep="\t", + names=["inst", "start_datetime", "end_datetime", "save_inst"], + ) df["start_datetime"] = pd.to_datetime(df["start_datetime"]) df["end_datetime"] = pd.to_datetime(df["end_datetime"]) for row in df.itertuples(index=False): _instruments.setdefault(row[0], []).append((row[1], row[2])) return _instruments - def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def list_instruments( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): market = instruments["market"] if market in H["i"]: _instruments = H["i"][market] @@ -616,14 +658,20 @@ class LocalInstrumentProvider(InstrumentProvider): ) for inst, spans in _instruments.items() } - _instruments_filtered = {key: value for key, value in _instruments_filtered.items() if value} + _instruments_filtered = { + key: value for key, value in _instruments_filtered.items() if value + } # filter filter_pipe = instruments["filter_pipe"] for filter_config in filter_pipe: from . import filter as F - filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config) - _instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq) + filter_t = getattr(F, filter_config["filter_type"]).from_config( + filter_config + ) + _instruments_filtered = filter_t( + _instruments_filtered, start_time, end_time, freq + ) # as list if as_list: return list(_instruments_filtered) @@ -650,7 +698,9 @@ class LocalFeatureProvider(FeatureProvider): instrument = Inst.convert_instruments(instrument) uri_data = self._uri_data.format(instrument.lower(), field, freq) if not os.path.exists(uri_data): - get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field)) + get_module_logger("data").warning( + "WARN: data not found for %s.%s" % (instrument, field) + ) return pd.Series(dtype=np.float32) # raise ValueError('uri_data not found: ' + uri_data) # load @@ -671,9 +721,13 @@ class LocalExpressionProvider(ExpressionProvider): expression = self.get_expression_instance(field) start_time = pd.Timestamp(start_time) end_time = pd.Timestamp(end_time) - _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False) + _, _, start_index, end_index = Cal.locate_index( + start_time, end_time, freq, future=False + ) lft_etd, rght_etd = expression.get_extended_window_size() - series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq) + series = expression.load( + instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq + ) # Ensure that each column type is consistent # FIXME: # 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented. @@ -705,12 +759,16 @@ class LocalDatasetProvider(DatasetProvider): start_time = cal[0] end_time = cal[-1] - data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq) + data = self.dataset_processor( + instruments_d, column_names, start_time, end_time, freq + ) return data @staticmethod - def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq="day"): + def multi_cache_walker( + instruments, fields, start_time=None, end_time=None, freq="day" + ): """ This method is used to prepare the expression cache for the client. Then the client will load the data from expression cache by itself. @@ -778,7 +836,9 @@ class ClientCalendarProvider(CalendarProvider): "future": future, }, msg_queue=self.queue, - msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content], + msg_proc_func=lambda response_content: [ + pd.Timestamp(c) for c in response_content + ], ) result = self.queue.get(timeout=C["timeout"]) return result @@ -797,11 +857,14 @@ class ClientInstrumentProvider(InstrumentProvider): def set_conn(self, conn): self.conn = conn - def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def list_instruments( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): def inst_msg_proc_func(response_content): if isinstance(response_content, dict): instrument = { - i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] for i, t in response_content.items() + i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] + for i, t in response_content.items() } else: instrument = response_content @@ -887,7 +950,9 @@ class ClientDatasetProvider(DatasetProvider): start_time = cal[0] end_time = cal[-1] - data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq) + data = self.dataset_processor( + instruments_d, column_names, start_time, end_time, freq + ) if return_uri: return data, feature_uri else: @@ -919,8 +984,12 @@ class ClientDatasetProvider(DatasetProvider): get_module_logger("data").debug("get result") try: # pre-mound nfs, used for demo - mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri) - df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) + mnt_feature_uri = os.path.join( + C.get_data_path(), C.dataset_cache_dir_name, feature_uri + ) + df = DiskDatasetCache.read_data_from_cache( + mnt_feature_uri, start_time, end_time, fields + ) get_module_logger("data").debug("finish slicing data") if return_uri: return df, feature_uri @@ -938,7 +1007,9 @@ class BaseProvider: def calendar(self, start_time=None, end_time=None, freq="day", future=False): return Cal.calendar(start_time, end_time, freq, future=future) - def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None): + def instruments( + self, market="all", filter_pipe=None, start_time=None, end_time=None + ): if start_time is not None or end_time is not None: get_module_logger("Provider").warning( "The instruments corresponds to a stock pool. " @@ -946,7 +1017,9 @@ class BaseProvider: ) return InstrumentProvider.instruments(market, filter_pipe) - def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def list_instruments( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): return Inst.list_instruments(instruments, start_time, end_time, freq, as_list) def features( @@ -972,7 +1045,9 @@ class BaseProvider: if C.disable_disk_cache: disk_cache = False try: - return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache) + return DatasetD.dataset( + instruments, fields, start_time, end_time, freq, disk_cache + ) except TypeError: return DatasetD.dataset(instruments, fields, start_time, end_time, freq) @@ -993,7 +1068,9 @@ class LocalProvider(BaseProvider): elif type == "feature": return DatasetD._uri(**kwargs) - def features_uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1): + def features_uri( + self, instruments, fields, start_time, end_time, freq, disk_cache=1 + ): """features_uri Return the uri of the generated cache of features/dataset @@ -1005,7 +1082,9 @@ class LocalProvider(BaseProvider): :param end_time: :param freq: """ - return DatasetD._dataset_uri(instruments, fields, start_time, end_time, freq, disk_cache) + return DatasetD._dataset_uri( + instruments, fields, start_time, end_time, freq, disk_cache + ) class ClientProvider(BaseProvider): @@ -1035,12 +1114,31 @@ class ClientProvider(BaseProvider): DatasetD.set_conn(self.client) -Cal = Wrapper() -Inst = Wrapper() -FeatureD = Wrapper() -ExpressionD = Wrapper() -DatasetD = Wrapper() -D = Wrapper() +import sys + +if sys.version_info >= (3, 9): + from typing import Annotated + + CalendarProviderWrapper = Annotated[CalendarProvider, Wrapper] + InstrumentProviderWrapper = Annotated[InstrumentProvider, Wrapper] + FeatureProviderWrapper = Annotated[FeatureProvider, Wrapper] + ExpressionProviderWrapper = Annotated[ExpressionProvider, Wrapper] + DatasetProviderWrapper = Annotated[DatasetProvider, Wrapper] + BaseProviderWrapper = Annotated[BaseProvider, Wrapper] +else: + CalendarProviderWrapper = CalendarProvider + InstrumentProviderWrapper = InstrumentProvider + FeatureProviderWrapper = FeatureProvider + ExpressionProviderWrapper = ExpressionProvider + DatasetProviderWrapper = DatasetProvider + BaseProviderWrapper = BaseProvider + +Cal: CalendarProviderWrapper = Wrapper() +Inst: InstrumentProviderWrapper = Wrapper() +FeatureD: FeatureProviderWrapper = Wrapper() +ExpressionD: ExpressionProviderWrapper = Wrapper() +DatasetD: DatasetProviderWrapper = Wrapper() +D: BaseProviderWrapper = Wrapper() def register_all_wrappers(): @@ -1050,7 +1148,9 @@ def register_all_wrappers(): _calendar_provider = init_instance_by_config(C.calendar_provider, module) if getattr(C, "calendar_cache", None) is not None: - _calendar_provider = init_instance_by_config(C.calendar_cache, module, provide=_calendar_provider) + _calendar_provider = init_instance_by_config( + C.calendar_cache, module, provide=_calendar_provider + ) register_wrapper(Cal, _calendar_provider, "qlib.data") logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}") @@ -1066,13 +1166,19 @@ def register_all_wrappers(): # This provider is unnecessary in client provider _eprovider = init_instance_by_config(C.expression_provider, module) if getattr(C, "expression_cache", None) is not None: - _eprovider = init_instance_by_config(C.expression_cache, module, provider=_eprovider) + _eprovider = init_instance_by_config( + C.expression_cache, module, provider=_eprovider + ) register_wrapper(ExpressionD, _eprovider, "qlib.data") - logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}") + logger.debug( + f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}" + ) _dprovider = init_instance_by_config(C.dataset_provider, module) if getattr(C, "dataset_cache", None) is not None: - _dprovider = init_instance_by_config(C.dataset_cache, module, provider=_dprovider) + _dprovider = init_instance_by_config( + C.dataset_cache, module, provider=_dprovider + ) register_wrapper(DatasetD, _dprovider, "qlib.data") logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}") diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index e65bfb03f..24e9cd22c 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -38,7 +38,9 @@ class QlibRecorder: try: yield run except Exception as e: - self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong + self.end_exp( + Recorder.STATUS_FA + ) # end the experiment if something went wrong raise e self.end_exp(Recorder.STATUS_FI) @@ -461,5 +463,14 @@ class QlibRecorder: self.get_exp().get_recorder().set_tags(**kwargs) +import sys + +if sys.version_info >= (3, 9): + from typing import Annotated + + QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper] +else: + QlibRecorderWrapper = QlibRecorder + # global record -R = Wrapper() +R: QlibRecorderWrapper = Wrapper() From 784e73bceb91c34ad61dbf4fb6ec9a6e0b1523f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E9=9B=AA?= Date: Wed, 20 Jan 2021 20:29:59 +0800 Subject: [PATCH 5/5] black formatting --- qlib/data/data.py | 150 ++++++++++---------------------------- qlib/workflow/__init__.py | 4 +- 2 files changed, 39 insertions(+), 115 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 671f9ce58..5dad558e7 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -59,9 +59,7 @@ class CalendarProvider(abc.ABC): list calendar list """ - raise NotImplementedError( - "Subclass of CalendarProvider must implement `calendar` method" - ) + raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method") def locate_index(self, start_time, end_time, freq, future): """Locate the start time index and end time index in a calendar under certain frequency. @@ -183,9 +181,7 @@ class InstrumentProvider(abc.ABC): return config @abc.abstractmethod - def list_instruments( - self, instruments, start_time=None, end_time=None, freq="day", as_list=False - ): + def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): """List the instruments based on a certain stockpool config. Parameters @@ -204,13 +200,9 @@ class InstrumentProvider(abc.ABC): dict or list instruments list or dictionary with time spans """ - raise NotImplementedError( - "Subclass of InstrumentProvider must implement `list_instruments` method" - ) + raise NotImplementedError("Subclass of InstrumentProvider must implement `list_instruments` method") - def _uri( - self, instruments, start_time=None, end_time=None, freq="day", as_list=False - ): + def _uri(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): return hash_args(instruments, start_time, end_time, freq, as_list) # instruments type @@ -241,9 +233,7 @@ class InstrumentProvider(abc.ABC): ) _df_list.append(_df.iloc[:, [0, -1]]) df = pd.concat(_df_list, sort=False).sort_values("save_inst") - df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna( - axis=1, method="ffill" - ) + df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(axis=1, method="ffill") _instruments_map = df.set_index("inst").iloc[:, 0].to_dict() setattr(self, "_instruments_map", _instruments_map) return _instruments_map.get(instrument, instrument) @@ -277,9 +267,7 @@ class FeatureProvider(abc.ABC): pd.Series data of a certain feature """ - raise NotImplementedError( - "Subclass of FeatureProvider must implement `feature` method" - ) + raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method") class ExpressionProvider(abc.ABC): @@ -300,14 +288,11 @@ class ExpressionProvider(abc.ABC): self.expression_instance_cache[field] = expression except NameError as e: get_module_logger("data").exception( - "ERROR: field [%s] contains invalid operator/variable [%s]" - % (str(field), str(e).split()[1]) + "ERROR: field [%s] contains invalid operator/variable [%s]" % (str(field), str(e).split()[1]) ) raise except SyntaxError: - get_module_logger("data").exception( - "ERROR: field [%s] contains invalid syntax" % str(field) - ) + get_module_logger("data").exception("ERROR: field [%s] contains invalid syntax" % str(field)) raise return expression @@ -333,9 +318,7 @@ class ExpressionProvider(abc.ABC): pd.Series data of a certain expression """ - raise NotImplementedError( - "Subclass of ExpressionProvider must implement `Expression` method" - ) + raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method") class DatasetProvider(abc.ABC): @@ -366,9 +349,7 @@ class DatasetProvider(abc.ABC): pd.DataFrame a pandas dataframe with index. """ - raise NotImplementedError( - "Subclass of DatasetProvider must implement `Dataset` method" - ) + raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method") def _uri( self, @@ -398,9 +379,7 @@ class DatasetProvider(abc.ABC): whether to skip(0)/use(1)/replace(2) disk_cache. """ - return DiskDatasetCache._uri( - instruments, fields, start_time, end_time, freq, disk_cache - ) + return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache) @staticmethod def get_instruments_d(instruments, freq): @@ -412,9 +391,7 @@ class DatasetProvider(abc.ABC): if isinstance(instruments, dict): if "market" in instruments: # dict of stockpool config - instruments_d = Inst.list_instruments( - instruments=instruments, freq=freq, as_list=False - ) + instruments_d = Inst.list_instruments(instruments=instruments, freq=freq, as_list=False) else: # dict of instruments and timestamp instruments_d = instruments @@ -504,9 +481,7 @@ class DatasetProvider(abc.ABC): return data @staticmethod - def expression_calculator( - inst, start_time, end_time, freq, column_names, spans=None, g_config=None - ): + def expression_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None): """ Calculate the expressions for one instrument, return a df result. If the expression has been calculated before, load from cache. @@ -571,9 +546,7 @@ class LocalCalendarProvider(CalendarProvider): fname = self._uri_cal.format(freq + "_future") # if future calendar not exists, return current calendar if not os.path.exists(fname): - get_module_logger("data").warning( - f"{freq}_future.txt not exists, return current calendar!" - ) + get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") fname = self._uri_cal.format(freq) else: fname = self._uri_cal.format(freq) @@ -635,9 +608,7 @@ class LocalInstrumentProvider(InstrumentProvider): _instruments.setdefault(row[0], []).append((row[1], row[2])) return _instruments - def list_instruments( - self, instruments, start_time=None, end_time=None, freq="day", as_list=False - ): + def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): market = instruments["market"] if market in H["i"]: _instruments = H["i"][market] @@ -658,20 +629,14 @@ class LocalInstrumentProvider(InstrumentProvider): ) for inst, spans in _instruments.items() } - _instruments_filtered = { - key: value for key, value in _instruments_filtered.items() if value - } + _instruments_filtered = {key: value for key, value in _instruments_filtered.items() if value} # filter filter_pipe = instruments["filter_pipe"] for filter_config in filter_pipe: from . import filter as F - filter_t = getattr(F, filter_config["filter_type"]).from_config( - filter_config - ) - _instruments_filtered = filter_t( - _instruments_filtered, start_time, end_time, freq - ) + filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config) + _instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq) # as list if as_list: return list(_instruments_filtered) @@ -698,9 +663,7 @@ class LocalFeatureProvider(FeatureProvider): instrument = Inst.convert_instruments(instrument) uri_data = self._uri_data.format(instrument.lower(), field, freq) if not os.path.exists(uri_data): - get_module_logger("data").warning( - "WARN: data not found for %s.%s" % (instrument, field) - ) + get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field)) return pd.Series(dtype=np.float32) # raise ValueError('uri_data not found: ' + uri_data) # load @@ -721,13 +684,9 @@ class LocalExpressionProvider(ExpressionProvider): expression = self.get_expression_instance(field) start_time = pd.Timestamp(start_time) end_time = pd.Timestamp(end_time) - _, _, start_index, end_index = Cal.locate_index( - start_time, end_time, freq, future=False - ) + _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False) lft_etd, rght_etd = expression.get_extended_window_size() - series = expression.load( - instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq - ) + series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq) # Ensure that each column type is consistent # FIXME: # 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented. @@ -759,16 +718,12 @@ class LocalDatasetProvider(DatasetProvider): start_time = cal[0] end_time = cal[-1] - data = self.dataset_processor( - instruments_d, column_names, start_time, end_time, freq - ) + data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq) return data @staticmethod - def multi_cache_walker( - instruments, fields, start_time=None, end_time=None, freq="day" - ): + def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq="day"): """ This method is used to prepare the expression cache for the client. Then the client will load the data from expression cache by itself. @@ -836,9 +791,7 @@ class ClientCalendarProvider(CalendarProvider): "future": future, }, msg_queue=self.queue, - msg_proc_func=lambda response_content: [ - pd.Timestamp(c) for c in response_content - ], + msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content], ) result = self.queue.get(timeout=C["timeout"]) return result @@ -857,14 +810,11 @@ class ClientInstrumentProvider(InstrumentProvider): def set_conn(self, conn): self.conn = conn - def list_instruments( - self, instruments, start_time=None, end_time=None, freq="day", as_list=False - ): + def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): def inst_msg_proc_func(response_content): if isinstance(response_content, dict): instrument = { - i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] - for i, t in response_content.items() + i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] for i, t in response_content.items() } else: instrument = response_content @@ -950,9 +900,7 @@ class ClientDatasetProvider(DatasetProvider): start_time = cal[0] end_time = cal[-1] - data = self.dataset_processor( - instruments_d, column_names, start_time, end_time, freq - ) + data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq) if return_uri: return data, feature_uri else: @@ -984,12 +932,8 @@ class ClientDatasetProvider(DatasetProvider): get_module_logger("data").debug("get result") try: # pre-mound nfs, used for demo - mnt_feature_uri = os.path.join( - C.get_data_path(), C.dataset_cache_dir_name, feature_uri - ) - df = DiskDatasetCache.read_data_from_cache( - mnt_feature_uri, start_time, end_time, fields - ) + mnt_feature_uri = os.path.join(C.get_data_path(), C.dataset_cache_dir_name, feature_uri) + df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) get_module_logger("data").debug("finish slicing data") if return_uri: return df, feature_uri @@ -1007,9 +951,7 @@ class BaseProvider: def calendar(self, start_time=None, end_time=None, freq="day", future=False): return Cal.calendar(start_time, end_time, freq, future=future) - def instruments( - self, market="all", filter_pipe=None, start_time=None, end_time=None - ): + def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None): if start_time is not None or end_time is not None: get_module_logger("Provider").warning( "The instruments corresponds to a stock pool. " @@ -1017,9 +959,7 @@ class BaseProvider: ) return InstrumentProvider.instruments(market, filter_pipe) - def list_instruments( - self, instruments, start_time=None, end_time=None, freq="day", as_list=False - ): + def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): return Inst.list_instruments(instruments, start_time, end_time, freq, as_list) def features( @@ -1045,9 +985,7 @@ class BaseProvider: if C.disable_disk_cache: disk_cache = False try: - return DatasetD.dataset( - instruments, fields, start_time, end_time, freq, disk_cache - ) + return DatasetD.dataset(instruments, fields, start_time, end_time, freq, disk_cache) except TypeError: return DatasetD.dataset(instruments, fields, start_time, end_time, freq) @@ -1068,9 +1006,7 @@ class LocalProvider(BaseProvider): elif type == "feature": return DatasetD._uri(**kwargs) - def features_uri( - self, instruments, fields, start_time, end_time, freq, disk_cache=1 - ): + def features_uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1): """features_uri Return the uri of the generated cache of features/dataset @@ -1082,9 +1018,7 @@ class LocalProvider(BaseProvider): :param end_time: :param freq: """ - return DatasetD._dataset_uri( - instruments, fields, start_time, end_time, freq, disk_cache - ) + return DatasetD._dataset_uri(instruments, fields, start_time, end_time, freq, disk_cache) class ClientProvider(BaseProvider): @@ -1148,9 +1082,7 @@ def register_all_wrappers(): _calendar_provider = init_instance_by_config(C.calendar_provider, module) if getattr(C, "calendar_cache", None) is not None: - _calendar_provider = init_instance_by_config( - C.calendar_cache, module, provide=_calendar_provider - ) + _calendar_provider = init_instance_by_config(C.calendar_cache, module, provide=_calendar_provider) register_wrapper(Cal, _calendar_provider, "qlib.data") logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}") @@ -1166,19 +1098,13 @@ def register_all_wrappers(): # This provider is unnecessary in client provider _eprovider = init_instance_by_config(C.expression_provider, module) if getattr(C, "expression_cache", None) is not None: - _eprovider = init_instance_by_config( - C.expression_cache, module, provider=_eprovider - ) + _eprovider = init_instance_by_config(C.expression_cache, module, provider=_eprovider) register_wrapper(ExpressionD, _eprovider, "qlib.data") - logger.debug( - f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}" - ) + logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}") _dprovider = init_instance_by_config(C.dataset_provider, module) if getattr(C, "dataset_cache", None) is not None: - _dprovider = init_instance_by_config( - C.dataset_cache, module, provider=_dprovider - ) + _dprovider = init_instance_by_config(C.dataset_cache, module, provider=_dprovider) register_wrapper(DatasetD, _dprovider, "qlib.data") logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}") diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 24e9cd22c..15faa0da1 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -38,9 +38,7 @@ class QlibRecorder: try: yield run except Exception as e: - self.end_exp( - Recorder.STATUS_FA - ) # end the experiment if something went wrong + self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong raise e self.end_exp(Recorder.STATUS_FI)