From 7a79028a720bdccc4f9971212a24c1cd44659523 Mon Sep 17 00:00:00 2001 From: Young Date: Sat, 14 Nov 2020 08:23:19 +0000 Subject: [PATCH] fix some small bug --- qlib/data/data.py | 6 ++++-- qlib/data/dataset/handler.py | 9 +++++---- qlib/data/dataset/utils.py | 2 +- qlib/utils/__init__.py | 24 +++++++++++++++++++++++- qlib/workflow/expm.py | 10 ++++------ qlib/workflow/recorder.py | 6 +++--- 6 files changed, 40 insertions(+), 17 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 6298cfa85..8331b1802 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -664,9 +664,11 @@ class LocalExpressionProvider(ExpressionProvider): lft_etd, rght_etd = expression.get_extended_window_size() series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq) # Ensure that each column type is consistent - # FIXME: The stock data is currently float. If there is other types of data, this part needs to be re-implemented. + # FIXME: + # 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented. + # 2) The the precision should be configurable try: - series = series.astype(float) + series = series.astype(np.float32) except ValueError: pass if not series.empty: diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index f6c097d22..9812864af 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -99,10 +99,11 @@ class DataHandler(Serializable): self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time) # TODO: cache - CS_ALL = "__all" + CS_ALL = "__all" # return all columns with single-level index column + CS_RAW = "__raw" # return raw data with multi-level index column def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame: - if not isinstance(df.columns, pd.MultiIndex): + if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW: return df elif col_set == self.CS_ALL: return df.droplevel(axis=1, level=0) @@ -111,7 +112,7 @@ class DataHandler(Serializable): def fetch( self, - selector: Union[pd.Timestamp, slice, str], + selector: Union[pd.Timestamp, slice, str] = slice(None, None), level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = CS_ALL, squeeze: bool = False, @@ -371,7 +372,7 @@ class DataHandlerLP(DataHandler): def fetch( self, - selector: Union[pd.Timestamp, slice, str], + selector: Union[pd.Timestamp, slice, str] = slice(None, None), level: Union[str, int] = "datetime", col_set=DataHandler.CS_ALL, data_key: str = DK_I, diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 6eb00ffee..d82a2d5b5 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -52,4 +52,4 @@ def fetch_df_by_index( idx_slc = (selector, slice(None, None)) if get_level_index(df, level) == 1: idx_slc = idx_slc[1], idx_slc[0] - return df.loc(axis=0)[idx_slc] + return df.loc[pd.IndexSlice[idx_slc], ] # This could be faster than df.loc(axis=0)[idx_slc] diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index d9ae98bd5..c469829d2 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -185,7 +185,7 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict): if isinstance(config, dict): # raise AttributeError klass = getattr(module, config["class"]) - kwargs = config["kwargs"] + kwargs = config.get("kwargs", {}) elif isinstance(config, str): klass = getattr(module, config) kwargs = {} @@ -619,6 +619,28 @@ def exists_qlib_data(qlib_dir): return True +def lexsort_index(df: pd.DataFrame) -> pd.DataFrame: + """ + make the df index lexsorted + + df.sort_index() will take a lot of time even when `df.is_lexsorted() == True` + This function could avoid such case + + Parameters + ---------- + df : pd.DataFrame + + Returns + ------- + pd.DataFrame: + sorted dataframe + """ + if df.index.is_lexsorted(): + return df + else: + return df.sort_index() + + #################### Wrapper ##################### class Wrapper(object): """Data Provider Wrapper""" diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index ebf6aeb7f..04f9c080f 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -14,7 +14,7 @@ logger = get_module_logger("workflow", "INFO") class ExpManager: """ - This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow. + This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow. (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ @@ -34,10 +34,6 @@ class ExpManager: name of the active experiment. uri : str the current tracking URI. - artifact_location : str - the location to store all the artifacts. - nested : boolean - controls whether run is nested in parent run. Returns ------- @@ -99,7 +95,7 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `create_exp` method.") - def get_exp(self, experiment_id=None, experiment_name=None): + def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True): """ Retrieve an experiment by experiment_id from the backend store. @@ -107,6 +103,8 @@ class ExpManager: ---------- experiment_id : str the experiment id to return. + create : boolean + create the experiment if it does not exists Returns ------- diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 1adaa3f8a..e5ea8d07a 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -139,13 +139,13 @@ class Recorder: """ raise NotImplementedError(f"Please implement the `delete_tags` method.") - def list_artifacts(self, artifact_path=None): + def list_artifacts(self, artifact_path: str = None): """ List all the artifacts of a recorder. Parameters ---------- - artifact_path=None : str + artifact_path : str the relative path for the artifact to be stored in the URI. Returns @@ -186,7 +186,7 @@ class MLflowRecorder(Recorder): assert status in ["SCHEDULED", "RUNNING", "FINISHED", "FAILED"], f"The status type {status} is not supported." mlflow.end_run(status) self.end_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - if self.status is not "FINISHED": + if self.status != "FINISHED": self.status = status shutil.rmtree(self.temp_dir)