diff --git a/qlib/portfolio/optimizer/enhanced_indexing.py b/qlib/portfolio/optimizer/enhanced_indexing.py index a0d0bc050..1f7de6cb4 100644 --- a/qlib/portfolio/optimizer/enhanced_indexing.py +++ b/qlib/portfolio/optimizer/enhanced_indexing.py @@ -70,7 +70,7 @@ class EnhancedIndexingOptimizer(BaseOptimizer): def __call__( self, - u: np.ndarray, + u: Union[np.ndarray, pd.Series], F: np.ndarray, covB: np.ndarray, varU: np.ndarray, @@ -80,7 +80,7 @@ class EnhancedIndexingOptimizer(BaseOptimizer): ) -> Union[np.ndarray, pd.Series]: """ Args: - u (np.ndarray): expected returns (a.k.a., alpha) + u (np.ndarray or pd.Series): expected returns (a.k.a., alpha) F, covB, varU (np.ndarray): see StructuredCovEstimator w0 (np.ndarray): initial weights (for turnover control) w_bench (np.ndarray): benchmark weights @@ -91,6 +91,10 @@ class EnhancedIndexingOptimizer(BaseOptimizer): """ assert inds_onehot is not None or self.inds_dev is None, "Industry onehot vector is required." + # transform dataframe into array + if isinstance(u, pd.Series): + u = u.values + # scale alpha to match volatility if self.scale_alpha: u = u / u.std()