diff --git a/qlib/contrib/data/utils/sepdf.py b/qlib/contrib/data/utils/sepdf.py index 9650b7729..58664c46c 100644 --- a/qlib/contrib/data/utils/sepdf.py +++ b/qlib/contrib/data/utils/sepdf.py @@ -75,6 +75,10 @@ class SepDataFrame: def copy(self, *args, **kwargs): return self.apply_each("copy", True, *args, **kwargs) + def _update_join(self): + if self.join not in self: + self.join = next(iter(self._df_dict.keys())) + def __getitem__(self, item): return self._df_dict[item] @@ -82,9 +86,16 @@ class SepDataFrame: # TODO: consider the join behavior self._df_dict[item] = df + def __delitem__(self, item: str): + del self._df_dict[item] + self._update_join() + def __contains__(self, item): return item in self._df_dict + def __len__(self): + return len(self._df_dict[self.join]) + def droplevel(self, *args, **kwargs): raise NotImplementedError(f"Please implement the `droplevel` method") @@ -125,21 +136,26 @@ class SDFLoc: return self._sdf[args] elif isinstance(args, (tuple, list)): new_df_dict = {k: self._sdf[k] for k in args} - return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0]) + return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True) else: raise NotImplementedError(f"This type of input is not supported") elif self.axis == 0: - return SepDataFrame({k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join) + return SepDataFrame( + {k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True + ) else: - ax0, *ax1 = args - if len(ax1) == 0: - ax1 = None df = self._sdf - if ax1 is not None: - df = df.loc(axis=1)[ax1] - if ax0 is not None: - df = df.loc(axis=0)[ax0] - return df + if isinstance(args, tuple): + ax0, *ax1 = args + if len(ax1) == 0: + ax1 = None + if ax1 is not None: + df = df.loc(axis=1)[ax1] + if ax0 is not None: + df = df.loc(axis=0)[ax0] + return df + else: + return df.loc(axis=0)[args] # Patch pandas DataFrame