diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index a97841da7..1babd08c7 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -8,13 +8,13 @@ from .account import Account if TYPE_CHECKING: from ..strategy.base import BaseStrategy + from .executor import BaseExecutor from .position import Position from .exchange import Exchange -from .executor import BaseExecutor from .backtest import backtest_loop from .backtest import collect_data_loop -from .utils import CommonInfrastructure from .order import Order +from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager from ..utils import init_instance_by_config from ..log import get_module_logger from ..config import C @@ -155,6 +155,7 @@ def get_strategy_executor( # - for avoiding recursive import # - typing annotations is not reliable from ..strategy.base import BaseStrategy + from .executor import BaseExecutor trade_account = create_account_instance( start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 773e1a037..542c0fba2 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -88,17 +88,7 @@ class Account: self._pos_type = pos_type self._port_metr_enabled = port_metr_enabled - self.init_vars(init_cash, position_dict, freq, benchmark_config) - def is_port_metr_enabled(self): - """ - Is portfolio-based metrics enabled. - """ - return self._port_metr_enabled and not self.current.skip_update() - - def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict): - - # init cash self.init_cash = init_cash self.current: BasePosition = init_instance_by_config( { @@ -113,8 +103,19 @@ class Account: self.accum_info = AccumulatedInfo() self.report = None self.positions = {} + + # in of reset ignore None values + self.benchmark_config = benchmark_config + self.freq = freq + self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True) + def is_port_metr_enabled(self): + """ + Is portfolio-based metrics enabled. + """ + return self._port_metr_enabled and not self.current.skip_update() + def reset_report(self, freq, benchmark_config): # portfolio related metrics if self.is_port_metr_enabled(): diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index bb615dc06..abd02554a 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -15,10 +15,11 @@ if TYPE_CHECKING: from qlib.backtest.exchange import Exchange from qlib.backtest.utils import TradeCalendarManager import warnings +import numpy as np import pandas as pd import numpy as np from dataclasses import dataclass, field -from typing import ClassVar, Union, List, Set, Tuple +from typing import ClassVar, Optional, Union, List, Set, Tuple class OrderDir(IntEnum): @@ -62,8 +63,8 @@ class Order: # - not tradable: the deal_amount == 0 , factor is None # - the stock is suspended and the entire order fails. No cost for this order # - dealed or partially dealed: deal_amount >= 0 and factor is not None - deal_amount: float = field(init=False) # `deal_amount` is a non-negative value - factor: float = field(init=False) + deal_amount: Optional[float] = None # `deal_amount` is a non-negative value + factor: Optional[float] = None # TODO: # a status field to indicate the dealing result of the order @@ -108,7 +109,7 @@ class Order: return self.direction * 2 - 1 @staticmethod - def parse_dir(direction: Union[str, int, float, np.integer, np.floating, OrderDir]) -> OrderDir: + def parse_dir(direction: Union[str, int, np.integer, OrderDir]) -> OrderDir: if isinstance(direction, OrderDir): return direction elif isinstance(direction, (int, float, np.integer, np.floating)): diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index fb1eeedfa..2d188dd18 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -82,11 +82,12 @@ class Report: def init_bench(self, freq=None, benchmark_config=None): if freq is not None: self.freq = freq - if benchmark_config is not None: - self.benchmark_config = benchmark_config + self.benchmark_config = benchmark_config self.bench = self._cal_benchmark(self.benchmark_config, self.freq) def _cal_benchmark(self, benchmark_config, freq): + if benchmark_config is None: + return None benchmark = benchmark_config.get("benchmark", CSI300_BENCH) if benchmark is None: return None diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index eabbe357b..7f04f444e 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -598,7 +598,7 @@ class FileOrderStrategy(BaseStrategy): Parameters ---------- - file : Union[IO, str, Path] + file : Union[IO, str, Path, pd.DataFrame] this parameters will specify the info of expected orders Here is an example of the content diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index fa21fae5f..c47d2494f 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,7 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qlib.backtest.exchange import Exchange -from qlib.backtest.position import BasePosition +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from qlib.backtest.exchange import Exchange + from qlib.backtest.position import BasePosition from typing import List, Tuple, Union from ..model.base import BaseModel