mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge pull request #650 from microsoft/backtest_improve
Improve the backtest design and APIs
This commit is contained in:
17
CHANGES.rst
17
CHANGES.rst
@@ -159,6 +159,21 @@ Version 0.5.0
|
||||
- Add baselines
|
||||
- public data crawler
|
||||
|
||||
Version greater than Version 0.5.0
|
||||
|
||||
Version 0.8.0
|
||||
--------------------
|
||||
- The backtest is greatly refactored.
|
||||
- Nested decision execution framework is supported
|
||||
- There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed
|
||||
- The trading limitation is more accurate;
|
||||
- In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`_, longing and shorting actions share the same action.
|
||||
- In `current verison <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`_, the trading limitation is different between loging and shorting action.
|
||||
- The constant is different when calculating annualized metrics.
|
||||
- `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`_
|
||||
- `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.
|
||||
- Users could chec kout the backtesting results between `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`_ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`_
|
||||
|
||||
|
||||
Other Versions
|
||||
----------------------------------
|
||||
Please refer to `Github release Notes <https://github.com/microsoft/qlib/releases>`_
|
||||
|
||||
@@ -53,6 +53,9 @@ Below is a typical config file of ``qrun``.
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
@@ -240,6 +243,9 @@ The following script is the configuration of `backtest` and the `strategy` used
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -86,4 +87,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -100,4 +101,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -35,8 +35,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -94,4 +95,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -14,7 +14,7 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
|
||||
@@ -33,6 +33,9 @@ port_analysis_config: &port_analysis_config
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
@@ -80,4 +83,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -76,4 +77,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -31,8 +31,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -41,8 +41,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -98,4 +99,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -13,6 +13,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
>
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ -->
|
||||
|
||||
> NOTE:
|
||||
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
|
||||
|
||||
|
||||
## Alpha158 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -30,8 +30,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -94,4 +95,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -16,8 +16,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -57,8 +57,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -151,10 +151,9 @@ class NestedDecisionExecutionWorkflow:
|
||||
self._train_model(model, dataset)
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
@@ -189,10 +188,9 @@ class NestedDecisionExecutionWorkflow:
|
||||
backtest_config["benchmark"] = self.benchmark
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
|
||||
@@ -31,10 +31,9 @@ if __name__ == "__main__":
|
||||
},
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
|
||||
@@ -152,8 +152,11 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
:param conf_path: A path to the qlib config in yml format
|
||||
"""
|
||||
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
if conf_path is None:
|
||||
config = {}
|
||||
else:
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
@@ -216,7 +219,7 @@ def auto_init(**kwargs):
|
||||
.. code-block:: yaml
|
||||
|
||||
conf_type: ref
|
||||
qlib_cfg: '<shared_yaml_config_path>'
|
||||
qlib_cfg: '<shared_yaml_config_path>' # this could be null reference no config from other files
|
||||
# following configs in `qlib_cfg_update` is project=specific
|
||||
qlib_cfg_update:
|
||||
exp_manager:
|
||||
@@ -246,6 +249,7 @@ def auto_init(**kwargs):
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
logger = get_module_logger("Initialization")
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
@@ -259,8 +263,14 @@ def auto_init(**kwargs):
|
||||
# - There is a shared configure file and you don't want to edit it inplace.
|
||||
# - The shared configure may be updated later and you don't want to copy it.
|
||||
# - You have some customized config.
|
||||
qlib_conf_path = conf["qlib_cfg"]
|
||||
qlib_conf_update = conf.get("qlib_cfg_update")
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
|
||||
logger = get_module_logger("Initialization")
|
||||
qlib_conf_path = conf.get("qlib_cfg", None)
|
||||
|
||||
# merge the arguments
|
||||
qlib_conf_update = conf.get("qlib_cfg_update", {})
|
||||
for k, v in kwargs.items():
|
||||
if k in qlib_conf_update:
|
||||
logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'")
|
||||
qlib_conf_update.update(kwargs)
|
||||
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update)
|
||||
logger.info(f"Auto load project config: {conf_pp}")
|
||||
|
||||
@@ -34,6 +34,7 @@ class Exchange:
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
impact_cost=0.0,
|
||||
extra_quote=None,
|
||||
quote_cls=NumpyQuote,
|
||||
**kwargs,
|
||||
@@ -95,6 +96,7 @@ class Exchange:
|
||||
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
|
||||
distinguish `not set` and `disable trade_unit`
|
||||
:param min_cost: min cost, default 5
|
||||
:param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1.
|
||||
:param extra_quote: pandas, dataframe consists of
|
||||
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
|
||||
The limit indicates that the etf is tradable on a specific day.
|
||||
@@ -164,9 +166,12 @@ class Exchange:
|
||||
all_fields = list(all_fields | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
|
||||
self.open_cost = open_cost
|
||||
self.close_cost = close_cost
|
||||
self.min_cost = min_cost
|
||||
self.impact_cost = impact_cost
|
||||
|
||||
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
|
||||
self.volume_threshold = volume_threshold
|
||||
self.extra_quote = extra_quote
|
||||
@@ -685,12 +690,14 @@ class Exchange:
|
||||
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
|
||||
)
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash):
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
|
||||
"""return the real order amount after cash limit for buying.
|
||||
Parameters
|
||||
----------
|
||||
trade_price : float
|
||||
position : cash
|
||||
cost_ratio : float
|
||||
|
||||
Return
|
||||
----------
|
||||
float
|
||||
@@ -699,10 +706,10 @@ class Exchange:
|
||||
max_trade_amount = 0
|
||||
if cash >= self.min_cost:
|
||||
# critical_price means the stock transaction price when the service fee is equal to min_cost.
|
||||
critical_price = self.min_cost / self.open_cost + self.min_cost
|
||||
critical_price = self.min_cost / cost_ratio + self.min_cost
|
||||
if cash >= critical_price:
|
||||
# the service fee is equal to open_cost * trade_amount
|
||||
max_trade_amount = cash / (1 + self.open_cost) / trade_price
|
||||
# the service fee is equal to cost_ratio * trade_amount
|
||||
max_trade_amount = cash / (1 + cost_ratio) / trade_price
|
||||
else:
|
||||
# the service fee is equal to min_cost
|
||||
max_trade_amount = (cash - self.min_cost) / trade_price
|
||||
@@ -718,6 +725,7 @@ class Exchange:
|
||||
:return: trade_price, trade_val, trade_cost
|
||||
"""
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
order.deal_amount = order.amount # set to full amount and clip it step by step
|
||||
# Clipping amount first
|
||||
@@ -726,8 +734,12 @@ class Exchange:
|
||||
# - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
|
||||
self._clip_amount_by_volume(order, dealt_order_amount)
|
||||
|
||||
# TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps
|
||||
trade_val = order.deal_amount * trade_price
|
||||
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
|
||||
|
||||
if order.direction == Order.SELL:
|
||||
cost_ratio = self.close_cost
|
||||
cost_ratio = self.close_cost + adj_cost_ratio
|
||||
# sell
|
||||
# if we don't know current position, we choose to sell all
|
||||
# Otherwise, we clip the amount based on current position
|
||||
@@ -750,14 +762,18 @@ class Exchange:
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
|
||||
elif order.direction == Order.BUY:
|
||||
cost_ratio = self.open_cost
|
||||
cost_ratio = self.open_cost + adj_cost_ratio
|
||||
# buy
|
||||
if position is not None:
|
||||
cash = position.get_cash()
|
||||
trade_val = order.deal_amount * trade_price
|
||||
if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
|
||||
if cash < max(trade_val * cost_ratio, self.min_cost):
|
||||
# cash cannot cover cost
|
||||
order.deal_amount = 0
|
||||
self.logger.debug(f"Order clipped due to cost higher than cash: {order}")
|
||||
elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
|
||||
# The money is not enough
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(max_buy_amount, order.deal_amount), order.factor
|
||||
)
|
||||
|
||||
@@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote):
|
||||
if is_single_value(start_time, end_time, self.freq, self.region):
|
||||
# this is a very special case.
|
||||
# skip aggregating function to speed-up the query calculation
|
||||
|
||||
# FIXME:
|
||||
# it will go to the else logic when it comes to the
|
||||
# 1) the day before holiday when daily trading
|
||||
# 2) the last minute of the day when intraday trading
|
||||
try:
|
||||
return self.data[stock_id].loc[start_time, field]
|
||||
except KeyError:
|
||||
|
||||
102
qlib/backtest/signal.py
Normal file
102
qlib/backtest/signal.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.utils import init_instance_by_config
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import Dataset
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..utils.resam import resam_ts_data
|
||||
import pandas as pd
|
||||
import abc
|
||||
|
||||
|
||||
class Signal(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Some trading strategy make decisions based on other prediction signals
|
||||
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
|
||||
|
||||
This interface is tries to provide unified interface for those different sources
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
"""
|
||||
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[pd.Series, pd.DataFrame, None]:
|
||||
returns None if no signal in the specific day
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SignalWCache(Signal):
|
||||
"""
|
||||
Signal With pandas with based Cache
|
||||
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
||||
"""
|
||||
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : Union[pd.Series, pd.DataFrame]
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
|
||||
|
||||
instrument datetime
|
||||
SH600000 2008-01-02 0.079704
|
||||
2008-01-03 0.120125
|
||||
2008-01-04 0.878860
|
||||
2008-01-07 0.505539
|
||||
2008-01-08 0.395004
|
||||
"""
|
||||
self.signal_cache = convert_index_format(signal, level="datetime")
|
||||
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not algin with the decision frequency of strategy
|
||||
# so resampling from the data is necessary
|
||||
# the latest signal leverage more recent data and therefore is used in trading.
|
||||
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
||||
return signal
|
||||
|
||||
|
||||
class ModelSignal(SignalWCache):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset):
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
pred_scores = self.model.predict(dataset)
|
||||
if isinstance(pred_scores, pd.DataFrame):
|
||||
pred_scores = pred_scores.iloc[:, 0]
|
||||
super().__init__(pred_scores)
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
When using online data, update model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
"""
|
||||
# TODO: this method is not included in the framework and could be refactor later
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
|
||||
def create_signal_from(
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
|
||||
) -> Signal:
|
||||
"""
|
||||
create signal from diverse information
|
||||
This method will choose the right method to create a signal based on `obj`
|
||||
Please refer to the code below.
|
||||
"""
|
||||
if isinstance(obj, Signal):
|
||||
return obj
|
||||
elif isinstance(obj, (tuple, list)):
|
||||
return ModelSignal(*obj)
|
||||
elif isinstance(obj, (dict, str)):
|
||||
return init_instance_by_config(obj)
|
||||
elif isinstance(obj, (pd.DataFrame, pd.Series)):
|
||||
return SignalWCache(signal=obj)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of signal is not supported")
|
||||
@@ -70,7 +70,7 @@ class TradeCalendarManager:
|
||||
- If self.trade_step >= self.self.trade_len, it means the trading is finished
|
||||
- If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step
|
||||
"""
|
||||
return self.trade_step >= self.trade_len
|
||||
return self.trade_step >= self.trade_len - 1
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
|
||||
0
qlib/contrib/data/utils/__init__.py
Normal file
0
qlib/contrib/data/utils/__init__.py
Normal file
183
qlib/contrib/data/utils/sepdf.py
Normal file
183
qlib/contrib/data/utils/sepdf.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import pandas as pd
|
||||
from typing import Dict, Iterable
|
||||
|
||||
|
||||
def align_index(df_dict, join):
|
||||
res = {}
|
||||
for k, df in df_dict.items():
|
||||
if join is not None and k != join:
|
||||
df = df.reindex(df_dict[join].index)
|
||||
res[k] = df
|
||||
return res
|
||||
|
||||
|
||||
# Mocking the pd.DataFrame class
|
||||
class SepDataFrame:
|
||||
"""
|
||||
(Sep)erate DataFrame
|
||||
We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter).
|
||||
However, they are usally be used seperately at last.
|
||||
This will result in extra cost for concating and spliting data(reshaping and copying data in the memory is very expensive)
|
||||
|
||||
SepDataFrame tries to act like a DataFrame whose column with multiindex
|
||||
"""
|
||||
|
||||
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
|
||||
"""
|
||||
initialize the data based on the dataframe dictionary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_dict : Dict[str, pd.DataFrame]
|
||||
dataframe dictionary
|
||||
join : str
|
||||
how to join the data
|
||||
It will reindex the dataframe based on the join key.
|
||||
If join is None, the reindex step will be skipped
|
||||
|
||||
skip_align :
|
||||
for some cases, we can improve performance by skipping aligning index
|
||||
"""
|
||||
self.join = join
|
||||
|
||||
if skip_align:
|
||||
self._df_dict = df_dict
|
||||
else:
|
||||
self._df_dict = align_index(df_dict, join)
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
return SDFLoc(self, join=self.join)
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
return self._df_dict[self.join].index
|
||||
|
||||
def apply_each(self, method: str, skip_align=True, *args, **kwargs):
|
||||
"""
|
||||
Assumptions:
|
||||
- inplace methods will return None
|
||||
"""
|
||||
inplace = False
|
||||
df_dict = {}
|
||||
for k, df in self._df_dict.items():
|
||||
df_dict[k] = getattr(df, method)(*args, **kwargs)
|
||||
if df_dict[k] is None:
|
||||
inplace = True
|
||||
if not inplace:
|
||||
return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align)
|
||||
|
||||
def sort_index(self, *args, **kwargs):
|
||||
return self.apply_each("sort_index", True, *args, **kwargs)
|
||||
|
||||
def copy(self, *args, **kwargs):
|
||||
return self.apply_each("copy", True, *args, **kwargs)
|
||||
|
||||
def _update_join(self):
|
||||
if self.join not in self:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._df_dict[item]
|
||||
|
||||
def __setitem__(self, item: str, df: pd.DataFrame):
|
||||
# TODO: consider the join behavior
|
||||
self._df_dict[item] = df
|
||||
|
||||
def __delitem__(self, item: str):
|
||||
del self._df_dict[item]
|
||||
self._update_join()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._df_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._df_dict[self.join])
|
||||
|
||||
def droplevel(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `droplevel` method")
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
dfs = []
|
||||
for k, df in self._df_dict.items():
|
||||
df = df.head(0)
|
||||
df.columns = pd.MultiIndex.from_product([[k], df.columns])
|
||||
dfs.append(df)
|
||||
return pd.concat(dfs, axis=1).columns
|
||||
|
||||
# Useless methods
|
||||
@staticmethod
|
||||
def merge(df_dict: Dict[str, pd.DataFrame], join: str):
|
||||
all_df = df_dict[join]
|
||||
for k, df in df_dict.items():
|
||||
if k != join:
|
||||
all_df = all_df.join(df)
|
||||
return all_df
|
||||
|
||||
|
||||
class SDFLoc:
|
||||
"""Mock Class"""
|
||||
|
||||
def __init__(self, sdf: SepDataFrame, join):
|
||||
self._sdf = sdf
|
||||
self.axis = None
|
||||
self.join = join
|
||||
|
||||
def __call__(self, axis):
|
||||
self.axis = axis
|
||||
return self
|
||||
|
||||
def __getitem__(self, args):
|
||||
if self.axis == 1:
|
||||
if isinstance(args, str):
|
||||
return self._sdf[args]
|
||||
elif isinstance(args, (tuple, list)):
|
||||
new_df_dict = {k: self._sdf[k] for k in args}
|
||||
return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
elif self.axis == 0:
|
||||
return SepDataFrame(
|
||||
{k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True
|
||||
)
|
||||
else:
|
||||
df = self._sdf
|
||||
if isinstance(args, tuple):
|
||||
ax0, *ax1 = args
|
||||
if len(ax1) == 0:
|
||||
ax1 = None
|
||||
if ax1 is not None:
|
||||
df = df.loc(axis=1)[ax1]
|
||||
if ax0 is not None:
|
||||
df = df.loc(axis=0)[ax0]
|
||||
return df
|
||||
else:
|
||||
return df.loc(axis=0)[args]
|
||||
|
||||
|
||||
# Patch pandas DataFrame
|
||||
# Tricking isinstance to accept SepDataFrame as its subclass
|
||||
import builtins
|
||||
|
||||
|
||||
def _isinstance(instance, cls):
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
|
||||
if isinstance(cls, Iterable):
|
||||
for c in cls:
|
||||
if c is pd.DataFrame:
|
||||
return True
|
||||
elif cls is pd.DataFrame:
|
||||
return True
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602
|
||||
|
||||
|
||||
builtins.isinstance_orig = builtins.isinstance
|
||||
builtins.isinstance = _isinstance
|
||||
|
||||
if __name__ == "__main__":
|
||||
sdf = SepDataFrame({}, join=None)
|
||||
print(isinstance(sdf, (pd.DataFrame,)))
|
||||
print(isinstance(sdf, pd.DataFrame))
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .model_strategy import (
|
||||
from .signal_strategy import (
|
||||
TopkDropoutStrategy,
|
||||
WeightStrategyBase,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ This strategy is not well maintained
|
||||
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
from .model_strategy import WeightStrategyBase
|
||||
from .signal_strategy import WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
|
||||
@@ -80,18 +80,22 @@ class OrderGenWInteract(OrderGenerator):
|
||||
|
||||
:rtype: list
|
||||
"""
|
||||
if target_weight_position is None:
|
||||
return []
|
||||
|
||||
# calculate current_tradable_value
|
||||
current_amount_dict = current.get_stock_amount_dict()
|
||||
|
||||
current_total_value = trade_exchange.calculate_amount_position_value(
|
||||
amount_dict=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
only_tradable=False,
|
||||
)
|
||||
current_tradable_value = trade_exchange.calculate_amount_position_value(
|
||||
amount_dict=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
only_tradable=True,
|
||||
)
|
||||
# add cash
|
||||
@@ -105,9 +109,7 @@ class OrderGenWInteract(OrderGenerator):
|
||||
# value. Then just sell all the stocks
|
||||
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
|
||||
for stock_id in list(target_amount_dict.keys()):
|
||||
if trade_exchange.is_stock_tradable(
|
||||
stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
):
|
||||
if trade_exchange.is_stock_tradable(stock_id, start_time=trade_start_time, end_time=trade_end_time):
|
||||
del target_amount_dict[stock_id]
|
||||
else:
|
||||
# consider cost rate
|
||||
@@ -118,16 +120,16 @@ class OrderGenWInteract(OrderGenerator):
|
||||
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
|
||||
weight_position=target_weight_position,
|
||||
cash=current_tradable_value,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
order_list = trade_exchange.generate_order_for_target_amount_position(
|
||||
target_position=target_amount_dict,
|
||||
current_position=current_amount_dict,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
return TradeDecisionWO(order_list, self)
|
||||
return order_list
|
||||
|
||||
|
||||
class OrderGenWOInteract(OrderGenerator):
|
||||
@@ -163,8 +165,11 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
:param trade_date:
|
||||
:type trade_date: pd.Timestamp
|
||||
|
||||
:rtype: list
|
||||
:rtype: list of generated orders
|
||||
"""
|
||||
if target_weight_position is None:
|
||||
return []
|
||||
|
||||
risk_total_value = risk_degree * current.calculate_value()
|
||||
|
||||
current_stock = current.get_stock_list()
|
||||
@@ -172,13 +177,17 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
for stock_id in target_weight_position:
|
||||
# Current rule will ignore the stock that not hold and cannot be traded at predict date
|
||||
if trade_exchange.is_stock_tradable(
|
||||
stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
stock_id=stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
) and trade_exchange.is_stock_tradable(
|
||||
stock_id=stock_id, start_time=pred_start_time, end_time=pred_end_time
|
||||
):
|
||||
amount_dict[stock_id] = (
|
||||
risk_total_value
|
||||
* target_weight_position[stock_id]
|
||||
/ trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time)
|
||||
/ trade_exchange.get_close(stock_id, start_time=pred_start_time, end_time=pred_end_time)
|
||||
)
|
||||
# TODO: Qlib use None to represent trading suspension. So last close price can't be the estimated trading price.
|
||||
# Maybe a close price with forward fill will be a better solution.
|
||||
elif stock_id in current_stock:
|
||||
amount_dict[stock_id] = (
|
||||
risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id)
|
||||
@@ -188,7 +197,7 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
order_list = trade_exchange.generate_order_for_target_amount_position(
|
||||
target_position=amount_dict,
|
||||
current_position=current.get_stock_amount_dict(),
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
return TradeDecisionWO(order_list, self)
|
||||
return order_list
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
@@ -1,27 +1,33 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import copy
|
||||
from qlib.backtest.signal import Signal, create_signal_from
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.model.base import BaseModel
|
||||
from qlib.backtest.position import Position
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
class TopkDropoutStrategy(ModelStrategy):
|
||||
class TopkDropoutStrategy(BaseStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
dataset,
|
||||
*,
|
||||
topk,
|
||||
n_drop,
|
||||
signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None,
|
||||
method_sell="bottom",
|
||||
method_buy="top",
|
||||
risk_degree=0.95,
|
||||
@@ -30,6 +36,8 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
trade_exchange=None,
|
||||
level_infra=None,
|
||||
common_infra=None,
|
||||
model=None,
|
||||
dataset=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -39,6 +47,9 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
the number of stocks in the portfolio.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
signal :
|
||||
the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
|
||||
the decision of the strategy will base on the given signal
|
||||
method_sell : str
|
||||
dropout method_sell, random/bottom.
|
||||
method_buy : str
|
||||
@@ -64,7 +75,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
@@ -74,6 +85,13 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
# This is trying to be compatible with previous version of qlib task config
|
||||
if model is not None and dataset is not None:
|
||||
warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning)
|
||||
signal = model, dataset
|
||||
|
||||
self.signal: Signal = create_signal_from(signal)
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
@@ -87,7 +105,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
|
||||
if pred_score is None:
|
||||
return TradeDecisionWO([], self)
|
||||
if self.only_tradable:
|
||||
@@ -235,15 +253,15 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
return TradeDecisionWO(sell_order_list + buy_order_list, self)
|
||||
|
||||
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
class WeightStrategyBase(BaseStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
dataset,
|
||||
*,
|
||||
signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
level_infra=None,
|
||||
@@ -251,6 +269,9 @@ class WeightStrategyBase(ModelStrategy):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
signal :
|
||||
the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
|
||||
the decision of the strategy will base on the given signal
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
@@ -260,13 +281,15 @@ class WeightStrategyBase(ModelStrategy):
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
|
||||
self.signal: Signal = create_signal_from(signal)
|
||||
|
||||
def get_risk_degree(self, trade_step=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
@@ -298,7 +321,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
|
||||
if pred_score is None:
|
||||
return TradeDecisionWO([], self)
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
@@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp):
|
||||
|
||||
if save:
|
||||
save_name = "results-{:}.pkl".format(key)
|
||||
self.recorder.save_objects(**{save_name: results})
|
||||
self.save(**{save_name: results})
|
||||
logger.info(
|
||||
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
|
||||
save_name, self.recorder.experiment_id
|
||||
@@ -79,9 +79,8 @@ class SignalMseRecord(RecordTemp):
|
||||
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
return paths
|
||||
return ["mse.pkl", "rmse.pkl"]
|
||||
|
||||
@@ -320,6 +320,7 @@ class TSDataSampler:
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
self.idx_map = self.idx_map2arr(self.idx_map)
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
start=time_to_slc_point(start), end=time_to_slc_point(end)
|
||||
@@ -328,6 +329,25 @@ class TSDataSampler:
|
||||
|
||||
del self.data # save memory
|
||||
|
||||
@staticmethod
|
||||
def idx_map2arr(idx_map):
|
||||
# pytorch data sampler will have better memory control without large dict or list
|
||||
# - https://github.com/pytorch/pytorch/issues/13243
|
||||
# - https://github.com/airctic/icevision/issues/613
|
||||
# So we convert the dict into int array.
|
||||
# The arr_map is expected to behave the same as idx_map
|
||||
|
||||
dtype = np.int32
|
||||
# set a index out of bound to indicate the none existing
|
||||
no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max)
|
||||
|
||||
max_idx = max(idx_map.keys())
|
||||
arr_map = []
|
||||
for i in range(max_idx + 1):
|
||||
arr_map.append(idx_map.get(i, no_existing_idx))
|
||||
arr_map = np.array(arr_map, dtype=dtype)
|
||||
return arr_map
|
||||
|
||||
@staticmethod
|
||||
def flt_idx_map(flt_data, idx_map):
|
||||
idx = 0
|
||||
@@ -524,20 +544,18 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
# make sure the calendar is updated to latest when loading data from new config
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
self.cal = cal
|
||||
self.cal = sorted(cal)
|
||||
|
||||
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
|
||||
@staticmethod
|
||||
def _extend_slice(slc: slice, cal: list, step_len: int) -> slice:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
start, end = slc.start, slc.stop
|
||||
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
|
||||
pad_start_idx = max(0, start_idx - self.step_len)
|
||||
pad_start = self.cal[pad_start_idx]
|
||||
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
|
||||
return data
|
||||
start_idx = bisect.bisect_left(cal, pd.Timestamp(start))
|
||||
pad_start_idx = max(0, start_idx - step_len)
|
||||
pad_start = cal[pad_start_idx]
|
||||
return slice(pad_start, end)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
@@ -546,13 +564,15 @@ class TSDatasetH(DatasetH):
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = self._prepare_raw_seg(slc, **kwargs)
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
|
||||
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
|
||||
data = super()._prepare_seg(ext_slice, **kwargs)
|
||||
|
||||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs["col_set"] = flt_col
|
||||
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
|
||||
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
@@ -82,8 +82,6 @@ class DataHandler(Serializable):
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
"""
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
|
||||
# Setup data loader
|
||||
assert data_loader is not None # to make start_time end_time could have None default value
|
||||
@@ -302,6 +300,7 @@ class DataHandlerLP(DataHandler):
|
||||
DK_R = "raw"
|
||||
DK_I = "infer"
|
||||
DK_L = "learn"
|
||||
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
|
||||
|
||||
# process type
|
||||
PTYPE_I = "independent"
|
||||
@@ -543,7 +542,7 @@ class DataHandlerLP(DataHandler):
|
||||
raise AttributeError(
|
||||
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
|
||||
)
|
||||
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
|
||||
df = getattr(self, self.ATTR_MAP[data_key])
|
||||
return df
|
||||
|
||||
def fetch(
|
||||
@@ -624,3 +623,33 @@ class DataHandlerLP(DataHandler):
|
||||
df = self._get_df_by_key(data_key).head()
|
||||
df = fetch_df_by_col(df, col_set)
|
||||
return df.columns.to_list()
|
||||
|
||||
@classmethod
|
||||
def cast(cls, handler: "DataHandlerLP") -> "DataHandlerLP":
|
||||
"""
|
||||
Motivation
|
||||
- A user create a datahandler in his customized package. Then he want to share the processed handler to other users without introduce the package dependency and complicated data processing logic.
|
||||
- This class make it possible by casting the class to DataHandlerLP and only keep the processed data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler : DataHandlerLP
|
||||
A subclass of DataHandlerLP
|
||||
|
||||
Returns
|
||||
-------
|
||||
DataHandlerLP:
|
||||
the converted processed data
|
||||
"""
|
||||
new_hd: DataHandlerLP = object.__new__(DataHandlerLP)
|
||||
new_hd.from_cast = True # add a mark for the casted instance
|
||||
|
||||
for key in list(DataHandlerLP.ATTR_MAP.values()) + [
|
||||
"instruments",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"fetch_orig",
|
||||
"drop_raw",
|
||||
]:
|
||||
setattr(new_hd, key, getattr(handler, key, None))
|
||||
return new_hd
|
||||
|
||||
@@ -17,7 +17,7 @@ from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
|
||||
__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
|
||||
|
||||
class BaseStrategy:
|
||||
@@ -194,45 +194,6 @@ class BaseStrategy:
|
||||
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
"""Model-based trading strategy, use model to make predictions for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : BaseModel
|
||||
the model used in when making predictions
|
||||
dataset : DatasetH
|
||||
provide test data for model
|
||||
kwargs : dict
|
||||
arguments that will be passed into `reset` method
|
||||
"""
|
||||
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
|
||||
if isinstance(self.pred_scores, pd.DataFrame):
|
||||
self.pred_scores = self.pred_scores.iloc[:, 0]
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
When using online data, pdate model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
"""
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
|
||||
class RLStrategy(BaseStrategy):
|
||||
"""RL-based strategy"""
|
||||
|
||||
|
||||
@@ -199,6 +199,7 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
|
||||
----------
|
||||
config : [dict, str]
|
||||
similar to config
|
||||
please refer to the doc of init_instance_by_config
|
||||
|
||||
default_module : Python module or str
|
||||
It should be a python module to load the class type
|
||||
@@ -219,9 +220,12 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
|
||||
_callable = config["class"] # the class type itself is passed in
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
# a.b.c.ClassName
|
||||
*m_path, cls = config.split(".")
|
||||
m_path = ".".join(m_path)
|
||||
module = get_module_by_module_path(default_module if m_path == "" else m_path)
|
||||
|
||||
_callable = getattr(module, config)
|
||||
_callable = getattr(module, cls)
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -260,7 +264,9 @@ def init_instance_by_config(
|
||||
1) specify a pickle object
|
||||
- path like 'file:///<path to pickle file>/obj.pkl'
|
||||
2) specify a class name
|
||||
- "ClassName": getattr(module, config)() will be used.
|
||||
- "ClassName": getattr(module, "ClassName")() will be used.
|
||||
3) specify module path with class name
|
||||
- "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
|
||||
object example:
|
||||
instance of accept_types
|
||||
default_module : Python module
|
||||
|
||||
@@ -401,6 +401,10 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
def columns(self):
|
||||
return self.indices[1]
|
||||
|
||||
def __getitem__(self, args):
|
||||
# NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean
|
||||
return self.iloc[args]
|
||||
|
||||
def _align_indices(self, other: "IndexData") -> "IndexData":
|
||||
"""
|
||||
Align all indices of `other` to `self` before performing the arithmetic operations.
|
||||
@@ -409,7 +413,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
Parameters
|
||||
----------
|
||||
other : "IndexData"
|
||||
the index in `other` is to be chagned
|
||||
the index in `other` is to be changed
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -455,7 +459,8 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
"""
|
||||
return len(self.data)
|
||||
|
||||
def sum(self, axis=None):
|
||||
def sum(self, axis=None, dtype=None, out=None):
|
||||
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
|
||||
# FIXME: weird logic and not general
|
||||
if axis is None:
|
||||
return np.nansum(self.data)
|
||||
@@ -468,7 +473,8 @@ class IndexData(metaclass=index_data_ops_creator):
|
||||
else:
|
||||
raise ValueError(f"axis must be None, 0 or 1")
|
||||
|
||||
def mean(self, axis=None):
|
||||
def mean(self, axis=None, dtype=None, out=None):
|
||||
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
|
||||
# FIXME: weird logic and not general
|
||||
if axis is None:
|
||||
return np.nanmean(self.data)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
import pandas as pd
|
||||
|
||||
from queue import Queue
|
||||
|
||||
|
||||
class ParallelExt(Parallel):
|
||||
@@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
|
||||
return pd.concat(dfs, axis=axis).sort_index()
|
||||
else:
|
||||
return _naive_group_apply(df)
|
||||
|
||||
|
||||
class AsyncCaller:
|
||||
"""
|
||||
This AsyncCaller tries to make it easier to async call
|
||||
|
||||
Currently, it is used in MLflowRecorder to make functions like `log_params` async
|
||||
|
||||
NOTE:
|
||||
- This caller didn't consider the return value
|
||||
"""
|
||||
|
||||
STOP_MARK = "__STOP"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._q = Queue()
|
||||
self._stop = False
|
||||
self._t = Thread(target=self.run)
|
||||
self._t.start()
|
||||
|
||||
def close(self):
|
||||
self._q.put(self.STOP_MARK)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
data = self._q.get()
|
||||
if data == self.STOP_MARK:
|
||||
break
|
||||
else:
|
||||
data()
|
||||
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
self._q.put(partial(func, *args, **kwargs))
|
||||
|
||||
def wait(self, close=True):
|
||||
if close:
|
||||
self.close()
|
||||
self._t.join()
|
||||
|
||||
@staticmethod
|
||||
def async_dec(ac_attr):
|
||||
def decorator_func(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if isinstance(getattr(self, ac_attr, None), Callable):
|
||||
return getattr(self, ac_attr)(func, self, *args, **kwargs)
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_func
|
||||
|
||||
@@ -21,19 +21,65 @@ Situations Description
|
||||
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
|
||||
will train models task by task and strategy by strategy.
|
||||
|
||||
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
|
||||
nothing until all tasks have been prepared. It makes user can train all tasks in
|
||||
the end of `routine` or `first_train`.
|
||||
Online + DelayTrainer DelayTrainer will skip concrete training until all tasks have been prepared by
|
||||
different strategies. It makes users can parallelly train all tasks at the end of
|
||||
`routine` or `first_train`. Otherwise, these functions will get stuck when each
|
||||
strategy prepare tasks.
|
||||
|
||||
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
|
||||
need to consider using Trainer. This means it will REAL train your models in
|
||||
every routine and prepare signals for every routine.
|
||||
Simulation + Trainer It will behave in the same way as `Online + Trainer`. The only difference is that it
|
||||
is for simulation/backtesting instead of online trading
|
||||
|
||||
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
|
||||
for the ability to multitasking. It means all tasks in all routines
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
different time segments (based on whether or not any new model is online).
|
||||
========================= ===================================================================================
|
||||
|
||||
Here is some pseudo code the demonstrate the workflow of each situation
|
||||
|
||||
For simplicity
|
||||
- Only one strategy is used in the strategy
|
||||
- `update_online_pred` is only called in the online mode and is ignored
|
||||
|
||||
1) `Online + Trainer`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tasks = first_train()
|
||||
models = trainer.train(tasks)
|
||||
trainer.end_train(models)
|
||||
for day in online_trading_days:
|
||||
# OnlineManager.routine
|
||||
models = trainer.train(strategy.prepare_tasks()) # for each strategy
|
||||
strategy.prepare_online_models(models) # for each strategy
|
||||
|
||||
trainer.end_train(models)
|
||||
prepare_signals() # prepare trading signals daily
|
||||
|
||||
|
||||
`Online + DelayTrainer`: the workflow is the same as `Online + Trainer`.
|
||||
|
||||
|
||||
2) `Simulation + DelayTrainer`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# simulate
|
||||
tasks = first_train()
|
||||
models = trainer.train(tasks)
|
||||
for day in historical_calendars:
|
||||
# OnlineManager.routine
|
||||
models = trainer.train(strategy.prepare_tasks()) # for each strategy
|
||||
strategy.prepare_online_models(models) # for each strategy
|
||||
# delay_prepare()
|
||||
# FIXME: Currently the delay_prepare is not implemented in a proper way.
|
||||
trainer.end_train(<for all previous models>)
|
||||
prepare_signals()
|
||||
|
||||
|
||||
# Can we simplify current workflow?
|
||||
- Can reduce the number of state of tasks?
|
||||
- For each task, we have three phases (i.e. task, partly trained task, final trained task)
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -58,7 +104,7 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
|
||||
STATUS_SIMULATING = "simulating" # when calling `simulate`
|
||||
STATUS_NORMAL = "normal" # the normal status
|
||||
STATUS_ONLINE = "online" # the normal status. It is used when online trading
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -87,12 +133,24 @@ class OnlineManager(Serializable):
|
||||
self.begin_time = pd.Timestamp(begin_time)
|
||||
self.cur_time = self.begin_time
|
||||
# OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
|
||||
# It records the online servnig models of each strategy for each day.
|
||||
self.history = {}
|
||||
if trainer is None:
|
||||
trainer = TrainerR()
|
||||
self.trainer = trainer
|
||||
self.signals = None
|
||||
self.status = self.STATUS_NORMAL
|
||||
self.status = self.STATUS_ONLINE
|
||||
|
||||
def _postpone_action(self):
|
||||
"""
|
||||
Should the workflow to postpone the following actions to the end (in delay_prepare)
|
||||
- trainer.end_train
|
||||
- prepare_signals
|
||||
|
||||
Postpone these actions is to support simulating/backtest online strategies without time dependencies.
|
||||
All the actions can be done parallelly at the end.
|
||||
"""
|
||||
return self.status == self.STATUS_SIMULATING and self.trainer.is_delay()
|
||||
|
||||
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
|
||||
"""
|
||||
@@ -113,12 +171,12 @@ class OnlineManager(Serializable):
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
|
||||
# FIXME: Train multiple online models at `first_train` will result in getting too much online models at the
|
||||
# start.
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
if not self._postpone_action():
|
||||
for strategy, models in zip(strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
|
||||
@@ -160,10 +218,10 @@ class OnlineManager(Serializable):
|
||||
|
||||
# The online model may changes in the above processes
|
||||
# So updating the predictions of online models should be the last step
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
if self.status == self.STATUS_ONLINE:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
if not self._postpone_action():
|
||||
for strategy, models in zip(self.strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
@@ -278,13 +336,13 @@ class OnlineManager(Serializable):
|
||||
signal_kwargs=signal_kwargs,
|
||||
)
|
||||
# delay prepare the models and signals
|
||||
if self.trainer.is_delay():
|
||||
if self._postpone_action():
|
||||
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
|
||||
|
||||
# FIXME: get logging level firstly and restore it here
|
||||
set_global_logger_level(logging.DEBUG)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
self.status = self.STATUS_NORMAL
|
||||
self.status = self.STATUS_ONLINE
|
||||
return self.get_signals()
|
||||
|
||||
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
|
||||
@@ -295,6 +353,8 @@ class OnlineManager(Serializable):
|
||||
model_kwargs: the params for `end_train`
|
||||
signal_kwargs: the params for `prepare_signals`
|
||||
"""
|
||||
# FIXME:
|
||||
# This method is not implemented in the proper way!!!
|
||||
last_models = {}
|
||||
signals_time = D.calendar()[0]
|
||||
need_prepare = False
|
||||
|
||||
@@ -9,6 +9,9 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import Union, List
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
@@ -45,6 +48,16 @@ class RecordTemp:
|
||||
|
||||
return "/".join(names)
|
||||
|
||||
def save(self, **kwargs):
|
||||
"""
|
||||
It behaves the same as self.recorder.save_objects.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
"""
|
||||
art_path = self.get_path()
|
||||
if art_path == "":
|
||||
art_path = None
|
||||
self.recorder.save_objects(artifact_path=art_path, **kwargs)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self._recorder = recorder
|
||||
|
||||
@@ -67,31 +80,37 @@ class RecordTemp:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `generate` method.")
|
||||
|
||||
def load(self, name):
|
||||
def load(self, name: str, parents: bool = True):
|
||||
"""
|
||||
Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API
|
||||
with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them
|
||||
in the future::
|
||||
|
||||
sar = SigAnaRecord(recorder)
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
It behaves the same as self.recorder.load_object.
|
||||
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
the name for the file to be load.
|
||||
|
||||
parents : bool
|
||||
Each recorder has different `artifact_path`.
|
||||
So parents recursively find the path in parents
|
||||
Sub classes has higher priority
|
||||
|
||||
Return
|
||||
------
|
||||
The stored records.
|
||||
"""
|
||||
# try to load the saved object
|
||||
obj = self.recorder.load_object(name)
|
||||
return obj
|
||||
try:
|
||||
return self.recorder.load_object(self.get_path(name))
|
||||
except LoadObjectError:
|
||||
if parents:
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
return self.load(name, parents=True)
|
||||
|
||||
def list(self):
|
||||
"""
|
||||
List the supported artifacts.
|
||||
Users don't have to consider self.get_path
|
||||
|
||||
Return
|
||||
------
|
||||
@@ -99,7 +118,7 @@ class RecordTemp:
|
||||
"""
|
||||
return []
|
||||
|
||||
def check(self, include_self: bool = False):
|
||||
def check(self, include_self: bool = False, parents: bool = True):
|
||||
"""
|
||||
Check if the records is properly generated and saved.
|
||||
It is useful in following examples
|
||||
@@ -110,19 +129,34 @@ class RecordTemp:
|
||||
----------
|
||||
include_self : bool
|
||||
is the file generated by self included
|
||||
parents : bool
|
||||
will we check parents
|
||||
|
||||
Raise
|
||||
------
|
||||
FileExistsError: whether the records are stored properly.
|
||||
FileNotFoundError
|
||||
: whether the records are stored properly.
|
||||
"""
|
||||
artifacts = set(self.recorder.list_artifacts())
|
||||
if include_self:
|
||||
|
||||
# Some mlflow backend will not list the directly recursively.
|
||||
# So we force to the directly
|
||||
artifacts = {}
|
||||
|
||||
def _get_arts(dirn):
|
||||
if dirn not in artifacts:
|
||||
artifacts[dirn] = self.recorder.list_artifacts(dirn)
|
||||
return artifacts[dirn]
|
||||
|
||||
for item in self.list():
|
||||
if item not in artifacts:
|
||||
raise FileExistsError(item)
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
ps = self.get_path(item).split("/")
|
||||
dirn, fn = "/".join(ps[:-1]), ps[-1]
|
||||
if self.get_path(item) not in _get_arts(dirn):
|
||||
raise FileNotFoundError
|
||||
if parents:
|
||||
if self.depend_cls is not None:
|
||||
with class_casting(self, self.depend_cls):
|
||||
self.check(include_self=True)
|
||||
|
||||
|
||||
class SignalRecord(RecordTemp):
|
||||
@@ -158,7 +192,7 @@ class SignalRecord(RecordTemp):
|
||||
pred = self.model.predict(self.dataset)
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame("score")
|
||||
self.recorder.save_objects(**{"pred.pkl": pred})
|
||||
self.save(**{"pred.pkl": pred})
|
||||
|
||||
logger.info(
|
||||
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
@@ -169,15 +203,11 @@ class SignalRecord(RecordTemp):
|
||||
|
||||
if isinstance(self.dataset, DatasetH):
|
||||
raw_label = self.generate_label(self.dataset)
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
self.save(**{"label.pkl": raw_label})
|
||||
|
||||
@staticmethod
|
||||
def list():
|
||||
def list(self):
|
||||
return ["pred.pkl", "label.pkl"]
|
||||
|
||||
def load(self, name="pred.pkl"):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
@@ -218,19 +248,11 @@ class HFSignalRecord(SignalRecord):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [
|
||||
self.get_path("ic.pkl"),
|
||||
self.get_path("ric.pkl"),
|
||||
self.get_path("long_pre.pkl"),
|
||||
self.get_path("short_pre.pkl"),
|
||||
self.get_path("long_short_r.pkl"),
|
||||
self.get_path("long_avg_r.pkl"),
|
||||
]
|
||||
return paths
|
||||
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]
|
||||
|
||||
|
||||
class SigAnaRecord(RecordTemp):
|
||||
@@ -241,13 +263,23 @@ class SigAnaRecord(RecordTemp):
|
||||
artifact_path = "sig_analysis"
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False):
|
||||
super().__init__(recorder=recorder)
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
self.label_col = label_col
|
||||
self.skip_existing = skip_existing
|
||||
|
||||
def generate(self, **kwargs):
|
||||
if self.skip_existing:
|
||||
try:
|
||||
self.check(include_self=True, parents=False)
|
||||
except FileNotFoundError:
|
||||
pass # continue to generating metrics
|
||||
else:
|
||||
logger.info("The results has previously generated, generation skipped.")
|
||||
return
|
||||
|
||||
self.check()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
@@ -280,13 +312,13 @@ class SigAnaRecord(RecordTemp):
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
self.save(**objects)
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
|
||||
paths = ["ic.pkl", "ric.pkl"]
|
||||
if self.ana_long_short:
|
||||
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
|
||||
paths.extend(["long_short_r.pkl", "long_avg_r.pkl"])
|
||||
return paths
|
||||
|
||||
|
||||
@@ -373,17 +405,11 @@ class PortAnaRecord(RecordTemp):
|
||||
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
|
||||
)
|
||||
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
|
||||
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
|
||||
|
||||
for _freq, indicators_normal in indicator_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq not in portfolio_metric_dict:
|
||||
@@ -405,9 +431,7 @@ class PortAnaRecord(RecordTemp):
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -432,9 +456,7 @@ class PortAnaRecord(RecordTemp):
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
|
||||
logger.info(
|
||||
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
@@ -446,20 +468,19 @@ class PortAnaRecord(RecordTemp):
|
||||
for _freq in self.all_freq:
|
||||
list_path.extend(
|
||||
[
|
||||
PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
|
||||
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
|
||||
f"report_normal_{_freq}.pkl",
|
||||
f"positions_normal_{_freq}.pkl",
|
||||
]
|
||||
)
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
|
||||
list_path.append(f"port_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
for _analysis_freq in self.indicator_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
|
||||
list_path.append(f"indicator_analysis_{_analysis_freq}.pkl")
|
||||
else:
|
||||
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
return list_path
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from qlib.utils.serial import Serializable
|
||||
import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
@@ -8,8 +9,10 @@ from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from qlib.utils.paral import AsyncCaller
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
from ..log import TimeInspector, get_module_logger
|
||||
from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
@@ -227,6 +230,7 @@ class MLflowRecorder(Recorder):
|
||||
if mlflow_run.info.end_time is not None
|
||||
else None
|
||||
)
|
||||
self.async_log = None
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
@@ -285,6 +289,10 @@ class MLflowRecorder(Recorder):
|
||||
self.status = Recorder.STATUS_R
|
||||
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
|
||||
|
||||
# NOTE: making logging async.
|
||||
# - This may cause delay when uploading results
|
||||
# - The logging time may not be accurate
|
||||
self.async_log = AsyncCaller()
|
||||
return run
|
||||
|
||||
def end_run(self, status: str = Recorder.STATUS_S):
|
||||
@@ -298,6 +306,9 @@ class MLflowRecorder(Recorder):
|
||||
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.status != Recorder.STATUS_S:
|
||||
self.status = status
|
||||
with TimeInspector.logt("waiting `async_log`"):
|
||||
self.async_log.wait()
|
||||
self.async_log = None
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
@@ -333,18 +344,27 @@ class MLflowRecorder(Recorder):
|
||||
try:
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
with Path(path).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
data = pickle.load(f)
|
||||
ar = self.client._tracking_client._get_artifact_repo(self.id)
|
||||
if isinstance(ar, AzureBlobArtifactRepository):
|
||||
# for saving disk space
|
||||
# For safety, only remove redundant file for specific ArtifactRepository
|
||||
shutil.rmtree(Path(path).absolute().parent)
|
||||
return data
|
||||
except Exception as e:
|
||||
raise LoadObjectError(message=str(e))
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def log_params(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_param(self.id, name, data)
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def log_metrics(self, step=None, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.log_metric(self.id, name, data, step=step)
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def set_tags(self, **kwargs):
|
||||
for name, data in kwargs.items():
|
||||
self.client.set_tag(self.id, name, data)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
|
||||
"""
|
||||
|
||||
from libs.qlib.qlib.log import TimeInspector
|
||||
from typing import Callable, Dict, List
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
@@ -190,7 +191,9 @@ class RecorderCollector(Collector):
|
||||
|
||||
collect_dict = {}
|
||||
# filter records
|
||||
recs = self.experiment.list_recorders(**self.list_kwargs)
|
||||
|
||||
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
|
||||
recs = self.experiment.list_recorders(**self.list_kwargs)
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
|
||||
@@ -94,6 +94,11 @@ class TaskGen(metaclass=abc.ABCMeta):
|
||||
def handler_mod(task: dict, rolling_gen):
|
||||
"""
|
||||
Help to modify the handler end time when using RollingGen
|
||||
It try to handle the following case
|
||||
- Hander's data end_time is earlier than dataset's test_data's segments.
|
||||
- To handle this, handler's data's end_time is extended.
|
||||
|
||||
If the handler's end_time is None, then it is not necessary to change it's end time.
|
||||
|
||||
Args:
|
||||
task (dict): a task template
|
||||
@@ -112,6 +117,9 @@ def handler_mod(task: dict, rolling_gen):
|
||||
except KeyError:
|
||||
# Maybe dataset do not have handler, then do nothing.
|
||||
pass
|
||||
except TypeError:
|
||||
# May be the handler is a string. `"handler.pkl"["kwargs"]` will raise TypeError
|
||||
pass
|
||||
|
||||
|
||||
class RollingGen(TaskGen):
|
||||
@@ -259,3 +267,56 @@ class RollingGen(TaskGen):
|
||||
# Update the following rolling
|
||||
res.extend(self.gen_following_tasks(t, test_end))
|
||||
return res
|
||||
|
||||
|
||||
class MultiHorizonGenBase(TaskGen):
|
||||
def __init__(self, horizon: List[int] = [5], label_leak_n=2):
|
||||
"""
|
||||
This task generator tries to genrate tasks for different horizons based on an existing task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
horizon : List[int]
|
||||
the possible horizons of the tasks
|
||||
label_leak_n : int
|
||||
How many future days it will take to get complete label after the day making prediction
|
||||
For example:
|
||||
- User make prediction on day `T`(after getting the close price on `T`)
|
||||
- The label is the return of buying stock on `T + 1` and selling it on `T + 2`
|
||||
- the `label_leak_n` will be 2 (e.g. two days of information is leaked to leverage this sample)
|
||||
"""
|
||||
self.horizon = list(horizon)
|
||||
self.label_leak_n = label_leak_n
|
||||
self.ta = TimeAdjuster()
|
||||
self.test_key = "test"
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_horizon(self, task: dict, hr: int):
|
||||
"""
|
||||
This method is designed to change the task **in place**
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : dict
|
||||
Qlib's task
|
||||
hr : int
|
||||
the horizon of task
|
||||
"""
|
||||
|
||||
def generate(self, task: dict):
|
||||
res = []
|
||||
for hr in self.horizon:
|
||||
|
||||
# Add horizon
|
||||
t = copy.deepcopy(task)
|
||||
self.set_horizon(t, hr)
|
||||
|
||||
# adjust segment
|
||||
segments = self.ta.align_seg(t["dataset"]["kwargs"]["segments"])
|
||||
test_start = min(t for t in segments[self.test_key] if t is not None)
|
||||
for k in list(segments.keys()):
|
||||
if k != self.test_key:
|
||||
segments[k] = self.ta.truncate(segments[k], test_start, hr + self.label_leak_n)
|
||||
t["dataset"]["kwargs"]["segments"] = segments
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
import unittest
|
||||
@@ -115,6 +114,19 @@ class IndexDataTest(unittest.TestCase):
|
||||
# sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
# 2 * sd2
|
||||
|
||||
def test_squeeze(self):
|
||||
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
|
||||
# automatically squeezing
|
||||
self.assertTrue(not isinstance(np.nansum(sd1), idd.IndexData))
|
||||
self.assertTrue(not isinstance(np.sum(sd1), idd.IndexData))
|
||||
self.assertTrue(not isinstance(sd1.sum(), idd.IndexData))
|
||||
self.assertEqual(np.nansum(sd1), 10)
|
||||
self.assertEqual(np.sum(sd1), 10)
|
||||
self.assertEqual(sd1.sum(), 10)
|
||||
self.assertEqual(np.nanmean(sd1), 2.5)
|
||||
self.assertEqual(np.mean(sd1), 2.5)
|
||||
self.assertEqual(sd1.mean(), 2.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -47,13 +47,13 @@ def train(uri_path: str = None):
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
pred_score = sr.load(sr.get_path("pred.pkl"))
|
||||
pred_score = sr.load("pred.pkl")
|
||||
|
||||
# calculate ic and ric
|
||||
sar = SigAnaRecord(recorder)
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
ic = sar.load("ic.pkl")
|
||||
ric = sar.load("ric.pkl")
|
||||
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
@@ -78,13 +78,13 @@ def train_with_sigana(uri_path: str = None):
|
||||
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
pred_score = sr.load(sr.get_path("pred.pkl"))
|
||||
pred_score = sr.load("pred.pkl")
|
||||
|
||||
# predict and calculate ic and ric
|
||||
sar = SigAnaRecord(recorder)
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
ic = sar.load("ic.pkl")
|
||||
ric = sar.load("ric.pkl")
|
||||
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
@@ -144,10 +144,9 @@ def backtest_analysis(pred, rid, uri_path: str = None):
|
||||
},
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
@@ -170,7 +169,7 @@ def backtest_analysis(pred, rid, uri_path: str = None):
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day")
|
||||
par.generate()
|
||||
analysis_df = par.load(par.get_path("port_analysis_1day.pkl"))
|
||||
analysis_df = par.load("port_analysis_1day.pkl")
|
||||
print(analysis_df)
|
||||
return analysis_df
|
||||
|
||||
|
||||
Reference in New Issue
Block a user