From 87cee85ceae13d9fc4d97e07bccda6d5c832cfb4 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 26 Nov 2020 00:55:26 +0800 Subject: [PATCH] Update docs and fix tabnet --- docs/component/data.rst | 234 ++++++++++++++---- docs/component/model.rst | 214 +++++++--------- docs/component/recorder.rst | 6 +- docs/reference/api.rst | 14 ++ docs/start/integration.rst | 103 ++++---- examples/benchmarks/HATS/README.md | 4 +- examples/benchmarks/TFT/README.md | 2 +- .../TabNet/workflow_config_tabnet.yaml | 2 +- qlib/contrib/evaluate.py | 84 +++---- .../analysis_model_performance.py | 16 +- .../analysis_position/cumulative_return.py | 6 +- .../report/analysis_position/rank_label.py | 6 +- .../report/analysis_position/report.py | 6 +- .../report/analysis_position/risk_analysis.py | 6 +- .../report/analysis_position/score_ic.py | 6 +- qlib/contrib/strategy/strategy.py | 70 +++--- qlib/data/base.py | 8 +- qlib/data/cache.py | 40 +-- qlib/data/client.py | 8 +- qlib/data/data.py | 92 +++---- qlib/data/dataset/__init__.py | 30 +-- qlib/data/dataset/handler.py | 38 +-- qlib/data/dataset/loader.py | 8 +- qlib/data/dataset/processor.py | 6 +- qlib/data/filter.py | 60 ++--- qlib/model/base.py | 6 +- qlib/model/riskmodel.py | 44 ++-- 27 files changed, 624 insertions(+), 495 deletions(-) diff --git a/docs/component/data.rst b/docs/component/data.rst index 9ef71a6cb..e14caff3e 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -1,7 +1,7 @@ .. _data: ================================ -Data Layer: Data Framework&Usage +Data Layer: Data Framework & Usage ================================ Introduction @@ -15,7 +15,9 @@ The introduction of ``Data Layer`` includes the following parts. - Data Preparation - Data API +- Data Loader - Data Handler +- Dataset - Cache - Data and Cache File Structure @@ -146,43 +148,161 @@ Filter To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_. - Reference ------------- To know more about ``Data API``, please refer to `Data API <../reference/api.html#data>`_. + +Data Loader +================= + +``Data Loader`` in ``Qlib`` is designed to load raw data from the original data source. It will be loaded and used in the ``Data Handler`` module. + +The ``QlibDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from the data source. + +Interface +------------ + +Here are some interfaces of the ``QlibDataLoader`` class: + +- `load(instruments, start_time=None, end_time=None)` + - This method loads the data as pd.DataFrame + - Parameters: + - `instruments` : str or dict + it can either be the market name or the config file of instruments generated by InstrumentProvider. + - `start_time` : str + start of the time range. + - `end_time` : str + end of the time range. + - Returns: + - The data being loaded with type `pd.DataFrame` + +- `load_group_df(instruments, exprs: list, names: list, start_time=None, end_time=None)` + - This method loads the dataframe for specific group. + - Parameters: + - `instruments` : str or dict + it can either be the market name or the config file of instruments generated by InstrumentProvider. + - `exprs` : list + the expressions to describe the content of the data. + - `names` : list + the name of the data. + - `start_time` : str + start of the time range. + - `end_time` : str + end of the time range. + - Returns: + - The queried data in type `pd.DataFrame`. + +API +----------- + +To know more about ``Data Loader``, please refer to `Data Loader API <../reference/api.html#module-qlib.data.dataset.loader>`_. + + Data Handler ================= -Users can use ``Data Handler`` in an automatic workflow by ``Estimator``, refer to `Estimator: Workflow Management `_ for more details. +The ``Data Handler`` module in ``Qlib`` is designed to handler those common data processing methods which will be used by most of the models. + +Users can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management `_ for more details. -Also, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data(standardization, remove NaN, etc.) and build datasets. It is a subclass of ``qlib.data.dataset.handler.DataHandlerLP``, which provides some interfaces as follows. Base Class & Interface ---------------------- -Qlib provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_, which provides the following interfaces: +In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets. -- `load_feature` - Implement the interface to load the data features. +In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some leanable ``Processors`` which can learn the parameters of data processing. When new data comes in, these `trained` ``Processors`` can then infer on the new data and thus processing real-time data in an efficient way. More information about ``Processors`` will be listed in the next subsection. -- `load_label` - Implement the interface to load the data labels and calculate the users' labels. +Here are some important interfaces that ``DataHandlerLP`` provides: -- `setup_processed_data` - Implement the interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on. +- `__init__(instruments=None, start_time=None, end_time=None, data_loader: Tuple[dict, str, DataLoader] = None, infer_processors=[], learn_processors=[], process_type=PTYPE_A, **kwargs)` + - Initialization of the class. + - Parameters: + - `infer_processors` : list + - list of of processors to generate data for inference + - example of : -Qlib also provides two functions to help users init the data handler, users can override them for users' needs. + .. code-block:: + + 1) classname & kwargs: + { + "class": "MinMaxNorm", + "kwargs": { + "fit_start_time": "20080101", + "fit_end_time": "20121231" + } + } + 2) Only classname: + "DropnaFeature" + 3) object instance of Processor -- `_init_raw_data` - Users can init the raw df, feature names, and label names of data handler in this function. - If the index of feature df and label df are not the same, users need to override this method to merge them (e.g. inner, left, right merge). + - `learn_processors` : list + similar to infer_processors, but for generating data for learning models + + - `process_type`: str + - PTYPE_I = 'independent' + - self._infer will processed by infer_processors + - self._learn will be processed by learn_processors + - PTYPE_A = 'append' + - self._infer will processed by infer_processors + - self._learn will be processed by infer_processors + learn_processors + - (e.g. self._infer processed by learn_processors ) + +- `fetch(selector: Union[pd.Timestamp, slice, str] = slice(None, None), level: Union[str, int] = "datetime", col_set=DataHandler.CS_ALL, data_key: str = DK_I)` + - This method fetches data from underlying data source + - Parameters: + - `selector` : Union[pd.Timestamp, slice, str] + describe how to select data by index. + - `level` : Union[str, int] + which index level to select the data. + - `col_set` : str + select a set of meaningful columns.(e.g. features, columns). + - `data_key` : str + The data to fetch: DK_*. + - Returns: + - The retrieved results in the type: `pd.DataFrame`. + +- `get_cols(col_set=DataHandler.CS_ALL, data_key: str = DK_I)` + - This method gets the column names. + - Parameters: + - `col_set` : str + select a set of meaningful columns.(e.g. features, columns). + - `data_key` : str + the data to fetch: DK_*. + - Returns: + - A list of column names. If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass. If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`. +Processor +---------- + +The ``Processor`` module in ``Qlib`` is designed to be learnable and it is responsible for handling data processing such as `normalization` and `drop none/nan features/labels`. + +``Qlib`` provides the following ``Processors``: + +- ``DropnaProcessor``: `processor` that drops N/A features. +- ``DropnaLabel``: `processor` that drops N/A labels. +- ``TanhProcess``: `processor` that uses `tanh` to process noise data. +- ``ProcessInf``: `processor` that handles infinity values, it will be replaces by the mean of the column. +- ``Fillna``: `processor` that handles N/A values, which will fill the N/A value by 0 or other given number. +- ``MinMaxNorm``: `processor` that applies min-max normalization. +- ``ZscoreNorm``: `processor` that applies z-score normalization. +- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization. +- ``CSRankNorm``: `processor` that applies cross sectional rank normalization. + +Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link `_). + +API +--------- + +To know more about ``Processor``, please refer to `Processor API <../reference/api.html#module-qlib.data.dataset.processor>`_. + + Usage -------------- @@ -194,15 +314,12 @@ Usage - `get_rolling_data` - According to the start and end dates, and `rolling_period`, an iterator is returned, which can be used to traverse the features and labels used for rolling. - - - Example -------------- -``Data Handler`` can be run with ``estimator`` by modifying the configuration file, and can also be used as a single module. +``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module. -Know more about how to run ``Data Handler`` with ``Estimator``, please refer to `Estimator: Workflow Management `_ +Know more about how to run ``Data Handler`` with ``qrun``, please refer to `Workflow: Workflow Management `_ Qlib provides implemented data handler `Alpha158`. The following example shows how to run `Alpha158` as a single module. @@ -211,45 +328,70 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h .. code-block:: Python + import qlib from qlib.contrib.data.handler import Alpha158 - from qlib.contrib.model.gbdt import LGBModel - DATA_HANDLER_CONFIG = { - "dropna_label": True, - "start_date": "2007-01-01", - "end_date": "2020-08-01", - "market": "csi300", + data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": "csi300", } - TRAINER_CONFIG = { - "train_start_date": "2007-01-01", - "train_end_date": "2014-12-31", - "validate_start_date": "2015-01-01", - "validate_end_date": "2016-12-31", - "test_start_date": "2017-01-01", - "test_end_date": "2020-08-01", - } + if __name__ == "__main__": + qlib.init() + h = Alpha158(**data_handler_config) - exampleDataHandler = Alpha158(**DATA_HANDLER_CONFIG) + # get all the columns of the data + print(h.get_cols()) - # example of 'get_split_data' - x_train, y_train, x_validate, y_validate, x_test, y_test = exampleDataHandler.get_split_data(**TRAINER_CONFIG) + # fetch all the labels + print(h.fetch(col_set="label")) - # example of 'get_rolling_data' - - for (x_train, y_train, x_validate, y_validate, x_test, y_test) in exampleDataHandler.get_rolling_data(**TRAINER_CONFIG): - print(x_train, y_train, x_validate, y_validate, x_test, y_test) - - -.. note:: (x_train, y_train, x_validate, y_validate, x_test, y_test) can be used as arguments for the `fit`, `predic``, and `score` methods of the ``Interday Model`` , please refer to `Model `_. - -Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``. + # fetch all the features + print(h.fetch(col_set="feature")) API --------- To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.data.dataset.handler>`_. + +Dataset +================= + +The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing. + +The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the rights to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``DNN`` will break down on such data. + +The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most important interface of the class: + +- `prepare(segments: Union[List[str], Tuple[str], str, slice], col_set=DataHandler.CS_ALL, data_key=DataHandlerLP.DK_I, **kwargs)` + - This method prepares the data for learning and inference. + - Parameters: + - `segments` : Union[List[str], Tuple[str], str, slice] + Describe the scope of the data to be prepared + Here are some examples: + + - 'train' + + - ['train', 'valid'] + + - `col_set` : str + The col_set will be passed to self._handler when fetching data. + - `data_key` : str + The data to fetch: DK_* + Default is DK_I, which indicate fetching data for **inference**. + + +API +--------- + +To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_. + + + Cache ========== diff --git a/docs/component/model.rst b/docs/component/model.rst index 6a6b02f86..b4e341df8 100644 --- a/docs/component/model.rst +++ b/docs/component/model.rst @@ -7,7 +7,7 @@ Interday Model: Model Training & Prediction Introduction =================== -``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``Estimator``, please refer to `Estimator: Workflow Management `_. +``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management `_. Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Interday Model`` can be used as an independent module also. @@ -20,151 +20,125 @@ The base class provides the following interfaces: - `__init__(**kwargs)` - Initialization. - - If users use ``Estimator`` to start an `experiment`, the parameter of `__init__` method shoule be consistent with the hyperparameters in the configuration file. -- `fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs)` +- `fit(self, dataset, **kwargs)` - Train model. - Parameter: - - `x_train`, pd.DataFrame type, train feature - The following example explains the value of `x_train`: + - `dataset`, ``Qlib``'s ``DatasetH`` type. For more information about ``DatasetH``, users can refer to the related document: `Qlib Dataset <../component/data.html#dataset>`_. + The `dataset` is passed into the `model`'s method because there are some unique data preprocessing procedures for each, we want to give each model maximum flexibility to handle the data that is suitable for their own. + The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`: - .. code-block:: YAML - - KMID KLEN KMID2 KUP KUP2 - instrument datetime - SH600004 2012-01-04 0.000000 0.017685 0.000000 0.012862 0.727275 - 2012-01-05 -0.006473 0.025890 -0.250001 0.012945 0.499998 - 2012-01-06 0.008117 0.019481 0.416666 0.008117 0.416666 - 2012-01-09 0.016051 0.025682 0.624998 0.006421 0.250001 - 2012-01-10 0.017323 0.026772 0.647057 0.003150 0.117648 - ... ... ... ... ... ... - SZ300273 2014-12-25 -0.005295 0.038697 -0.136843 0.016293 0.421052 - 2014-12-26 -0.022486 0.041701 -0.539215 0.002453 0.058824 - 2014-12-29 -0.031526 0.039092 -0.806451 0.000000 0.000000 - 2014-12-30 -0.010000 0.032174 -0.310811 0.013913 0.432433 - 2014-12-31 0.010917 0.020087 0.543479 0.001310 0.065216 + .. code-block:: Python - - `x_train` is a pandas DataFrame, whose index is MultiIndex . Each column of `x_train` corresponds to a feature, and the column name is the feature name. - - .. note:: - - The number and names of the columns are determined by the data handler, please refer to `Data Handler `_ and `Estimator Data Section `_. - - - `y_train`, pd.DataFrame type, train label - The following example explains the value of `y_train`: + # get features and labels + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_valid["feature"], df_valid["label"] - .. code-block:: YAML - - LABEL - instrument datetime - SH600004 2012-01-04 -0.798456 - 2012-01-05 -1.366716 - 2012-01-06 -0.491026 - 2012-01-09 0.296900 - 2012-01-10 0.501426 - ... ... - SZ300273 2014-12-25 -0.465540 - 2014-12-26 0.233864 - 2014-12-29 0.471368 - 2014-12-30 0.411914 - 2014-12-31 1.342723 - - `y_train` is a pandas DataFrame, whose index is MultiIndex . The `LABEL` column represents the value of train label. - - .. note:: - - The number and names of the columns are determined by the ``Data Handler``, please refer to `Data Handler `_. - - - `x_valid`, pd.DataFrame type, validation feature - The format of `x_valid` is same as `x_train` - - - - `y_valid`, pd.DataFrame type, validation label - The format of `y_valid` is same as `y_train` - - - `w_train`(Optional args, default is None), pd.DataFrame type, train weight - `w_train` is a pandas DataFrame, whose shape and index is same as `x_train`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`. - - - `w_train`(Optional args, default is None), pd.DataFrame type, validation weight - `w_train` is a pandas DataFrame, whose shape and index is the same as `x_valid`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`. - -- `predict(self, x_test, **kwargs)` - - Predict test data 'x_test' - - Parameter: - - `x_test`, pd.DataFrame type, test features - The form of `x_test` is same as `x_train` in 'fit' method. - - Return: - - `label`, np.ndarray type, test label - The label of `x_test` that predicted by model. - -- `score(self, x_test, y_test, w_test=None, **kwargs)` - - Evaluate model with test feature/label - - Parameter: - - `x_test`, pd.DataFrame type, test feature - The format of `x_test` is same as `x_train` in `fit` method. + # get weights + try: + wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L) + w_train, w_valid = wdf_train["weight"], wdf_valid["weight"] + except KeyError as e: + w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index) + w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index) - - `x_test`, pd.DataFrame type, test label - The format of `y_test` is same as `y_train` in `fit` method. +- `predict(self, dataset, **kwargs)` + - Predict test data. + - Parameter: + - `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above. + - Returns: + - Predic results with type: `pandas.Series`. - - `w_test`, pd.DataFrame type, test weight - The format of `w_test` is same as `w_train` in `fit` method. - - Return: float type, evaluation score +- `finetune(self, dataset, **kwargs)` + - Finetune the model. + - Parameter: + - `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above. -For other interfaces such as `save`, `load`, `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_. + +For other interfaces such as `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_. Example ================== -``Qlib`` provides ``LightGBM`` and ``DNN`` models as the baseline, the following steps show how to run`` LightGBM`` as an independent module. +``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``DNN``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. The following steps show how to run`` LightGBM`` as an independent module. - Initialize ``Qlib`` with `qlib.init` first, please refer to `Initialization <../start/initialization.html>`_. - Run the following code to get the `prediction score` `pred_score` .. code-block:: Python - from qlib.contrib.data.handler import Alpha158 from qlib.contrib.model.gbdt import LGBModel + from qlib.contrib.data.handler import Alpha158 + from qlib.utils import init_instance_by_config, flatten_dict + from qlib.workflow import R + from qlib.workflow.record_temp import SignalRecord, PortAnaRecord - DATA_HANDLER_CONFIG = { - "dropna_label": True, - "start_date": "2007-01-01", - "end_date": "2020-08-01", - "market": MARKET, + market = "csi300" + benchmark = "SH000300" + + data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, } - TRAINER_CONFIG = { - "train_start_date": "2007-01-01", - "train_end_date": "2014-12-31", - "validate_start_date": "2015-01-01", - "validate_end_date": "2016-12-31", - "test_start_date": "2017-01-01", - "test_end_date": "2020-08-01", + task = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, } + + # model initiaiton + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) - x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158( - **DATA_HANDLER_CONFIG - ).get_split_data(**TRAINER_CONFIG) + # start exp + with R.start(experiment_name="workflow"): + # train + R.log_params(**flatten_dict(task)) + model.fit(dataset) + # prediction + recorder = R.get_recorder() + sr = SignalRecord(model, dataset, recorder) + sr.generate() - MODEL_CONFIG = { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - } - # use default model - model = LGBModel(**MODEL_CONFIG) - model.fit(x_train, y_train, x_validate, y_validate) - _pred = model.predict(x_test) - pred_score = pd.DataFrame(index=_pred.index) - pred_score["score"] = _pred.iloc(axis=1)[0] - - .. note:: `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler `_. + .. note:: + + `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler `_. + `SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow `_. Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``. diff --git a/docs/component/recorder.rst b/docs/component/recorder.rst index efd67e859..0d1e83168 100644 --- a/docs/component/recorder.rst +++ b/docs/component/recorder.rst @@ -402,8 +402,8 @@ Record Template The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class: -- ``SignalRecord``: This class generates the `preidction` of the model. -- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR`. +- ``SignalRecord``: This class generates the `preidction` results of the model. +- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model. - ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_. -For more information, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_. \ No newline at end of file +For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_. \ No newline at end of file diff --git a/docs/reference/api.rst b/docs/reference/api.rst index 76d2a74a5..d99a26f49 100644 --- a/docs/reference/api.rst +++ b/docs/reference/api.rst @@ -60,12 +60,26 @@ Cache Contrib ==================== +Data Loader +--------------- +.. automodule:: qlib.data.dataset.loader + :members: Data Handler --------------- .. automodule:: qlib.data.dataset.handler :members: +Processor +--------------- +.. automodule:: qlib.data.dataset.processor + :members: + +Dataset +--------------- +.. automodule:: qlib.data.dataset.__init__ + :members: + Model -------------------- .. automodule:: qlib.model.base diff --git a/docs/start/integration.rst b/docs/start/integration.rst index 5276729b5..102d88425 100644 --- a/docs/start/integration.rst +++ b/docs/start/integration.rst @@ -5,7 +5,7 @@ Custom Model Integration Introduction =================== -``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. +``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``DNN``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. In addition to the default models ``Qlib`` provide, users can integrate their own custom models into ``Qlib``. Users can integrate their own custom models according to the following steps. @@ -32,79 +32,76 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html# - Override the `fit` method - ``Qlib`` calls the fit method to train the model - - The parameters must include training feature `x_train`, training label `y_train`, test feature `x_valid`, test label `y_valid` at least. - - The parameters could include some optional parameters with default values, such as train weight `w_train`, test weight `w_valid` and `num_boost_round = 1000`. + - The parameters must include training feature `dataset`. + - The parameters could include some optional parameters with default values, such as `num_boost_round = 1000` for `GBDT`. - Code Example: In the following example, `num_boost_round = 1000` is an optional parameter. .. code-block:: Python - def fit(self, x_train:pd.DataFrame, y_train:pd.DataFrame, x_valid:pd.DataFrame, y_valid:pd.DataFrame, - w_train:pd.DataFrame = None, w_valid:pd.DataFrame = None, num_boost_round = 1000, **kwargs): + def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs): + + # prepare dataset for lgb training and evaluation + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_valid["feature"], df_valid["label"] # Lightgbm need 1D array as its label if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: - y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values) + y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values) else: - raise ValueError('LightGBM doesn\'t support multi-label training') + raise ValueError("LightGBM doesn't support multi-label training") - w_train_weight = None if w_train is None else w_train.values - w_valid_weight = None if w_valid is None else w_valid.values + dtrain = lgb.Dataset(x_train.values, label=y_train) + dvalid = lgb.Dataset(x_valid.values, label=y_valid) - dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight) - dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight) - self._model = lgb.train( - self._params, - dtrain, + # fit the model + self.model = lgb.train( + self.params, + dtrain, num_boost_round=num_boost_round, valid_sets=[dtrain, dvalid], - valid_names=['train', 'valid'], + valid_names=["train", "valid"], + early_stopping_rounds=early_stopping_rounds, + verbose_eval=verbose_eval, + evals_result=evals_result, **kwargs ) - Override the `predict` method - - The parameters include the test features. + - The parameters must include training feature `dataset`, which will be userd to get the test dataset. - Return the `prediction score`. - Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method. - - Code Example: In the following example, users need to use dnn to predict the label(such as `preds`) of test data `x_test` and return it. + - Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it. .. code-block:: Python - def predict(self, x_test:pd.DataFrame, **kwargs)-> numpy.ndarray: - if self._model is None: - raise ValueError('model is not fitted yet!') - return self._model.predict(x_test.values) + def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series: + if self.model is None: + raise ValueError("model is not fitted yet!") + x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + return pd.Series(self.model.predict(x_test.values), index=x_test.index) -- Override the `save` method & `load` method - - The `save` method parameter includes the a `filename` that represents an absolute path, user need to save model into the path. - - The `load` method parameter includes the a `buffer` read from the `filename` passed in the `save` method, users need to load model from the `buffer`. - - Code Example: +- Override the `finetune` method + - The parameters must include training feature `dataset`. + - Code Example: In the following example, users will use `LightGBM` as the model and finetune it. .. code-block:: Python - def save(self, filename): - if self._model is None: - raise ValueError('model is not fitted yet!') - self._model.save_model(filename) - - def load(self, buffer): - self._model = lgb.Booster(params={'model_str': buffer.decode('utf-8')}) - -.. Without tuner, this part will not be used -.. - Override the `score` method(This step is optional) -.. - The parameters include the test features and test labels. -.. - Return the evaluation score of the model. It's recommended to adopt the loss between labels and `prediction score`. -.. - Code Example: In the following example, users need to calculate the weighted loss with test data `x_test`, test label `y_test` and the weight `w_test`. -.. .. code-block:: Python -.. -.. def score(self, x_test:pd.Dataframe, y_test:pd.Dataframe, w_test:pd.DataFrame = None) -> float: -.. # Remove rows from x, y and w, which contain Nan in any columns in y_test. -.. x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test) -.. preds = self.predict(x_test) -.. w_test_weight = None if w_test is None else w_test.values -.. scorer = mean_squared_error if self.loss_type == 'mse' else roc_auc_score -.. return scorer(y_test.values, preds, sample_weight=w_test_weight) + def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20): + dtrain, _ = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + init_model=self.model, + valid_sets=[dtrain], + valid_names=["train"], + verbose_eval=verbose_eval, + ) Configuration File ======================= -The configuration file is described in detail in the `estimator <../component/estimator.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file. +The configuration file is described in detail in the `Workflow <../component/workflow.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file. - Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`. @@ -124,20 +121,20 @@ The configuration file is described in detail in the `estimator <../component/es num_leaves: 210 num_threads: 20 -Users could find configuration file of the baseline of the ``Model`` in ``qlib/examples/estimator/estimator_config.yaml`` and ``qlib/examples/estimator/estimator_config_dnn.yaml`` +Users could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder. Model Testing ===================== -Assuming that the configuration file is ``examples/estimator/estimator_config.yaml``, users can run the following command to test the custom model: +Assuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml``, users can run the following command to test the custom model: .. code-block:: bash cd examples # Avoid running program under the directory contains `qlib` - estimator -c estimator/estimator_config.yaml + qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml -.. note:: ``estimator`` is a built-in command of ``Qlib``. +.. note:: ``qrun`` is a built-in command of ``Qlib``. -Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/train_backtest_analyze.ipynb``. +Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``. Reference diff --git a/examples/benchmarks/HATS/README.md b/examples/benchmarks/HATS/README.md index 95619e1ee..b70dbff25 100644 --- a/examples/benchmarks/HATS/README.md +++ b/examples/benchmarks/HATS/README.md @@ -1,11 +1,11 @@ -##Requirement +## Requirement * pandas==1.1.2 * numpy==1.17.4 * scikit_learn==0.23.2 * torch==1.7.0 -##HATS +## HATS * HATS is a a hierarchical attention network for stock prediction which uses relational data for stock market prediction. HATS selectively aggregates information on different relation types and adds the information to the representations of each company. HATS is used as a relational modeling module with initialized node representations.Furthermore, HATS diff --git a/examples/benchmarks/TFT/README.md b/examples/benchmarks/TFT/README.md index 6d605a1bd..a64ca0129 100644 --- a/examples/benchmarks/TFT/README.md +++ b/examples/benchmarks/TFT/README.md @@ -5,7 +5,7 @@ **GitHub**: https://github.com/google-research/google-research/tree/master/tft ## Run the Workflow -Users can follow the ``workflow_by_code_tft.py`` to run the benchmark. +Users can follow the ``workflow_by_code_tft.py`` to run the benchmark. Please be **aware** that this script can only support Python 3.5 - 3.8. ### Notes 1. The model must run in GPU, or an error will be raised. diff --git a/examples/benchmarks/TabNet/workflow_config_tabnet.yaml b/examples/benchmarks/TabNet/workflow_config_tabnet.yaml index 0ee95f238..5f6aa8b6d 100644 --- a/examples/benchmarks/TabNet/workflow_config_tabnet.yaml +++ b/examples/benchmarks/TabNet/workflow_config_tabnet.yaml @@ -44,7 +44,7 @@ task: module_path: qlib.data.dataset kwargs: handler: - class: Alpha158 + class: ALPHA360_Denoise module_path: qlib.contrib.data.handler kwargs: *data_handler_config segments: diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 2b85f1a9b..4bb5e4372 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -26,9 +26,9 @@ def risk_analysis(r, N=252): Parameters ---------- r : pandas.Series - daily return series + daily return series. N: int - scaler for annualizing information_ratio (day: 250, week: 50, month: 12) + scaler for annualizing information_ratio (day: 250, week: 50, month: 12). """ mean = r.mean() std = r.std(ddof=1) @@ -61,7 +61,7 @@ def get_strategy( ---------- strategy : Strategy() - strategy used in backtest + strategy used in backtest. topk : int (Default value: 50) top-N stocks to buy. margin : int or float(Default value: 0.5) @@ -73,14 +73,14 @@ def get_strategy( sell_limit = pred_in_a_day.count() * margin - buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit) - sell_limit should be no less than topk + buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit). + sell_limit should be no less than topk. n_drop : int - number of stocks to be replaced in each trading date + number of stocks to be replaced in each trading date. risk_degree: float - 0-1, 0.95 for example, use 95% money to trade + 0-1, 0.95 for example, use 95% money to trade. str_type: 'amount', 'weight' or 'dropout' - strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy + strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy. Returns ------- @@ -126,21 +126,21 @@ def get_exchange( ---------- # exchange related arguments - exchange: Exchange() + exchange: Exchange(). subscribe_fields: list - subscribe fields + subscribe fields. open_cost : float - open transaction cost + open transaction cost. close_cost : float - close transaction cost + close transaction cost. min_cost : float - min transaction cost + min transaction cost. trade_unit : int - 100 for China A + 100 for China A. deal_price: str - dealing price type: 'close', 'open', 'vwap' + dealing price type: 'close', 'open', 'vwap'. limit_threshold : float - limit move 0.1 (10%) for example, long and short with same limit + limit move 0.1 (10%) for example, long and short with same limit. extract_codes: bool will we pass the codes extracted from the pred to the exchange. NOTE: This will be faster with offline qlib. @@ -193,20 +193,20 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k - **backtest workflow related or commmon arguments** pred : pandas.DataFrame - predict should has index and one `score` column + predict should has index and one `score` column. account : float - init account value + init account value. shift : int - whether to shift prediction by one day + whether to shift prediction by one day. benchmark : str - benchmark code, default is SH000905 CSI 500 + benchmark code, default is SH000905 CSI 500. verbose : bool - whether to print log + whether to print log. - **strategy related arguments** strategy : Strategy() - strategy used in backtest + strategy used in backtest. topk : int (Default value: 50) top-N stocks to buy. margin : int or float(Default value: 0.5) @@ -218,33 +218,33 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k sell_limit = pred_in_a_day.count() * margin - buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit) - sell_limit should be no less than topk + buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit). + sell_limit should be no less than topk. n_drop : int - number of stocks to be replaced in each trading date + number of stocks to be replaced in each trading date. risk_degree: float - 0-1, 0.95 for example, use 95% money to trade + 0-1, 0.95 for example, use 95% money to trade. str_type: 'amount', 'weight' or 'dropout' - strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy + strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy. - **exchange related arguments** exchange: Exchange() pass the exchange for speeding up. subscribe_fields: list - subscribe fields + subscribe fields. open_cost : float open transaction cost. The default value is 0.002(0.2%). close_cost : float close transaction cost. The default value is 0.002(0.2%). min_cost : float - min transaction cost + min transaction cost. trade_unit : int - 100 for China A + 100 for China A. deal_price: str - dealing price type: 'close', 'open', 'vwap' + dealing price type: 'close', 'open', 'vwap'. limit_threshold : float - limit move 0.1 (10%) for example, long and short with same limit + limit move 0.1 (10%) for example, long and short with same limit. extract_codes: bool will we pass the codes extracted from the pred to the exchange. @@ -291,17 +291,17 @@ def long_short_backtest( """ A backtest for long-short strategy - :param pred: The trading signal produced on day `T` - :param topk: The short topk securities and long topk securities - :param deal_price: The price to deal the trading + :param pred: The trading signal produced on day `T`. + :param topk: The short topk securities and long topk securities. + :param deal_price: The price to deal the trading. :param shift: Whether to shift prediction by one day. The trading day will be T+1 if shift==1. - :param open_cost: open transaction cost - :param close_cost: close transaction cost - :param trade_unit: 100 for China A - :param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit - :param min_cost: min transaction cost - :param subscribe_fields: subscribe fields - :param extract_codes: bool + :param open_cost: open transaction cost. + :param close_cost: close transaction cost. + :param trade_unit: 100 for China A. + :param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit. + :param min_cost: min transaction cost. + :param subscribe_fields: subscribe fields. + :param extract_codes: bool. will we pass the codes extracted from the pred to the exchange. NOTE: This will be faster with offline qlib. :return: The result of backtest, it is represented by a dict. diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py index 1c69145db..1cb14d261 100644 --- a/qlib/contrib/report/analysis_model/analysis_model_performance.py +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -252,7 +252,7 @@ def model_performance_graph( """Model performance :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, - label]**. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1") + label]**. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1"). .. code-block:: python @@ -266,13 +266,13 @@ def model_performance_graph( :param lag: `pred.groupby(level='instrument')['score'].shift(lag)`. It will be only used in the auto-correlation computing. - :param N: group number, default 5 - :param reverse: if `True`, `pred['score'] *= -1` - :param rank: if **True**, calculate rank ic - :param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'] - :param show_notebook: whether to display graphics in notebook, the default is `True` - :param show_nature_day: whether to display the abscissa of non-trading day - :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list + :param N: group number, default 5. + :param reverse: if `True`, `pred['score'] *= -1`. + :param rank: if **True**, calculate rank ic. + :param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover']. + :param show_notebook: whether to display graphics in notebook, the default is `True`. + :param show_nature_day: whether to display the abscissa of non-trading day. + :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list. """ figure_list = [] for graph_name in graph_names: diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py index 941785e83..abb68ea60 100644 --- a/qlib/contrib/report/analysis_position/cumulative_return.py +++ b/qlib/contrib/report/analysis_position/cumulative_return.py @@ -218,10 +218,10 @@ def cumulative_return_graph( Graph desc: - - Axis X: Trading day + - Axis X: Trading day. - Axis Y: - - Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()` - - Below axis Y: Daily weight sum + - Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`. + - Below axis Y: Daily weight sum. - In the **sell** graph, `y < 0` stands for profit; in other cases, `y > 0` stands for profit. - In the **buy_minus_sell** graph, the **y** value of the **weight** graph at the bottom is `buy_weight + sell_weight`. - In each graph, the **red line** in the histogram on the right represents the average. diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py index e2f7fe1cf..72a358adc 100644 --- a/qlib/contrib/report/analysis_position/rank_label.py +++ b/qlib/contrib/report/analysis_position/rank_label.py @@ -97,9 +97,9 @@ def rank_label_graph( qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) - :param position: position data; **qlib.contrib.backtest.backtest.backtest** result + :param position: position data; **qlib.contrib.backtest.backtest.backtest** result. :param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**. - **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])` + **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`. .. code-block:: python @@ -115,7 +115,7 @@ def rank_label_graph( :param start_date: start date :param end_date: end_date - :param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures + :param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures. :return: """ position = copy.deepcopy(position) diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py index 6d108cabf..438aab8b9 100644 --- a/qlib/contrib/report/analysis_position/report.py +++ b/qlib/contrib/report/analysis_position/report.py @@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, qcr.report_graph(report_normal_df) - :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench** + :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**. .. code-block:: python @@ -200,8 +200,8 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, 2017-01-10 -0.000416 0.000440 -0.003350 0.208396 - :param show_notebook: whether to display graphics in notebook, the default is **True** - :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list + :param show_notebook: whether to display graphics in notebook, the default is **True**. + :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list. """ report_df = report_df.copy() fig_list = _report_figure(report_df) diff --git a/qlib/contrib/report/analysis_position/risk_analysis.py b/qlib/contrib/report/analysis_position/risk_analysis.py index 124a9b3b0..051c78035 100644 --- a/qlib/contrib/report/analysis_position/risk_analysis.py +++ b/qlib/contrib/report/analysis_position/risk_analysis.py @@ -218,7 +218,7 @@ def risk_analysis_graph( max_drawdown -0.088263 - :param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench** + :param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**. .. code-block:: python @@ -232,7 +232,7 @@ def risk_analysis_graph( 2017-01-10 -0.000416 0.000440 -0.003350 0.208396 - :param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short** + :param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**. .. code-block:: python @@ -246,7 +246,7 @@ def risk_analysis_graph( 2017-01-10 0.000824 -0.001944 -0.001120 - :param show_notebook: Whether to display graphics in a notebook, default **True** + :param show_notebook: Whether to display graphics in a notebook, default **True**. If True, show graph in notebook If False, return graph figure :return: diff --git a/qlib/contrib/report/analysis_position/score_ic.py b/qlib/contrib/report/analysis_position/score_ic.py index 9a2fc8560..a6a7a8b0e 100644 --- a/qlib/contrib/report/analysis_position/score_ic.py +++ b/qlib/contrib/report/analysis_position/score_ic.py @@ -36,7 +36,7 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis analysis_position.score_ic_graph(pred_label) - :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]** + :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**. .. code-block:: python @@ -49,8 +49,8 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis 2017-12-15 -0.102778 -0.102778 - :param show_notebook: whether to display graphics in notebook, the default is **True** - :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list + :param show_notebook: whether to display graphics in notebook, the default is **True**. + :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list. """ _ic_df = _get_score_ic(pred_label) # FIXME: support HIGH-FREQ diff --git a/qlib/contrib/strategy/strategy.py b/qlib/contrib/strategy/strategy.py index f2e2a4554..23e8b5185 100644 --- a/qlib/contrib/strategy/strategy.py +++ b/qlib/contrib/strategy/strategy.py @@ -31,16 +31,16 @@ class BaseStrategy: Parameters ----------- score_series : pd.Seires - stock_id , score + stock_id , score. current : Position() - current state of position - DO NOT directly change the state of current + current state of position. + DO NOT directly change the state of current. trade_exchange : Exchange() - trade exchange + trade exchange. pred_date : pd.Timestamp - predict date + predict date. trade_date : pd.Timestamp - trade date + trade date. """ pass @@ -49,11 +49,11 @@ class BaseStrategy: Parameters ----------- score_series : pd.Series - stock_id , score + stock_id , score. pred_date : pd.Timestamp - oredict date + oredict date. trade_date : pd.Timestamp - trade date + trade date. """ pass @@ -67,7 +67,7 @@ class BaseStrategy: """ This method only be used in 'online' module, it will generate the *args to initial the strategy. :param - mode : model used in 'online' module + mode : model used in 'online' module. """ return {} @@ -82,7 +82,7 @@ class StrategyWrapper: def __init__(self, inner_strategy): """__init__ - :param inner_strategy: set the inner strategy + :param inner_strategy: set the inner strategy. """ self.inner_strategy = inner_strategy @@ -99,9 +99,9 @@ class AdjustTimer: Responsible for timing of position adjusting This is designed as multiple inheritance mechanism due to: - - the is_adjust may need access to the internel state of a strategy + - the is_adjust may need access to the internel state of a strategy. - - it can be reguard as a enhancement to the existing strategy + - it can be reguard as a enhancement to the existing strategy. """ # adjust position in each trade date @@ -146,12 +146,12 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer): Parameters ----------- score : pd.Series - pred score for this trade date, index is stock_id, contain 'score' column + pred score for this trade date, index is stock_id, contain 'score' column. current : Position() - current position + current position. trade_exchange : Exchange() trade_date : pd.Timestamp - trade date + trade date. """ raise NotImplementedError() @@ -160,13 +160,13 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer): Parameters ----------- score_series : pd.Seires - stock_id , score + stock_id , score. current : Position() - current of account + current of account. trade_exchange : Exchange() - exchange + exchange. trade_date : pd.Timestamp - date + date. """ # judge if to adjust if not self.is_adjust(trade_date): @@ -206,26 +206,26 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer): Parameters ----------- topk : int - The number of stocks in the portfolio + the number of stocks in the portfolio. n_drop : int - number of stocks to be replaced in each trading date + number of stocks to be replaced in each trading date. method_sell : str - dropout method_sell, random/bottom + dropout method_sell, random/bottom. method_buy : str - dropout method_buy, random/top + dropout method_buy, random/top. risk_degree : float - position percentage of total value + position percentage of total value. thresh : int - minimun holding days since last buy singal of the stock + minimun holding days since last buy singal of the stock. hold_thresh : int minimum holding days - before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh + before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh. only_tradable : bool will the strategy only consider the tradable stock when buying and selling. if only_tradable: - strategy will make buy sell decision without checking the tradable state of the stock + strategy will make buy sell decision without checking the tradable state of the stock. else: - strategy will make decision with the tradable state of the stock info and avoid buy and sell them + strategy will make decision with the tradable state of the stock info and avoid buy and sell them. """ super(TopkDropoutStrategy, self).__init__() ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None)) @@ -245,7 +245,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer): def get_risk_degree(self, date): """get_risk_degree Return the proportion of your total value you will used in investment. - Dynamically risk_degree will result in Market timing + Dynamically risk_degree will result in Market timing. """ # It will use 95% amoutn of your total value by default return self.risk_degree @@ -257,15 +257,15 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer): Parameters ----------- score_series : pd.Series - stock_id , score + stock_id , score. current : Position() - current of account + current of account. trade_exchange : Exchange() - exchange + exchange. pred_date : pd.Timestamp - predict date + predict date. trade_date : pd.Timestamp - trade date + trade date. """ if not self.is_adjust(trade_date): return [] diff --git a/qlib/data/base.py b/qlib/data/base.py index 433b6585a..92fc57ffe 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -129,13 +129,13 @@ class Expression(abc.ABC): Parameters ---------- instrument : str - instrument code + instrument code. start_index : str - feature start index [in calendar] + feature start index [in calendar]. end_index : str - feature end index [in calendar] + feature end index [in calendar]. freq : str - feature frequency + feature frequency. Returns ---------- diff --git a/qlib/data/cache.py b/qlib/data/cache.py index bf8baab31..3fab2b527 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -76,8 +76,8 @@ class MemCache(object): Parameters ---------- - mem_cache_size_limit: cache max size - limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof) + mem_cache_size_limit: cache max size. + limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof). """ if limit_type not in ["length", "sizeof"]: raise ValueError(f"limit_type must be length or sizeof, your limit_type is {limit_type}") @@ -118,9 +118,9 @@ class MemCacheExpire: def set_cache(mem_cache, key, value): """set cache - :param mem_cache: MemCache attribute('c'/'i'/'f') - :param key: cache key - :param value: cache value + :param mem_cache: MemCache attribute('c'/'i'/'f'). + :param key: cache key. + :param value: cache value. """ mem_cache[key] = value, time.time() @@ -128,9 +128,9 @@ class MemCacheExpire: def get_cache(mem_cache, key): """get mem cache - :param mem_cache: MemCache attribute('c'/'i'/'f') - :param key: cache key - :return: cache value; if cache not exist, return None + :param mem_cache: MemCache attribute('c'/'i'/'f'). + :param key: cache key. + :return: cache value; if cache not exist, return None. """ value = None expire = False @@ -275,12 +275,12 @@ class ExpressionCache(BaseProviderCache): Parameters ---------- cache_uri : str - the complete uri of expression cache file (include dir path) + the complete uri of expression cache file (include dir path). Returns ------- int - 0(successful update)/ 1(no need to update)/ 2(update failure) + 0(successful update)/ 1(no need to update)/ 2(update failure). """ raise NotImplementedError("Implement this method if you want to make expression cache up to date") @@ -348,7 +348,7 @@ class DatasetCache(BaseProviderCache): Parameters ---------- cache_uri : str - the complete uri of dataset cache file (include dir path) + the complete uri of dataset cache file (include dir path). Returns ------- @@ -361,9 +361,9 @@ class DatasetCache(BaseProviderCache): def cache_to_origin_data(data, fields): """cache data to origin data - :param data: pd.DataFrame, cache data - :param fields: feature fields - :return: pd.DataFrame + :param data: pd.DataFrame, cache data. + :param fields: feature fields. + :return: pd.DataFrame. """ not_space_fields = remove_fields_space(fields) data = data.loc[:, not_space_fields] @@ -583,7 +583,7 @@ class DiskDatasetCache(DatasetCache): :param cache_path: :param start_time: :param end_time: - :param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent + :param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent. :return: """ @@ -771,12 +771,12 @@ class DiskDatasetCache(DatasetCache): - This is a hdf file sorted by datetime - :param cache_path: The path to store the cache - :param instruments: The instruments to store the cache - :param fields: The fields to store the cache - :param freq: The freq to store the cache + :param cache_path: The path to store the cache. + :param instruments: The instruments to store the cache. + :param fields: The fields to store the cache. + :param freq: The freq to store the cache. - :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function + :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function. """ # get calendar from .data import Cal diff --git a/qlib/data/client.py b/qlib/data/client.py index 928faaa72..65a830f20 100644 --- a/qlib/data/client.py +++ b/qlib/data/client.py @@ -51,13 +51,13 @@ class Client(object): Parameters ---------- request_type : str - type of proposed request, 'calendar'/'instrument'/'feature' + type of proposed request, 'calendar'/'instrument'/'feature'. request_content : dict - records the information of the request + records the information of the request. msg_proc_func : func - the function to process the message when receiving response, should have arg `*args` + the function to process the message when receiving response, should have arg `*args`. msg_queue: Queue - The queue to pass the messsage after callback + The queue to pass the messsage after callback. """ head_info = {"version": qlib.__version__} diff --git a/qlib/data/data.py b/qlib/data/data.py index ef5e7fe8a..a4c3d63f2 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -41,13 +41,13 @@ class CalendarProvider(abc.ABC): Parameters ---------- start_time : str - start of the time range + start of the time range. end_time : str - end of the time range + end of the time range. freq : str - time frequency, available: year/quarter/month/week/day + time frequency, available: year/quarter/month/week/day. future : bool - whether including future trading day + whether including future trading day. Returns ---------- @@ -62,24 +62,24 @@ class CalendarProvider(abc.ABC): Parameters ---------- start_time : str - start of the time range + start of the time range. end_time : str - end of the time range + end of the time range. freq : str - time frequency, available: year/quarter/month/week/day + time frequency, available: year/quarter/month/week/day. future : bool - whether including future trading day + whether including future trading day. Returns ------- pd.Timestamp - the real start time + the real start time. pd.Timestamp - the real end time + the real end time. int - the index of start time + the index of start time. int - the index of end time + the index of end time. """ start_time = pd.Timestamp(start_time) end_time = pd.Timestamp(end_time) @@ -103,16 +103,16 @@ class CalendarProvider(abc.ABC): Parameters ---------- freq : str - frequency of read calendar file + frequency of read calendar file. future : bool - whether including future trading day + whether including future trading day. Returns ------- list - list of timestamps + list of timestamps. dict - dict composed by timestamp as key and index as value for fast search + dict composed by timestamp as key and index as value for fast search. """ flag = f"{freq}_future_{future}" if flag in H["c"]: @@ -141,14 +141,14 @@ class InstrumentProvider(abc.ABC): Parameters ---------- market : str - market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500 + market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500. filter_pipe : list - the list of dynamic filters + the list of dynamic filters. Returns ---------- dict - dict of stockpool config + dict of stockpool config. {`market`=>base market name, `filter_pipe`=>list of filters} example : @@ -182,13 +182,13 @@ class InstrumentProvider(abc.ABC): Parameters ---------- instruments : dict - stockpool config + stockpool config. start_time : str - start of the time range + start of the time range. end_time : str - end of the time range + end of the time range. as_list : bool - return instruments as list or dict + return instruments as list or dict. Returns ------- @@ -243,15 +243,15 @@ class FeatureProvider(abc.ABC): Parameters ---------- instrument : str - a certain instrument + a certain instrument. field : str - a certain field of feature + a certain field of feature. start_time : str - start of the time range + start of the time range. end_time : str - end of the time range + end of the time range. freq : str - time frequency, available: year/quarter/month/week/day + time frequency, available: year/quarter/month/week/day. Returns ------- @@ -294,15 +294,15 @@ class ExpressionProvider(abc.ABC): Parameters ---------- instrument : str - a certain instrument + a certain instrument. field : str - a certain field of feature + a certain field of feature. start_time : str - start of the time range + start of the time range. end_time : str - end of the time range + end of the time range. freq : str - time frequency, available: year/quarter/month/week/day + time frequency, available: year/quarter/month/week/day. Returns ------- @@ -325,20 +325,20 @@ class DatasetProvider(abc.ABC): Parameters ---------- instruments : list or dict - list/dict of instruments or dict of stockpool config + list/dict of instruments or dict of stockpool config. fields : list - list of feature instances + list of feature instances. start_time : str - start of the time range + start of the time range. end_time : str - end of the time range + end of the time range. freq : str - time frequency + time frequency. Returns ---------- pd.DataFrame - a pandas dataframe with index + a pandas dataframe with index. """ raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method") @@ -357,17 +357,17 @@ class DatasetProvider(abc.ABC): Parameters ---------- instruments : list or dict - list/dict of instruments or dict of stockpool config + list/dict of instruments or dict of stockpool config. fields : list - list of feature instances + list of feature instances. start_time : str - start of the time range + start of the time range. end_time : str - end of the time range + end of the time range. freq : str - time frequency + time frequency. disk_cache : int - whether to skip(0)/use(1)/replace(2) disk_cache + whether to skip(0)/use(1)/replace(2) disk_cache. """ return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache) @@ -526,7 +526,7 @@ class LocalCalendarProvider(CalendarProvider): Parameters ---------- freq : str - frequency of read calendar file + frequency of read calendar file. Returns ---------- diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index e972aba3c..74e14f47a 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -17,7 +17,7 @@ class Dataset(Serializable): init is designed to finish following steps: - setup data - - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing + - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing. - initialize the state of the dataset(info to prepare the data) - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing. @@ -29,17 +29,17 @@ class Dataset(Serializable): def setup_data(self, *args, **kwargs): """ - setup the data + Setup the data. We split the setup_data function for following situation: - - User have a Dataset object with learned status on disk + - User have a Dataset object with learned status on disk. - - User load the Dataset object from the disk(Note the init function is skiped) + - User load the Dataset object from the disk(Note the init function is skiped). - - User call `setup_data` to load new data + - User call `setup_data` to load new data. - - User prepare data for model based on previous status + - User prepare data for model based on previous status. """ pass @@ -66,9 +66,10 @@ class DatasetH(Dataset): User should try to put the data preprocessing functions into handler. Only following data processing functions should be placed in Dataset: + - The processing is related to specific model. - - The processing is related to data split + - The processing is related to data split. """ def __init__(self, handler: Union[dict, DataHandler], segments: list): @@ -76,15 +77,15 @@ class DatasetH(Dataset): Parameters ---------- handler : Union[dict, DataHandler] - handler will be passed into setup_data + handler will be passed into setup_data. segments : list - handler will be passed into setup_data + handler will be passed into setup_data. """ super().__init__(handler, segments) def setup_data(self, handler: Union[dict, DataHandler], segments: list): """ - setup the underlying data + Setup the underlying data. Parameters ---------- @@ -121,7 +122,7 @@ class DatasetH(Dataset): **kwargs, ) -> Union[List[pd.DataFrame], pd.DataFrame]: """ - prepare the data for learning and inference + Prepare the data for learning and inference. Parameters ---------- @@ -132,11 +133,12 @@ class DatasetH(Dataset): - 'train' - ['train', 'valid'] + col_set : str - The col_set will be passed to self._handler when fetching data - data_key: str + The col_set will be passed to self._handler when fetching data. + data_key : str The data to fetch: DK_* - Default is DK_I, which indicate fetching data for **inference** + Default is DK_I, which indicate fetching data for **inference**. Returns ------- diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 4d3d88c38..1710ff9e3 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -29,7 +29,7 @@ class DataHandler(Serializable): """ The steps to using a handler 1. initialized data handler (call by `init`). - 2. use the data + 2. use the data. The data handler try to maintain a handler with 2 level. @@ -65,17 +65,17 @@ class DataHandler(Serializable): Parameters ---------- instruments : - The stock list to retrive + The stock list to retrive. start_time : - start_time of the original data + start_time of the original data. end_time : - end_time of the original data + end_time of the original data. data_loader : Tuple[dict, str, DataLoader] - data loader to load the data + data loader to load the data. init_data : - intialize the original data in the constructor + intialize the original data in the constructor. fetch_orig : bool - Return the original data instead of copy if possible + Return the original data instead of copy if possible. """ # Set logger self.logger = get_module_logger("DataHandler") @@ -219,9 +219,9 @@ class DataHandler(Serializable): get a iterator of sliced data with given periods Args: - periods (int): number of periods - min_periods (int): minimum periods for sliced dataframe - kwargs (dict): will be passed to `self.fetch` + periods (int): number of periods. + min_periods (int): minimum periods for sliced dataframe. + kwargs (dict): will be passed to `self.fetch`. """ trading_dates = self._data.index.unique(level="datetime") if min_periods is None: @@ -377,7 +377,7 @@ class DataHandlerLP(DataHandler): Parameters ---------- init_type : str - The type `IT_*` listed above + The type `IT_*` listed above. enable_cache : bool default value is false: @@ -419,13 +419,13 @@ class DataHandlerLP(DataHandler): Parameters ---------- selector : Union[pd.Timestamp, slice, str] - describe how to select data by index + describe how to select data by index. level : Union[str, int] - which index level to select the data + which index level to select the data. col_set : str - select a set of meaningful columns.(e.g. features, columns) - data_key: str - The data to fetch: DK_* + select a set of meaningful columns.(e.g. features, columns). + data_key : str + the data to fetch: DK_*. Returns ------- @@ -443,9 +443,9 @@ class DataHandlerLP(DataHandler): Parameters ---------- col_set : str - select a set of meaningful columns.(e.g. features, columns) - data_key: str - The data to fetch: DK_* + select a set of meaningful columns.(e.g. features, columns). + data_key : str + the data to fetch: DK_*. Returns ------- diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index db6b1440d..d1de4821c 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -100,16 +100,16 @@ class DLWParser(DataLoader): Parameters ---------- instruments : - the instruments + the instruments. exprs : list - The expressions to describe the content of the data + the expressions to describe the content of the data. names : list - The name of the data + the name of the data. Returns ------- pd.DataFrame: - the queried dataframe + the queried dataframe. """ pass diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 2201c0891..e4003a1f5 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -21,7 +21,7 @@ def get_group_columns(df: pd.DataFrame, group: str): Parameters ---------- df : pd.DataFrame - with multi of columns + with multi of columns. group : str the name of the feature group, i.e. the first level value of the group index. """ @@ -56,7 +56,7 @@ class Processor(Serializable): Parameters ---------- df : pd.DataFrame - The raw_df of handler or result from previous processor + The raw_df of handler or result from previous processor. """ pass @@ -68,7 +68,7 @@ class Processor(Serializable): Returns ------- bool: - if it is usable for infenrece + if it is usable for infenrece. """ return True diff --git a/qlib/data/filter.py b/qlib/data/filter.py index 47b093b67..70f9d3278 100644 --- a/qlib/data/filter.py +++ b/qlib/data/filter.py @@ -32,7 +32,7 @@ class BaseDFilter(abc.ABC): Parameters ---------- config : dict - dict of config parameters + dict of config parameters. """ raise NotImplementedError("Subclass of BaseDFilter must reimplement `from_config` method") @@ -43,7 +43,7 @@ class BaseDFilter(abc.ABC): Returns ---------- dict - return the dict of config parameters + return the dict of config parameters. """ raise NotImplementedError("Subclass of BaseDFilter must reimplement `to_config` method") @@ -69,9 +69,9 @@ class SeriesDFilter(BaseDFilter): Parameters ---------- fstart_time: str - the time for the filter rule to start filter the instruments + the time for the filter rule to start filter the instruments. fend_time: str - the time for the filter rule to stop filter the instruments + the time for the filter rule to stop filter the instruments. """ super(SeriesDFilter, self).__init__() self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None @@ -83,12 +83,12 @@ class SeriesDFilter(BaseDFilter): Parameters ---------- instruments: dict - the dict of instruments in the form {instrument_name => list of timestamp tuple} + the dict of instruments in the form {instrument_name => list of timestamp tuple}. Returns ---------- pd.Timestamp, pd.Timestamp - the lower time bound and upper time bound of all the instruments + the lower time bound and upper time bound of all the instruments. """ trange = Cal.calendar(freq=self.filter_freq) ubound, lbound = trange[0], trange[-1] @@ -105,14 +105,14 @@ class SeriesDFilter(BaseDFilter): Parameters ---------- time_range : D.calendar - the time range of the instruments + the time range of the instruments. target_timestamp : list - the list of tuple (timestamp, timestamp) + the list of tuple (timestamp, timestamp). Returns ---------- pd.Series - the series of bool value for an instrument + the series of bool value for an instrument. """ # Construct a whole dict of {date => bool} timestamp_series = {timestamp: False for timestamp in time_range} @@ -124,19 +124,19 @@ class SeriesDFilter(BaseDFilter): return timestamp_series def _filterSeries(self, timestamp_series, filter_series): - """Filter the timestamp series with filter series by using element-wise AND operation of the two series + """Filter the timestamp series with filter series by using element-wise AND operation of the two series. Parameters ---------- timestamp_series : pd.Series - the series of bool value indicating existing time + the series of bool value indicating existing time. filter_series : pd.Series - the series of bool value indicating filter feature + the series of bool value indicating filter feature. Returns ---------- pd.Series - the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp + the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp. """ fstart, fend = list(filter_series.keys())[0], list(filter_series.keys())[-1] filter_series = filter_series.astype("bool") # Make sure the filter_series is boolean @@ -144,17 +144,17 @@ class SeriesDFilter(BaseDFilter): return timestamp_series def _toTimestamp(self, timestamp_series): - """Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE + """Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE. Parameters ---------- timestamp_series: pd.Series - the series of bool value after being filtered + the series of bool value after being filtered. Returns ---------- list - the list of tuple (timestamp, timestamp) + the list of tuple (timestamp, timestamp). """ # sort the timestamp_series according to the timestamps timestamp_series.sort_index() @@ -194,18 +194,18 @@ class SeriesDFilter(BaseDFilter): Parameters ---------- instruments : dict - the dict of instruments to be filtered + the dict of instruments to be filtered. fstart : pd.Timestamp - start time of filter + start time of filter. fend : pd.Timestamp - end time of filter + end time of filter. - .. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time + .. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time. Returns ---------- pd.Dataframe - a series of {pd.Timestamp => bool} + a series of {pd.Timestamp => bool}. """ raise NotImplementedError("Subclass of SeriesDFilter must reimplement `getFilterSeries` method") @@ -215,16 +215,16 @@ class SeriesDFilter(BaseDFilter): Parameters ---------- instruments: dict - input instruments to be filtered + input instruments to be filtered. start_time: str - start of the time range + start of the time range. end_time: str - end of the time range + end of the time range. Returns ---------- dict - filtered instruments, same structure as input instruments + filtered instruments, same structure as input instruments. """ lbound, ubound = self._getTimeBound(instruments) start_time = pd.Timestamp(start_time or lbound) @@ -272,7 +272,7 @@ class NameDFilter(SeriesDFilter): params: ------ name_rule_re: str - regular expression for the name rule + regular expression for the name rule. """ super(NameDFilter, self).__init__(fstart_time, fend_time) self.name_rule_re = name_rule_re @@ -325,13 +325,13 @@ class ExpressionDFilter(SeriesDFilter): params: ------ fstart_time: str - filter the feature starting from this time + filter the feature starting from this time. fend_time: str - filter the feature ending by this time + filter the feature ending by this time. rule_expression: str - an input expression for the rule + an input expression for the rule. keep: bool - whether to keep the instruments of which features don't exist in the filter time span + whether to keep the instruments of which features don't exist in the filter time span. """ super(ExpressionDFilter, self).__init__(fstart_time, fend_time) self.rule_expression = rule_expression diff --git a/qlib/model/base.py b/qlib/model/base.py index d6ee50e33..fd220cd7e 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -33,7 +33,7 @@ class Model(BaseModel): Parameters ---------- dataset : Dataset - dataset will generate the processed data from model training + dataset will generate the processed data from model training. """ raise NotImplementedError() @@ -44,7 +44,7 @@ class Model(BaseModel): Parameters ---------- dataset : Dataset - dataset will generate the processed dataset from model training + dataset will generate the processed dataset from model training. """ raise NotImplementedError() @@ -59,6 +59,6 @@ class ModelFT(Model): Parameters ---------- dataset : Dataset - dataset will generate the processed dataset from model training + dataset will generate the processed dataset from model training. """ raise NotImplementedError() diff --git a/qlib/model/riskmodel.py b/qlib/model/riskmodel.py index b5275213b..07a1e0c9f 100644 --- a/qlib/model/riskmodel.py +++ b/qlib/model/riskmodel.py @@ -23,9 +23,9 @@ class RiskModel(BaseModel): def __init__(self, nan_option: str = "ignore", assume_centered: bool = False, scale_return: bool = True): """ Args: - nan_option (str): nan handling option (`ignore`/`mask`/`fill`) - assume_centered (bool): whether the data is assumed to be centered - scale_return (bool): whether scale returns as percentage + nan_option (str): nan handling option (`ignore`/`mask`/`fill`). + assume_centered (bool): whether the data is assumed to be centered. + scale_return (bool): whether scale returns as percentage. """ # nan assert nan_option in [ @@ -45,11 +45,11 @@ class RiskModel(BaseModel): Args: X (pd.Series, pd.DataFrame or np.ndarray): data from which to estimate the covariance, with variables as columns and observations as rows. - return_corr (bool): whether return the correlation matrix - is_price (bool): whether `X` contains price (if not assume stock returns) + return_corr (bool): whether return the correlation matrix. + is_price (bool): whether `X` contains price (if not assume stock returns). Returns: - pd.DataFrame or np.ndarray: estimated covariance (or correlation) + pd.DataFrame or np.ndarray: estimated covariance (or correlation). """ # transform input into 2D array if not isinstance(X, (pd.Series, pd.DataFrame)): @@ -101,10 +101,10 @@ class RiskModel(BaseModel): By default, this method implements the empirical covariance estimation. Args: - X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows) + X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows). Returns: - np.ndarray: covariance matrix + np.ndarray: covariance matrix. """ xTx = np.asarray(X.T.dot(X)) N = len(X) @@ -117,7 +117,7 @@ class RiskModel(BaseModel): """handle nan and centerize data Note: - if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray` + if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`. """ # handle nan if self.nan_option == self.FILL_NAN: @@ -139,15 +139,15 @@ class ShrinkCovEstimator(RiskModel): where `alpha` is the shrink parameter and `F` is the shrinking target. The following shrinking parameters (`alpha`) are supported: - - `lw` [1][2][3]: use Ledoit-Wolf shrinking parameter - - `oas` [4]: use Oracle Approximating Shrinkage shrinking parameter - - float: directly specify the shrink parameter, should be between [0, 1] + - `lw` [1][2][3]: use Ledoit-Wolf shrinking parameter. + - `oas` [4]: use Oracle Approximating Shrinkage shrinking parameter. + - float: directly specify the shrink parameter, should be between [0, 1]. The following shrinking targets (`F`) are supported: - - `const_var` [1][4][5]: assume stocks have the same constant variance and zero correlation - - `const_corr` [2][6]: assume stocks have different variance but equal correlation - - `single_factor` [3][7]: assume single factor model as the shrinking target - - np.ndarray: provide the shrinking targets directly + - `const_var` [1][4][5]: assume stocks have the same constant variance and zero correlation. + - `const_corr` [2][6]: assume stocks have different variance but equal correlation. + - `single_factor` [3][7]: assume single factor model as the shrinking target. + - np.ndarray: provide the shrinking targets directly. Note: - The optimal shrinking parameter depends on the selection of the shrinking target. @@ -402,13 +402,13 @@ class POETCovEstimator(RiskModel): def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = "soft", **kwargs): """ Args: - num_factors (int): number of factors (if set to zero, no factor model will be used) - thresh (float): the positive constant for thresholding + num_factors (int): number of factors (if set to zero, no factor model will be used). + thresh (float): the positive constant for thresholding. thresh_method (str): thresholding method, which can be - - 'soft': soft thresholding - - 'hard': hard thresholding - - 'scad': scad thresholding - kwargs: see `RiskModel` for more information + - 'soft': soft thresholding. + - 'hard': hard thresholding. + - 'scad': scad thresholding. + kwargs: see `RiskModel` for more information. """ super().__init__(**kwargs)