mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
add sample & base class
This commit is contained in:
@@ -6,6 +6,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import time
|
||||
import queue
|
||||
@@ -24,7 +25,7 @@ from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path, sample_calendar
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
@@ -55,7 +56,7 @@ class CalendarProvider(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method")
|
||||
|
||||
def locate_index(self, start_time, end_time, freq, future):
|
||||
def locate_index(self, start_time, end_time, freq, freq_sam=None, future=False):
|
||||
"""Locate the start time index and end time index in a calendar under certain frequency.
|
||||
|
||||
Parameters
|
||||
@@ -82,7 +83,7 @@ class CalendarProvider(abc.ABC):
|
||||
"""
|
||||
start_time = pd.Timestamp(start_time)
|
||||
end_time = pd.Timestamp(end_time)
|
||||
calendar, calendar_index = self._get_calendar(freq=freq, future=future)
|
||||
calendar, calendar_index = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future)
|
||||
if start_time not in calendar_index:
|
||||
try:
|
||||
start_time = calendar[bisect.bisect_left(calendar, start_time)]
|
||||
@@ -96,7 +97,7 @@ class CalendarProvider(abc.ABC):
|
||||
end_index = calendar_index[end_time]
|
||||
return start_time, end_time, start_index, end_index
|
||||
|
||||
def _get_calendar(self, freq, future):
|
||||
def _get_calendar(self, freq, freq_sam=None, future=False):
|
||||
"""Load calendar using memcache.
|
||||
|
||||
Parameters
|
||||
@@ -113,14 +114,21 @@ class CalendarProvider(abc.ABC):
|
||||
dict
|
||||
dict composed by timestamp as key and index as value for fast search.
|
||||
"""
|
||||
flag = f"{freq}_future_{future}"
|
||||
flag = f"{freq}_future_{future}_sam_{freq_sam}"
|
||||
if flag in H["c"]:
|
||||
_calendar, _calendar_index = H["c"][flag]
|
||||
else:
|
||||
flag_raw = f"{freq}_future_{future}_sam_{None}"
|
||||
_calendar = np.array(self.load_calendar(freq, future))
|
||||
_calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search
|
||||
H["c"][flag] = _calendar, _calendar_index
|
||||
return _calendar, _calendar_index
|
||||
H["c"][flag_raw] = _calendar, _calendar_index
|
||||
if freq_sam is None:
|
||||
return _calendar, _calendar_index
|
||||
else:
|
||||
_calendar_sam = sample_calendar(_calendar, freq, freq_sam)
|
||||
_calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)}
|
||||
H["c"][flag] = _calendar_sam, _calendar_sam_index
|
||||
return _calendar_sam, _calendar_sam_index
|
||||
|
||||
def _uri(self, start_time, end_time, freq, future=False):
|
||||
"""Get the uri of calendar generation task."""
|
||||
@@ -530,12 +538,13 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
with open(fname) as f:
|
||||
return [pd.Timestamp(x.strip()) for x in f]
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
_calendar, _calendar_index = self._get_calendar(freq, future)
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False, freq_sam=None):
|
||||
_calendar, _ = self._get_calendar(freq=freq, future=future)
|
||||
if start_time == "None":
|
||||
start_time = None
|
||||
if end_time == "None":
|
||||
end_time = None
|
||||
|
||||
# strip
|
||||
if start_time:
|
||||
start_time = pd.Timestamp(start_time)
|
||||
@@ -549,8 +558,15 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
return np.array([])
|
||||
else:
|
||||
end_time = _calendar[-1]
|
||||
_, _, si, ei = self.locate_index(start_time, end_time, freq, future)
|
||||
return _calendar[si : ei + 1]
|
||||
st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, future=future)
|
||||
_calendar = _calendar[si : ei + 1]
|
||||
if freq_sam is None:
|
||||
return _calendar
|
||||
else:
|
||||
_calendar_sam, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future)
|
||||
st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future)
|
||||
if bisect.bisect(_calendar, st, 0, len(_calendar)):
|
||||
return np.hstack()
|
||||
|
||||
|
||||
class LocalInstrumentProvider(InstrumentProvider):
|
||||
@@ -658,7 +674,7 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
expression = self.get_expression_instance(field)
|
||||
start_time = pd.Timestamp(start_time)
|
||||
end_time = pd.Timestamp(end_time)
|
||||
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False)
|
||||
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False)
|
||||
lft_etd, rght_etd = expression.get_extended_window_size()
|
||||
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
|
||||
# Ensure that each column type is consistent
|
||||
|
||||
9
qlib/strategy/__init__.py
Normal file
9
qlib/strategy/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .strategy import (
|
||||
TopkDropoutStrategy,
|
||||
BaseStrategy,
|
||||
WeightStrategyBase,
|
||||
)
|
||||
73
qlib/strategy/cost_control.py
Normal file
73
qlib/strategy/cost_control.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from .strategy import StrategyWrapper, WeightStrategyBase
|
||||
import copy
|
||||
|
||||
|
||||
class SoftTopkStrategy(WeightStrategyBase):
|
||||
def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill"):
|
||||
"""Parameter
|
||||
topk : int
|
||||
top-N stocks to buy
|
||||
risk_degree : float
|
||||
position percentage of total value
|
||||
buy_method :
|
||||
rank_fill: assign the weight stocks that rank high first(1/topk max)
|
||||
average_fill: assign the weight to the stocks rank high averagely.
|
||||
"""
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.max_sold_weight = max_sold_weight
|
||||
self.risk_degree = risk_degree
|
||||
self.buy_method = buy_method
|
||||
|
||||
def get_risk_degree(self, date):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing
|
||||
"""
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_target_weight_position(self, score, current, trade_date):
|
||||
"""Parameter:
|
||||
score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
|
||||
current : current position, use Position() class
|
||||
trade_date : trade date
|
||||
generate target position from score for this date and the current position
|
||||
The cache is not considered in the position
|
||||
"""
|
||||
# TODO:
|
||||
# If the current stock list is more than topk(eg. The weights are modified
|
||||
# by risk control), the weight will not be handled correctly.
|
||||
buy_signal_stocks = set(score.sort_values(ascending=False).iloc[: self.topk].index)
|
||||
cur_stock_weight = current.get_stock_weight_dict(only_stock=True)
|
||||
|
||||
if len(cur_stock_weight) == 0:
|
||||
final_stock_weight = {code: 1 / self.topk for code in buy_signal_stocks}
|
||||
else:
|
||||
final_stock_weight = copy.deepcopy(cur_stock_weight)
|
||||
sold_stock_weight = 0.0
|
||||
for stock_id in final_stock_weight:
|
||||
if stock_id not in buy_signal_stocks:
|
||||
sw = min(self.max_sold_weight, final_stock_weight[stock_id])
|
||||
sold_stock_weight += sw
|
||||
final_stock_weight[stock_id] -= sw
|
||||
if self.buy_method == "first_fill":
|
||||
for stock_id in buy_signal_stocks:
|
||||
add_weight = min(
|
||||
max(1 / self.topk - final_stock_weight.get(stock_id, 0), 0.0),
|
||||
sold_stock_weight,
|
||||
)
|
||||
final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + add_weight
|
||||
sold_stock_weight -= add_weight
|
||||
elif self.buy_method == "average_fill":
|
||||
for stock_id in buy_signal_stocks:
|
||||
final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + sold_stock_weight / len(
|
||||
buy_signal_stocks
|
||||
)
|
||||
else:
|
||||
raise ValueError("Buy method not found")
|
||||
return final_stock_weight
|
||||
171
qlib/strategy/order_generator.py
Normal file
171
qlib/strategy/order_generator.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
from ..backtest.position import Position
|
||||
from ..backtest.exchange import Exchange
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
|
||||
class OrderGenerator:
|
||||
def generate_order_list_from_target_weight_position(
|
||||
self,
|
||||
current: Position,
|
||||
trade_exchange: Exchange,
|
||||
target_weight_position: dict,
|
||||
risk_degree: float,
|
||||
pred_date: pd.Timestamp,
|
||||
trade_date: pd.Timestamp,
|
||||
) -> list:
|
||||
"""generate_order_list_from_target_weight_position
|
||||
|
||||
:param current: The current position
|
||||
:type current: Position
|
||||
:param trade_exchange:
|
||||
:type trade_exchange: Exchange
|
||||
:param target_weight_position: {stock_id : weight}
|
||||
:type target_weight_position: dict
|
||||
:param risk_degree:
|
||||
:type risk_degree: float
|
||||
:param pred_date: the date the score is predicted
|
||||
:type pred_date: pd.Timestamp
|
||||
:param trade_date: the date the stock is traded
|
||||
:type trade_date: pd.Timestamp
|
||||
|
||||
:rtype: list
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class OrderGenWInteract(OrderGenerator):
|
||||
"""Order Generator With Interact"""
|
||||
|
||||
def generate_order_list_from_target_weight_position(
|
||||
self,
|
||||
current: Position,
|
||||
trade_exchange: Exchange,
|
||||
target_weight_position: dict,
|
||||
risk_degree: float,
|
||||
pred_date: pd.Timestamp,
|
||||
trade_date: pd.Timestamp,
|
||||
) -> list:
|
||||
"""generate_order_list_from_target_weight_position
|
||||
|
||||
No adjustment for for the nontradable share.
|
||||
All the tadable value is assigned to the tadable stock according to the weight.
|
||||
if interact == True, will use the price at trade date to generate order list
|
||||
else, will only use the price before the trade date to generate order list
|
||||
|
||||
:param current:
|
||||
:type current: Position
|
||||
:param trade_exchange:
|
||||
:type trade_exchange: Exchange
|
||||
:param target_weight_position:
|
||||
:type target_weight_position: dict
|
||||
:param risk_degree:
|
||||
:type risk_degree: float
|
||||
:param pred_date:
|
||||
:type pred_date: pd.Timestamp
|
||||
:param trade_date:
|
||||
:type trade_date: pd.Timestamp
|
||||
|
||||
:rtype: list
|
||||
"""
|
||||
# calculate current_tradable_value
|
||||
current_amount_dict = current.get_stock_amount_dict()
|
||||
current_total_value = trade_exchange.calculate_amount_position_value(
|
||||
amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=False
|
||||
)
|
||||
current_tradable_value = trade_exchange.calculate_amount_position_value(
|
||||
amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=True
|
||||
)
|
||||
# add cash
|
||||
current_tradable_value += current.get_cash()
|
||||
|
||||
reserved_cash = (1.0 - risk_degree) * (current_total_value + current.get_cash())
|
||||
current_tradable_value -= reserved_cash
|
||||
|
||||
if current_tradable_value < 0:
|
||||
# if you sell all the tradable stock can not meet the reserved
|
||||
# value. Then just sell all the stocks
|
||||
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
|
||||
for stock_id in list(target_amount_dict.keys()):
|
||||
if trade_exchange.is_stock_tradable(stock_id, trade_date):
|
||||
del target_amount_dict[stock_id]
|
||||
else:
|
||||
# consider cost rate
|
||||
current_tradable_value /= 1 + max(trade_exchange.close_cost, trade_exchange.open_cost)
|
||||
|
||||
# strategy 1 : generate amount_position by weight_position
|
||||
# Use API in Exchange()
|
||||
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
|
||||
weight_position=target_weight_position,
|
||||
cash=current_tradable_value,
|
||||
trade_date=trade_date,
|
||||
)
|
||||
order_list = trade_exchange.generate_order_for_target_amount_position(
|
||||
target_position=target_amount_dict,
|
||||
current_position=current_amount_dict,
|
||||
trade_date=trade_date,
|
||||
)
|
||||
return order_list
|
||||
|
||||
|
||||
class OrderGenWOInteract(OrderGenerator):
|
||||
"""Order Generator Without Interact"""
|
||||
|
||||
def generate_order_list_from_target_weight_position(
|
||||
self,
|
||||
current: Position,
|
||||
trade_exchange: Exchange,
|
||||
target_weight_position: dict,
|
||||
risk_degree: float,
|
||||
pred_date: pd.Timestamp,
|
||||
trade_date: pd.Timestamp,
|
||||
) -> list:
|
||||
"""generate_order_list_from_target_weight_position
|
||||
|
||||
generate order list directly not using the information (e.g. whether can be traded, the accurate trade price) at trade date.
|
||||
In target weight position, generating order list need to know the price of objective stock in trade date, but we cannot get that
|
||||
value when do not interact with exchange, so we check the %close price at pred_date or price recorded in current position.
|
||||
|
||||
:param current:
|
||||
:type current: Position
|
||||
:param trade_exchange:
|
||||
:type trade_exchange: Exchange
|
||||
:param target_weight_position:
|
||||
:type target_weight_position: dict
|
||||
:param risk_degree:
|
||||
:type risk_degree: float
|
||||
:param pred_date:
|
||||
:type pred_date: pd.Timestamp
|
||||
:param trade_date:
|
||||
:type trade_date: pd.Timestamp
|
||||
|
||||
:rtype: list
|
||||
"""
|
||||
risk_total_value = risk_degree * current.calculate_value()
|
||||
|
||||
current_stock = current.get_stock_list()
|
||||
amount_dict = {}
|
||||
for stock_id in target_weight_position:
|
||||
# Current rule will ignore the stock that not hold and cannot be traded at predict date
|
||||
if trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=pred_date):
|
||||
amount_dict[stock_id] = (
|
||||
risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date)
|
||||
)
|
||||
elif stock_id in current_stock:
|
||||
amount_dict[stock_id] = (
|
||||
risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
order_list = trade_exchange.generate_order_for_target_amount_position(
|
||||
target_position=amount_dict,
|
||||
current_position=current.get_stock_amount_dict(),
|
||||
trade_date=trade_date,
|
||||
)
|
||||
return order_list
|
||||
304
qlib/strategy/strategy.py
Normal file
304
qlib/strategy/strategy.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
from ..backtest.order import Order
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
"""
|
||||
1. BaseStrategy 的粒度一定是数据粒度的整数倍
|
||||
- 关于calendar的合并咋整
|
||||
- adjust_dates这个东西啥用
|
||||
- label和freq和strategy的bar分离,这个如何决策呢
|
||||
"""
|
||||
class BaseStrategy:
|
||||
def __init__(self, bar, start_time, end_time):
|
||||
self.bar = bar
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
self.current_time = start_time
|
||||
|
||||
def generate_action(self, current):
|
||||
pass
|
||||
|
||||
|
||||
class RuleStrategy(BaseStrategy):
|
||||
pass
|
||||
|
||||
class DLStrategy(BaseStrategy):
|
||||
def __init__(self, bar, model, dataset:DatasetH, start_time=None, end_time=None):
|
||||
super(DLStrategy, self).__init__(bar, start_time, end_time)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_score_all = self.model.predict(dataset)
|
||||
self.pred_score = None
|
||||
_pred_dates = pred.index.get_level_values(level="datetime")
|
||||
self.start_time = _pred_dates.min() if start_time is None else start_time
|
||||
self.end_time = _pred_dates.max() if end_time is None else end_time
|
||||
self.pred_date = [pd.Timestamp(self.start_time), *D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max(), freq=bar), self.end_time]
|
||||
self.current_index = -1
|
||||
self.pred_length = len(self.pred_date)
|
||||
|
||||
def _update_pred_score(self):
|
||||
"""update pred score
|
||||
"""
|
||||
pass
|
||||
|
||||
class AdjustTimer:
|
||||
"""AdjustTimer
|
||||
Responsible for timing of position adjusting
|
||||
|
||||
This is designed as multiple inheritance mechanism due to:
|
||||
- the is_adjust may need access to the internel state of a strategy.
|
||||
|
||||
- it can be reguard as a enhancement to the existing strategy.
|
||||
"""
|
||||
|
||||
# adjust position in each trade date
|
||||
def is_adjust(self, trade_date):
|
||||
"""is_adjust
|
||||
Return if the strategy can adjust positions on `trade_date`
|
||||
Will normally be used in strategy do trading with trade frequency
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
class ListAdjustTimer(AdjustTimer):
|
||||
def __init__(self, adjust_dates=None):
|
||||
"""__init__
|
||||
|
||||
:param adjust_dates: an iterable object, it will return a timelist for trading dates
|
||||
"""
|
||||
if adjust_dates is None:
|
||||
# None indicates that all dates is OK for adjusting
|
||||
self.adjust_dates = None
|
||||
else:
|
||||
self.adjust_dates = {pd.Timestamp(dt) for dt in adjust_dates}
|
||||
|
||||
def is_adjust(self, trade_date):
|
||||
if self.adjust_dates is None:
|
||||
return True
|
||||
return pd.Timestamp(trade_date) in self.adjust_dates
|
||||
|
||||
class TopkDropoutStrategy(DLStrategy, ListAdjustTimer):
|
||||
def __init__(
|
||||
self,
|
||||
bar,
|
||||
model,
|
||||
dataset,
|
||||
trade_exchange,
|
||||
topk,
|
||||
n_drop,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
method_sell="bottom",
|
||||
method_buy="top",
|
||||
risk_degree=0.95,
|
||||
thresh=1,
|
||||
hold_thresh=1,
|
||||
only_tradable=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
-----------
|
||||
topk : int
|
||||
the number of stocks in the portfolio.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
method_sell : str
|
||||
dropout method_sell, random/bottom.
|
||||
method_buy : str
|
||||
dropout method_buy, random/top.
|
||||
risk_degree : float
|
||||
position percentage of total value.
|
||||
thresh : int
|
||||
minimun holding days since last buy singal of the stock.
|
||||
hold_thresh : int
|
||||
minimum holding days
|
||||
before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh.
|
||||
only_tradable : bool
|
||||
will the strategy only consider the tradable stock when buying and selling.
|
||||
if only_tradable:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(bar, model, dataset, start_time, end_time)
|
||||
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
|
||||
self.trade_exchange = trade_exchange
|
||||
self.topk = topk
|
||||
self.n_drop = n_drop
|
||||
self.method_sell = method_sell
|
||||
self.method_buy = method_buy
|
||||
self.risk_degree = risk_degree
|
||||
self.thresh = thresh
|
||||
# self.stock_count['code'] will be the days the stock has been hold
|
||||
# since last buy signal. This is designed for thresh
|
||||
self.stock_count = {}
|
||||
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
def get_risk_degree(self, date):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing.
|
||||
"""
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_action(self, current):
|
||||
|
||||
self.current_index += 1
|
||||
|
||||
if not self.is_adjust(trade_date):
|
||||
return []
|
||||
|
||||
if self.only_tradable:
|
||||
# If The strategy only consider tradable stock when make decision
|
||||
# It needs following actions to filter stocks
|
||||
def get_first_n(l, n, reverse=False):
|
||||
cur_n = 0
|
||||
res = []
|
||||
for si in reversed(l) if reverse else l:
|
||||
if self.trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date):
|
||||
res.append(si)
|
||||
cur_n += 1
|
||||
if cur_n >= n:
|
||||
break
|
||||
return res[::-1] if reverse else res
|
||||
|
||||
def get_last_n(l, n):
|
||||
return get_first_n(l, n, reverse=True)
|
||||
|
||||
def filter_stock(l):
|
||||
return [si for si in l if self.trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date)]
|
||||
|
||||
else:
|
||||
# Otherwise, the stock will make decision with out the stock tradable info
|
||||
def get_first_n(l, n):
|
||||
return list(l)[:n]
|
||||
|
||||
def get_last_n(l, n):
|
||||
return list(l)[-n:]
|
||||
|
||||
def filter_stock(l):
|
||||
return l
|
||||
|
||||
current_temp = copy.deepcopy(current)
|
||||
# generate order list for this adjust date
|
||||
sell_order_list = []
|
||||
buy_order_list = []
|
||||
# load score
|
||||
cash = current_temp.get_cash()
|
||||
current_stock_list = current_temp.get_stock_list()
|
||||
# last position (sorted by score)
|
||||
last = self.pred_score.reindex(current_stock_list).sort_values(ascending=False).index
|
||||
# The new stocks today want to buy **at most**
|
||||
if self.method_buy == "top":
|
||||
today = get_first_n(
|
||||
self.pred_score[~self.pred_score.index.isin(last)].sort_values(ascending=False).index,
|
||||
self.n_drop + self.topk - len(last),
|
||||
)
|
||||
elif self.method_buy == "random":
|
||||
topk_candi = get_first_n(self.pred_score.sort_values(ascending=False).index, self.topk)
|
||||
candi = list(filter(lambda x: x not in last, topk_candi))
|
||||
n = self.n_drop + self.topk - len(last)
|
||||
try:
|
||||
today = np.random.choice(candi, n, replace=False)
|
||||
except ValueError:
|
||||
today = candi
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
# combine(new stocks + last stocks), we will drop stocks from this list
|
||||
# In case of dropping higher score stock and buying lower score stock.
|
||||
comb = self.pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index
|
||||
|
||||
# Get the stock list we really want to sell (After filtering the case that we sell high and buy low)
|
||||
if self.method_sell == "bottom":
|
||||
sell = last[last.isin(get_last_n(comb, self.n_drop))]
|
||||
elif self.method_sell == "random":
|
||||
candi = filter_stock(last)
|
||||
try:
|
||||
sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
|
||||
except ValueError: # No enough candidates
|
||||
sell = candi
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
# Get the stock list we really want to buy
|
||||
buy = today[: len(sell) + self.topk - len(last)]
|
||||
|
||||
# buy singal: if a stock falls into topk, it appear in the buy_sinal
|
||||
buy_signal = self.pred_score.sort_values(ascending=False).iloc[: self.topk].index
|
||||
|
||||
for code in current_stock_list:
|
||||
if not self.trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
|
||||
continue
|
||||
if code in sell:
|
||||
# check hold limit
|
||||
if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh:
|
||||
# can not sell this code
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
continue
|
||||
# sell order
|
||||
sell_amount = current_temp.get_stock_amount(code=code)
|
||||
sell_order = Order(
|
||||
stock_id=code,
|
||||
amount=sell_amount,
|
||||
trade_date=trade_date,
|
||||
direction=Order.SELL, # 0 for sell, 1 for buy
|
||||
factor=self.trade_exchange.get_factor(code, trade_date),
|
||||
)
|
||||
# is order executable
|
||||
if self.trade_exchange.check_order(sell_order):
|
||||
sell_order_list.append(sell_order)
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(sell_order, position=current_temp)
|
||||
# update cash
|
||||
cash += trade_val - trade_cost
|
||||
# sold
|
||||
del self.stock_count[code]
|
||||
else:
|
||||
# no buy signal, but the stock is kept
|
||||
self.stock_count[code] += 1
|
||||
elif code in buy_signal:
|
||||
# NOTE: This is different from the original version
|
||||
# get new buy signal
|
||||
# Only the stock fall in to topk will produce buy signal
|
||||
self.stock_count[code] = 1
|
||||
else:
|
||||
self.stock_count[code] += 1
|
||||
# buy new stock
|
||||
# note the current has been changed
|
||||
current_stock_list = current_temp.get_stock_list()
|
||||
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
|
||||
|
||||
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
|
||||
# consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
|
||||
# value = value / (1+self.trade_exchange.open_cost) # set open_cost limit
|
||||
for code in buy:
|
||||
# check is stock suspended
|
||||
if not self.trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
|
||||
continue
|
||||
# buy order
|
||||
buy_price = self.trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date)
|
||||
buy_amount = value / buy_price
|
||||
factor = self.trade_exchange.quote[(code, trade_date)]["$factor"]
|
||||
buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor)
|
||||
buy_order = Order(
|
||||
stock_id=code,
|
||||
amount=buy_amount,
|
||||
trade_date=trade_date,
|
||||
direction=Order.BUY, # 1 for buy
|
||||
factor=factor,
|
||||
)
|
||||
buy_order_list.append(buy_order)
|
||||
self.stock_count[code] = 1
|
||||
return sell_order_list + buy_order_list
|
||||
@@ -799,3 +799,123 @@ def fname_to_code(fname: str):
|
||||
if fname.startswith(prefix):
|
||||
fname = fname.lstrip(prefix)
|
||||
return fname
|
||||
|
||||
########################## Sample ############################
|
||||
def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
|
||||
"""
|
||||
freq_raw : "min" or "day"
|
||||
"""
|
||||
freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw
|
||||
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
|
||||
|
||||
if freq_sam.endswith(("minute", "min")):
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if 9 <= hour <= 11:
|
||||
minute_index = (11 - hour)*60 + 30 - minute + 120
|
||||
elif 13 <= hour <= 15:
|
||||
minute_index = (15 - hour)*60 - minute
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
return 15 - (minute_index + 59) // 60, (120 - minute_index) % 60
|
||||
elif 120 <= minute_index < 240:
|
||||
return 11 - (minute_index - 120 + 29) // 60, (240 - minute_index + 30) % 60
|
||||
else:
|
||||
raise ValueError("calendar minute_index error")
|
||||
|
||||
sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6])
|
||||
|
||||
if not freq_raw.endswith(("minute", "min")):
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6])
|
||||
if raw_minutes > sam_minutes:
|
||||
raise ValueError("raw freq must be higher than sample freq")
|
||||
|
||||
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59), calendar_raw)))
|
||||
return _calendar_minute
|
||||
else:
|
||||
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 23, 59, 59), calendar_raw)))
|
||||
if freq_sam.endswith(("day", "d")):
|
||||
sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3])
|
||||
return _calendar_day[(len(_calendar_day) + sam_days - 1)%sam_days::sam_days]
|
||||
|
||||
elif freq_sam.endswith(("week", "w")):
|
||||
sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4])
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week[::-1], to_begin=1)[::-1] > 0]
|
||||
return _calendar_week[(len(_calendar_week) + sam_weeks - 1)%sam_weeks::sam_weeks]
|
||||
|
||||
elif freq_sam.endswith(("month", "m")):
|
||||
sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5])
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month[::-1], to_begin=1)[::-1] > 0]
|
||||
return _calendar_month[(len(_calendar_month) + sam_months - 1)%sam_months::sam_months]
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
"""
|
||||
freq_raw : "min" or "day"
|
||||
"""
|
||||
freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw
|
||||
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
|
||||
|
||||
if freq_sam.endswith(("minute", "min")):
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if 9 <= hour <= 11:
|
||||
minute_index = (hour - 9)*60 + minute - 30
|
||||
elif 13 <= hour <= 15:
|
||||
minute_index = (hour - 13)*60 + minute + 120
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60
|
||||
elif 120 <= minute_index < 240:
|
||||
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
|
||||
else:
|
||||
raise ValueError("calendar minute_index error")
|
||||
sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6])
|
||||
if not freq_raw.endswith(("minute", "min")):
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6])
|
||||
if raw_minutes > sam_minutes:
|
||||
raise ValueError("raw freq must be higher than sample freq")
|
||||
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 0), calendar_raw)))
|
||||
return _calendar_minute
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam.endswith(("day", "d")):
|
||||
sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3])
|
||||
return _calendar_day[::sam_days]
|
||||
|
||||
elif freq_sam.endswith(("week", "w")):
|
||||
sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4])
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
|
||||
return _calendar_week[::sam_weeks]
|
||||
|
||||
elif freq_sam.endswith(("month", "m")):
|
||||
sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5])
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
|
||||
return _calendar_month[::sam_months]
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
def sample_feature(feature_raw, freq, start_time, end_time, method="last"):
|
||||
datetime_raw = feature_raw.index.get_level_values("datetime")
|
||||
feature_sample = feature_raw[list(map(lambda x: start_time < x <= end_time, datetime_raw))]
|
||||
return getattr(feature_sample.groupby(level="instrument"), method)()
|
||||
Reference in New Issue
Block a user