1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 19:10:58 +08:00
This commit is contained in:
Default
2022-06-20 10:05:25 +08:00
parent 92d4ec4dce
commit 7535d60e99
7 changed files with 171 additions and 8 deletions

View File

@@ -349,4 +349,4 @@ def format_decisions(
return res
__all__ = ["Order", "backtest"]
__all__ = ["Order", "backtest", "BaseExecutor", "CommonInfrastructure"]

View File

@@ -5,3 +5,24 @@
Currently it supports single-asset order execution.
Multi-asset is on the way.
"""
from .interpreter import (
CategoricalActionInterpreter,
CurrentStepStateInterpreter,
FullHistoryStateInterpreter,
TwapRelativeActionInterpreter,
)
from .network import Recurrent
from .policy import PPO, AllOne
from .simulator_simple import SingleAssetOrderExecution
__all__ = [
"CategoricalActionInterpreter",
"CurrentStepStateInterpreter",
"FullHistoryStateInterpreter",
"TwapRelativeActionInterpreter",
"Recurrent",
"PPO",
"AllOne",
"SingleAssetOrderExecution",
]

View File

@@ -2,3 +2,142 @@
# Licensed under the MIT License.
"""Placeholder for qlib-based simulator."""
from dataclasses import dataclass
from pathlib import Path
from plistlib import Dict
from typing import Callable, Generator, List, Optional, Tuple, Union
import pandas as pd
from qlib.backtest import Account, BaseExecutor, CommonInfrastructure, get_exchange
from qlib.backtest.decision import Order
from qlib.backtest.executor import NestedExecutor
from qlib.config import QlibConfig
from qlib.rl.simulator import ActType, Simulator, StateType
from qlib.strategy.base import BaseStrategy
@dataclass
class ExchangeConfig:
limit_threshold: Union[float, Tuple[str, str]]
deal_price: Union[str, Tuple[str, str]]
volume_threshold: Union[float, Dict[str, Tuple[str, str]]]
open_cost: float = 0.0005
close_cost: float = 0.0015
min_cost: float = 5.
trade_unit: Optional[float] = 100.
cash_limit: Optional[Union[Path, float]] = None
generate_report: bool = False
def get_common_infra(
config: ExchangeConfig,
trade_start_time: pd.Timestamp,
trade_end_time: pd.Timestamp,
codes: List[str],
cash_limit: Optional[float] = None,
) -> CommonInfrastructure:
# need to specify a range here for acceleration
if cash_limit is None:
trade_account = Account(
init_cash=int(1e12),
benchmark_config={},
pos_type='InfPosition'
)
else:
trade_account = Account(
init_cash=cash_limit,
benchmark_config={},
pos_type='Position',
position_dict={code: {"amount": 1e12, "price": 1.} for code in codes}
)
exchange = get_exchange(
codes=codes,
freq='1min',
limit_threshold=config.limit_threshold,
deal_price=config.deal_price,
open_cost=config.open_cost,
close_cost=config.close_cost,
min_cost=config.min_cost if config.trade_unit is not None else 0,
start_time=pd.Timestamp(trade_start_time),
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
trade_unit=config.trade_unit,
volume_threshold=config.volume_threshold
)
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
class QlibSimulator(Simulator[Order, StateType, ActType]):
def __init__(
self,
qlib_config: QlibConfig,
time_per_step: str,
top_strategy: BaseStrategy,
inner_strategy_fn: Callable[[], BaseStrategy],
inner_executor_fn: Callable[[CommonInfrastructure], BaseExecutor],
exchange_config: ExchangeConfig,
) -> None:
super(QlibSimulator, self).__init__(
initial=None, # TODO
)
self.qlib_config = qlib_config
self._time_per_step = time_per_step
self._top_strategy = top_strategy
self._inner_executor_fn = inner_executor_fn
self._inner_strategy_fn = inner_strategy_fn
self._exchange_config = exchange_config
self._executor: Optional[NestedExecutor] = None
self._collect_data_loop: Optional[Generator] = None
self._done = False
def _reset(
self,
instrument: str,
date_time: pd.Timestamp,
) -> None:
# TODO: init_qlib
common_infra = get_common_infra(
self._exchange_config,
trade_start_time=date_time,
trade_end_time=date_time,
codes=[instrument],
)
self._executor = NestedExecutor(
time_per_step=self._time_per_step,
inner_executor=self._inner_executor_fn(common_infra),
inner_strategy=self._inner_strategy_fn(),
track_data=True,
)
self._executor.reset(start_time=date_time, end_time=date_time)
self._top_strategy.reset(level_infra=self._executor.get_level_infra())
self._collect_data_loop = self._executor.collect_data(self._top_strategy.generate_trade_decision(), level=0)
assert isinstance(self._collect_data_loop, Generator)
self._done = False
def step(self, action: ActType) -> None:
try:
strategy = self._collect_data_loop.send(action)
while not isinstance(strategy, BaseStrategy):
strategy = self._collect_data_loop.send(action)
assert isinstance(strategy, BaseStrategy)
# TODO: do something here
except StopIteration:
self._done = True
def get_state(self) -> StateType:
pass # TODO: Collect info from executor. Generate state.
def done(self) -> bool:
return self._done

View File

@@ -57,7 +57,7 @@ class SAOEMetrics(TypedDict):
trade_price: float
"""The average deal price for this strategy."""
trade_value: float
"""Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
"""Total worth of trading. In the simple simulation, trade_value = deal_amount * price."""
position: float
"""Position left after this "period"."""

View File

@@ -2,9 +2,9 @@
# Licensed under the MIT License.
from .data_queue import DataQueue
from .env_wrapper import EnvWrapper
from .env_wrapper import EnvWrapper, EnvWrapperStatus
from .finite_env import FiniteEnvType, vectorize_env
from .log import LogCollector, LogLevel, LogWriter
from .log import ConsoleWriter, CsvWriter, LogCollector, LogLevel, LogWriter
__all__ = [
"LogLevel",
@@ -14,4 +14,7 @@ __all__ = [
"LogCollector",
"LogWriter",
"vectorize_env",
"ConsoleWriter",
"CsvWriter",
"EnvWrapperStatus",
]

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from __future__ import annotations
from abc import abstractmethod
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
@@ -205,7 +205,7 @@ class BaseStrategy:
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
class RLStrategy(BaseStrategy):
class RLStrategy(BaseStrategy, metaclass=ABCMeta):
"""RL-based strategy"""
def __init__(
@@ -226,7 +226,7 @@ class RLStrategy(BaseStrategy):
self.policy = policy
class RLIntStrategy(RLStrategy):
class RLIntStrategy(RLStrategy, metaclass=ABCMeta):
"""(RL)-based (Strategy) with (Int)erpreter"""
def __init__(

View File

@@ -18,7 +18,7 @@ from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.entries.test import backtest
from qlib.rl.order_execution import *
from qlib.rl.order_execution import SingleAssetOrderExecution, FullHistoryStateInterpreter, CurrentStepStateInterpreter, CategoricalActionInterpreter, TwapRelativeActionInterpreter, AllOne, Recurrent, PPO
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")