mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
Merge pull request #456 from ultmaster/rl-dummy
Dummy RL example on nested decision framework
This commit is contained in:
@@ -8,13 +8,13 @@ from .account import Account
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .executor import BaseExecutor
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .utils import CommonInfrastructure
|
||||
from .order import Order
|
||||
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
@@ -155,6 +155,7 @@ def get_strategy_executor(
|
||||
# - for avoiding recursive import
|
||||
# - typing annotations is not reliable
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
|
||||
@@ -88,17 +88,7 @@ class Account:
|
||||
|
||||
self._pos_type = pos_type
|
||||
self._port_metr_enabled = port_metr_enabled
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current.skip_update()
|
||||
|
||||
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
|
||||
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.current: BasePosition = init_instance_by_config(
|
||||
{
|
||||
@@ -113,8 +103,19 @@ class Account:
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.report = None
|
||||
self.positions = {}
|
||||
|
||||
# in of reset ignore None values
|
||||
self.benchmark_config = benchmark_config
|
||||
self.freq = freq
|
||||
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current.skip_update()
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
|
||||
@@ -15,10 +15,11 @@ if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Union, List, Set, Tuple
|
||||
from typing import ClassVar, Optional, Union, List, Set, Tuple
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
@@ -62,8 +63,8 @@ class Order:
|
||||
# - not tradable: the deal_amount == 0 , factor is None
|
||||
# - the stock is suspended and the entire order fails. No cost for this order
|
||||
# - dealed or partially dealed: deal_amount >= 0 and factor is not None
|
||||
deal_amount: float = field(init=False) # `deal_amount` is a non-negative value
|
||||
factor: float = field(init=False)
|
||||
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
|
||||
factor: Optional[float] = None
|
||||
|
||||
# TODO:
|
||||
# a status field to indicate the dealing result of the order
|
||||
@@ -108,7 +109,7 @@ class Order:
|
||||
return self.direction * 2 - 1
|
||||
|
||||
@staticmethod
|
||||
def parse_dir(direction: Union[str, int, float, np.integer, np.floating, OrderDir]) -> OrderDir:
|
||||
def parse_dir(direction: Union[str, int, np.integer, OrderDir]) -> OrderDir:
|
||||
if isinstance(direction, OrderDir):
|
||||
return direction
|
||||
elif isinstance(direction, (int, float, np.integer, np.floating)):
|
||||
|
||||
@@ -82,11 +82,12 @@ class Report:
|
||||
def init_bench(self, freq=None, benchmark_config=None):
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
if benchmark_config is not None:
|
||||
self.benchmark_config = benchmark_config
|
||||
self.benchmark_config = benchmark_config
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
if benchmark_config is None:
|
||||
return None
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
if benchmark is None:
|
||||
return None
|
||||
|
||||
@@ -598,7 +598,7 @@ class FileOrderStrategy(BaseStrategy):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file : Union[IO, str, Path]
|
||||
file : Union[IO, str, Path, pd.DataFrame]
|
||||
this parameters will specify the info of expected orders
|
||||
|
||||
Here is an example of the content
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from ..model.base import BaseModel
|
||||
|
||||
Reference in New Issue
Block a user