1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 18:40:58 +08:00

Merge pull request #456 from ultmaster/rl-dummy

Dummy RL example on nested decision framework
This commit is contained in:
you-n-g
2021-07-27 22:58:15 +08:00
committed by GitHub
6 changed files with 29 additions and 21 deletions

View File

@@ -8,13 +8,13 @@ from .account import Account
if TYPE_CHECKING:
from ..strategy.base import BaseStrategy
from .executor import BaseExecutor
from .position import Position
from .exchange import Exchange
from .executor import BaseExecutor
from .backtest import backtest_loop
from .backtest import collect_data_loop
from .utils import CommonInfrastructure
from .order import Order
from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
@@ -155,6 +155,7 @@ def get_strategy_executor(
# - for avoiding recursive import
# - typing annotations is not reliable
from ..strategy.base import BaseStrategy
from .executor import BaseExecutor
trade_account = create_account_instance(
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type

View File

@@ -88,17 +88,7 @@ class Account:
self._pos_type = pos_type
self._port_metr_enabled = port_metr_enabled
self.init_vars(init_cash, position_dict, freq, benchmark_config)
def is_port_metr_enabled(self):
"""
Is portfolio-based metrics enabled.
"""
return self._port_metr_enabled and not self.current.skip_update()
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
# init cash
self.init_cash = init_cash
self.current: BasePosition = init_instance_by_config(
{
@@ -113,8 +103,19 @@ class Account:
self.accum_info = AccumulatedInfo()
self.report = None
self.positions = {}
# in of reset ignore None values
self.benchmark_config = benchmark_config
self.freq = freq
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
def is_port_metr_enabled(self):
"""
Is portfolio-based metrics enabled.
"""
return self._port_metr_enabled and not self.current.skip_update()
def reset_report(self, freq, benchmark_config):
# portfolio related metrics
if self.is_port_metr_enabled():

View File

@@ -15,10 +15,11 @@ if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange
from qlib.backtest.utils import TradeCalendarManager
import warnings
import numpy as np
import pandas as pd
import numpy as np
from dataclasses import dataclass, field
from typing import ClassVar, Union, List, Set, Tuple
from typing import ClassVar, Optional, Union, List, Set, Tuple
class OrderDir(IntEnum):
@@ -62,8 +63,8 @@ class Order:
# - not tradable: the deal_amount == 0 , factor is None
# - the stock is suspended and the entire order fails. No cost for this order
# - dealed or partially dealed: deal_amount >= 0 and factor is not None
deal_amount: float = field(init=False) # `deal_amount` is a non-negative value
factor: float = field(init=False)
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
factor: Optional[float] = None
# TODO:
# a status field to indicate the dealing result of the order
@@ -108,7 +109,7 @@ class Order:
return self.direction * 2 - 1
@staticmethod
def parse_dir(direction: Union[str, int, float, np.integer, np.floating, OrderDir]) -> OrderDir:
def parse_dir(direction: Union[str, int, np.integer, OrderDir]) -> OrderDir:
if isinstance(direction, OrderDir):
return direction
elif isinstance(direction, (int, float, np.integer, np.floating)):

View File

@@ -82,11 +82,12 @@ class Report:
def init_bench(self, freq=None, benchmark_config=None):
if freq is not None:
self.freq = freq
if benchmark_config is not None:
self.benchmark_config = benchmark_config
self.benchmark_config = benchmark_config
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
def _cal_benchmark(self, benchmark_config, freq):
if benchmark_config is None:
return None
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
if benchmark is None:
return None

View File

@@ -598,7 +598,7 @@ class FileOrderStrategy(BaseStrategy):
Parameters
----------
file : Union[IO, str, Path]
file : Union[IO, str, Path, pd.DataFrame]
this parameters will specify the info of expected orders
Here is an example of the content

View File

@@ -1,7 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.backtest.exchange import Exchange
from qlib.backtest.position import BasePosition
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange
from qlib.backtest.position import BasePosition
from typing import List, Tuple, Union
from ..model.base import BaseModel