From ca14e36f7a5d9e281939d58490c23a8ae063ddd0 Mon Sep 17 00:00:00 2001 From: "wangwenxi.handsome" Date: Tue, 13 Jul 2021 20:54:58 +0800 Subject: [PATCH] initial account by position --- qlib/backtest/__init__.py | 24 ++++++++++++++++++------ qlib/backtest/account.py | 10 +++++++--- qlib/backtest/position.py | 4 ++-- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 99e5b8790..a171ef81e 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -8,6 +8,7 @@ from .account import Account if TYPE_CHECKING: from ..strategy.base import BaseStrategy +from .position import Position from .exchange import Exchange from .executor import BaseExecutor from .backtest import backtest_loop @@ -95,7 +96,7 @@ def get_exchange( def create_account_instance( - start_time, end_time, benchmark: str, account: float, pos_type: str = "Position" + start_time, end_time, benchmark: str, account: Union[float, int, Position], pos_type: str = "Position" ) -> Account: """ # TODO: is very strange pass benchmark_config in the account(maybe for report) @@ -109,13 +110,23 @@ def create_account_instance( end time of the benchmark benchmark : str the benchmark for reporting - account : Union[float, str] + account : Union[float, int, Position] information for describing how to creating the account - For `float` - Using Account with a normal position - For `str`: - Using account with a specific Position + For `float` or `int`: + Using Account with only initial cash + For `Position`: + Using Account with a Position """ + if(type(account) in (int, float)): + pos_kwargs = {"init_cash": account} + elif(type(account) is Position): + pos_kwargs = { + "init_cash": account.position["cash"], + "position_dict": account.position, + } + else: + raise ValueError("account must be in (int, float, Position)") + kwargs = { "init_cash": account, "benchmark_config": { @@ -125,6 +136,7 @@ def create_account_instance( }, "pos_type": pos_type, } + kwargs.update(pos_kwargs) return Account(**kwargs) diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 3ef1cdd03..fee0a98c4 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -67,6 +67,7 @@ class Account: def __init__( self, init_cash: float = 1e9, + position_dict: dict = {}, freq: str = "day", benchmark_config: dict = {}, pos_type: str = "Position", @@ -74,7 +75,7 @@ class Account: ): self._pos_type = pos_type self._port_metr_enabled = port_metr_enabled - self.init_vars(init_cash, freq, benchmark_config) + self.init_vars(init_cash, position_dict, freq, benchmark_config) def is_port_metr_enabled(self): """ @@ -82,14 +83,17 @@ class Account: """ return self._port_metr_enabled and not self.current.skip_update() - def init_vars(self, init_cash, freq: str, benchmark_config: dict): + def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict): # init cash self.init_cash = init_cash self.current: BasePosition = init_instance_by_config( { "class": self._pos_type, - "kwargs": {"cash": init_cash}, + "kwargs": { + "cash": init_cash, + "position_dict": position_dict, + }, "module_path": "qlib.backtest.position", } ) diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index b1037d460..7c32edc81 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -199,13 +199,13 @@ class Position(BasePosition): } """ - def __init__(self, cash=0, position_dict={}, now_account_value=0): + def __init__(self, cash=0, position_dict={}): # NOTE: The position dict must be copied!!! # Otherwise the initial value self.init_cash = cash self.position = position_dict.copy() self.position["cash"] = cash - self.position["now_account_value"] = now_account_value + self.position["now_account_value"] = self.calculate_value() def _init_stock(self, stock_id, amount, price=None): """