1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

initial account by position

This commit is contained in:
wangwenxi.handsome
2021-07-13 20:54:58 +08:00
committed by you-n-g
parent 9b38e62f21
commit ca14e36f7a
3 changed files with 27 additions and 11 deletions

View File

@@ -8,6 +8,7 @@ from .account import Account
if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
from .position import Position
from .exchange import Exchange
from .executor import BaseExecutor
from .backtest import backtest_loop
@@ -95,7 +96,7 @@ def get_exchange(
def create_account_instance(
start_time, end_time, benchmark: str, account: float, pos_type: str = "Position"
start_time, end_time, benchmark: str, account: Union[float, int, Position], pos_type: str = "Position"
) -> Account:
"""
# TODO: is very strange pass benchmark_config in the account(maybe for report)
@@ -109,13 +110,23 @@ def create_account_instance(
end time of the benchmark
benchmark : str
the benchmark for reporting
account : Union[float, str]
account : Union[float, int, Position]
information for describing how to creating the account
For `float`
Using Account with a normal position
For `str`:
Using account with a specific Position
For `float` or `int`:
Using Account with only initial cash
For `Position`:
Using Account with a Position
"""
if(type(account) in (int, float)):
pos_kwargs = {"init_cash": account}
elif(type(account) is Position):
pos_kwargs = {
"init_cash": account.position["cash"],
"position_dict": account.position,
}
else:
raise ValueError("account must be in (int, float, Position)")
kwargs = {
"init_cash": account,
"benchmark_config": {
@@ -125,6 +136,7 @@ def create_account_instance(
},
"pos_type": pos_type,
}
kwargs.update(pos_kwargs)
return Account(**kwargs)

View File

@@ -67,6 +67,7 @@ class Account:
def __init__(
self,
init_cash: float = 1e9,
position_dict: dict = {},
freq: str = "day",
benchmark_config: dict = {},
pos_type: str = "Position",
@@ -74,7 +75,7 @@ class Account:
):
self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled
self.init_vars(init_cash, freq, benchmark_config)
self.init_vars(init_cash, position_dict, freq, benchmark_config)
def is_port_metr_enabled(self):
"""
@@ -82,14 +83,17 @@ class Account:
"""
return self._port_metr_enabled and not self.current.skip_update()
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
# init cash
self.init_cash = init_cash
self.current: BasePosition = init_instance_by_config(
{
"class": self._pos_type,
"kwargs": {"cash": init_cash},
"kwargs": {
"cash": init_cash,
"position_dict": position_dict,
},
"module_path": "qlib.backtest.position",
}
)

View File

@@ -199,13 +199,13 @@ class Position(BasePosition):
}
"""
def __init__(self, cash=0, position_dict={}, now_account_value=0):
def __init__(self, cash=0, position_dict={}):
# NOTE: The position dict must be copied!!!
# Otherwise the initial value
self.init_cash = cash
self.position = position_dict.copy()
self.position["cash"] = cash
self.position["now_account_value"] = now_account_value
self.position["now_account_value"] = self.calculate_value()
def _init_stock(self, stock_id, amount, price=None):
"""