From d3a1e03a113127bf65464c0b53e2fdd213d8dd2e Mon Sep 17 00:00:00 2001 From: bxdd Date: Sat, 20 Mar 2021 00:11:19 +0800 Subject: [PATCH] add sample & base class --- qlib/data/data.py | 40 ++-- qlib/strategy/__init__.py | 9 + qlib/strategy/cost_control.py | 73 ++++++++ qlib/strategy/order_generator.py | 171 +++++++++++++++++ qlib/strategy/strategy.py | 304 +++++++++++++++++++++++++++++++ qlib/utils/__init__.py | 120 ++++++++++++ 6 files changed, 705 insertions(+), 12 deletions(-) create mode 100644 qlib/strategy/__init__.py create mode 100644 qlib/strategy/cost_control.py create mode 100644 qlib/strategy/order_generator.py create mode 100644 qlib/strategy/strategy.py diff --git a/qlib/data/data.py b/qlib/data/data.py index 000bd1196..68e1a69d2 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function import os +import re import abc import time import queue @@ -24,7 +25,7 @@ from ..log import get_module_logger from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache -from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path +from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path, sample_calendar class CalendarProvider(abc.ABC): @@ -55,7 +56,7 @@ class CalendarProvider(abc.ABC): """ raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method") - def locate_index(self, start_time, end_time, freq, future): + def locate_index(self, start_time, end_time, freq, freq_sam=None, future=False): """Locate the start time index and end time index in a calendar under certain frequency. Parameters @@ -82,7 +83,7 @@ class CalendarProvider(abc.ABC): """ start_time = pd.Timestamp(start_time) end_time = pd.Timestamp(end_time) - calendar, calendar_index = self._get_calendar(freq=freq, future=future) + calendar, calendar_index = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future) if start_time not in calendar_index: try: start_time = calendar[bisect.bisect_left(calendar, start_time)] @@ -96,7 +97,7 @@ class CalendarProvider(abc.ABC): end_index = calendar_index[end_time] return start_time, end_time, start_index, end_index - def _get_calendar(self, freq, future): + def _get_calendar(self, freq, freq_sam=None, future=False): """Load calendar using memcache. Parameters @@ -113,14 +114,21 @@ class CalendarProvider(abc.ABC): dict dict composed by timestamp as key and index as value for fast search. """ - flag = f"{freq}_future_{future}" + flag = f"{freq}_future_{future}_sam_{freq_sam}" if flag in H["c"]: _calendar, _calendar_index = H["c"][flag] else: + flag_raw = f"{freq}_future_{future}_sam_{None}" _calendar = np.array(self.load_calendar(freq, future)) _calendar_index = {x: i for i, x in enumerate(_calendar)} # for fast search - H["c"][flag] = _calendar, _calendar_index - return _calendar, _calendar_index + H["c"][flag_raw] = _calendar, _calendar_index + if freq_sam is None: + return _calendar, _calendar_index + else: + _calendar_sam = sample_calendar(_calendar, freq, freq_sam) + _calendar_sam_index = {x: i for i, x in enumerate(_calendar_sam)} + H["c"][flag] = _calendar_sam, _calendar_sam_index + return _calendar_sam, _calendar_sam_index def _uri(self, start_time, end_time, freq, future=False): """Get the uri of calendar generation task.""" @@ -530,12 +538,13 @@ class LocalCalendarProvider(CalendarProvider): with open(fname) as f: return [pd.Timestamp(x.strip()) for x in f] - def calendar(self, start_time=None, end_time=None, freq="day", future=False): - _calendar, _calendar_index = self._get_calendar(freq, future) + def calendar(self, start_time=None, end_time=None, freq="day", future=False, freq_sam=None): + _calendar, _ = self._get_calendar(freq=freq, future=future) if start_time == "None": start_time = None if end_time == "None": end_time = None + # strip if start_time: start_time = pd.Timestamp(start_time) @@ -549,8 +558,15 @@ class LocalCalendarProvider(CalendarProvider): return np.array([]) else: end_time = _calendar[-1] - _, _, si, ei = self.locate_index(start_time, end_time, freq, future) - return _calendar[si : ei + 1] + st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, future=future) + _calendar = _calendar[si : ei + 1] + if freq_sam is None: + return _calendar + else: + _calendar_sam, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future) + st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future) + if bisect.bisect(_calendar, st, 0, len(_calendar)): + return np.hstack() class LocalInstrumentProvider(InstrumentProvider): @@ -658,7 +674,7 @@ class LocalExpressionProvider(ExpressionProvider): expression = self.get_expression_instance(field) start_time = pd.Timestamp(start_time) end_time = pd.Timestamp(end_time) - _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False) + _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False) lft_etd, rght_etd = expression.get_extended_window_size() series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq) # Ensure that each column type is consistent diff --git a/qlib/strategy/__init__.py b/qlib/strategy/__init__.py new file mode 100644 index 000000000..6c2e4ceed --- /dev/null +++ b/qlib/strategy/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from .strategy import ( + TopkDropoutStrategy, + BaseStrategy, + WeightStrategyBase, +) diff --git a/qlib/strategy/cost_control.py b/qlib/strategy/cost_control.py new file mode 100644 index 000000000..dd90437b0 --- /dev/null +++ b/qlib/strategy/cost_control.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from .strategy import StrategyWrapper, WeightStrategyBase +import copy + + +class SoftTopkStrategy(WeightStrategyBase): + def __init__(self, topk, max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill"): + """Parameter + topk : int + top-N stocks to buy + risk_degree : float + position percentage of total value + buy_method : + rank_fill: assign the weight stocks that rank high first(1/topk max) + average_fill: assign the weight to the stocks rank high averagely. + """ + super().__init__() + self.topk = topk + self.max_sold_weight = max_sold_weight + self.risk_degree = risk_degree + self.buy_method = buy_method + + def get_risk_degree(self, date): + """get_risk_degree + Return the proportion of your total value you will used in investment. + Dynamically risk_degree will result in Market timing + """ + # It will use 95% amoutn of your total value by default + return self.risk_degree + + def generate_target_weight_position(self, score, current, trade_date): + """Parameter: + score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column + current : current position, use Position() class + trade_date : trade date + generate target position from score for this date and the current position + The cache is not considered in the position + """ + # TODO: + # If the current stock list is more than topk(eg. The weights are modified + # by risk control), the weight will not be handled correctly. + buy_signal_stocks = set(score.sort_values(ascending=False).iloc[: self.topk].index) + cur_stock_weight = current.get_stock_weight_dict(only_stock=True) + + if len(cur_stock_weight) == 0: + final_stock_weight = {code: 1 / self.topk for code in buy_signal_stocks} + else: + final_stock_weight = copy.deepcopy(cur_stock_weight) + sold_stock_weight = 0.0 + for stock_id in final_stock_weight: + if stock_id not in buy_signal_stocks: + sw = min(self.max_sold_weight, final_stock_weight[stock_id]) + sold_stock_weight += sw + final_stock_weight[stock_id] -= sw + if self.buy_method == "first_fill": + for stock_id in buy_signal_stocks: + add_weight = min( + max(1 / self.topk - final_stock_weight.get(stock_id, 0), 0.0), + sold_stock_weight, + ) + final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + add_weight + sold_stock_weight -= add_weight + elif self.buy_method == "average_fill": + for stock_id in buy_signal_stocks: + final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + sold_stock_weight / len( + buy_signal_stocks + ) + else: + raise ValueError("Buy method not found") + return final_stock_weight diff --git a/qlib/strategy/order_generator.py b/qlib/strategy/order_generator.py new file mode 100644 index 000000000..494981ecc --- /dev/null +++ b/qlib/strategy/order_generator.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This order generator is for strategies based on WeightStrategyBase +""" +from ..backtest.position import Position +from ..backtest.exchange import Exchange +import pandas as pd +import copy + + +class OrderGenerator: + def generate_order_list_from_target_weight_position( + self, + current: Position, + trade_exchange: Exchange, + target_weight_position: dict, + risk_degree: float, + pred_date: pd.Timestamp, + trade_date: pd.Timestamp, + ) -> list: + """generate_order_list_from_target_weight_position + + :param current: The current position + :type current: Position + :param trade_exchange: + :type trade_exchange: Exchange + :param target_weight_position: {stock_id : weight} + :type target_weight_position: dict + :param risk_degree: + :type risk_degree: float + :param pred_date: the date the score is predicted + :type pred_date: pd.Timestamp + :param trade_date: the date the stock is traded + :type trade_date: pd.Timestamp + + :rtype: list + """ + raise NotImplementedError() + + +class OrderGenWInteract(OrderGenerator): + """Order Generator With Interact""" + + def generate_order_list_from_target_weight_position( + self, + current: Position, + trade_exchange: Exchange, + target_weight_position: dict, + risk_degree: float, + pred_date: pd.Timestamp, + trade_date: pd.Timestamp, + ) -> list: + """generate_order_list_from_target_weight_position + + No adjustment for for the nontradable share. + All the tadable value is assigned to the tadable stock according to the weight. + if interact == True, will use the price at trade date to generate order list + else, will only use the price before the trade date to generate order list + + :param current: + :type current: Position + :param trade_exchange: + :type trade_exchange: Exchange + :param target_weight_position: + :type target_weight_position: dict + :param risk_degree: + :type risk_degree: float + :param pred_date: + :type pred_date: pd.Timestamp + :param trade_date: + :type trade_date: pd.Timestamp + + :rtype: list + """ + # calculate current_tradable_value + current_amount_dict = current.get_stock_amount_dict() + current_total_value = trade_exchange.calculate_amount_position_value( + amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=False + ) + current_tradable_value = trade_exchange.calculate_amount_position_value( + amount_dict=current_amount_dict, trade_date=trade_date, only_tradable=True + ) + # add cash + current_tradable_value += current.get_cash() + + reserved_cash = (1.0 - risk_degree) * (current_total_value + current.get_cash()) + current_tradable_value -= reserved_cash + + if current_tradable_value < 0: + # if you sell all the tradable stock can not meet the reserved + # value. Then just sell all the stocks + target_amount_dict = copy.deepcopy(current_amount_dict.copy()) + for stock_id in list(target_amount_dict.keys()): + if trade_exchange.is_stock_tradable(stock_id, trade_date): + del target_amount_dict[stock_id] + else: + # consider cost rate + current_tradable_value /= 1 + max(trade_exchange.close_cost, trade_exchange.open_cost) + + # strategy 1 : generate amount_position by weight_position + # Use API in Exchange() + target_amount_dict = trade_exchange.generate_amount_position_from_weight_position( + weight_position=target_weight_position, + cash=current_tradable_value, + trade_date=trade_date, + ) + order_list = trade_exchange.generate_order_for_target_amount_position( + target_position=target_amount_dict, + current_position=current_amount_dict, + trade_date=trade_date, + ) + return order_list + + +class OrderGenWOInteract(OrderGenerator): + """Order Generator Without Interact""" + + def generate_order_list_from_target_weight_position( + self, + current: Position, + trade_exchange: Exchange, + target_weight_position: dict, + risk_degree: float, + pred_date: pd.Timestamp, + trade_date: pd.Timestamp, + ) -> list: + """generate_order_list_from_target_weight_position + + generate order list directly not using the information (e.g. whether can be traded, the accurate trade price) at trade date. + In target weight position, generating order list need to know the price of objective stock in trade date, but we cannot get that + value when do not interact with exchange, so we check the %close price at pred_date or price recorded in current position. + + :param current: + :type current: Position + :param trade_exchange: + :type trade_exchange: Exchange + :param target_weight_position: + :type target_weight_position: dict + :param risk_degree: + :type risk_degree: float + :param pred_date: + :type pred_date: pd.Timestamp + :param trade_date: + :type trade_date: pd.Timestamp + + :rtype: list + """ + risk_total_value = risk_degree * current.calculate_value() + + current_stock = current.get_stock_list() + amount_dict = {} + for stock_id in target_weight_position: + # Current rule will ignore the stock that not hold and cannot be traded at predict date + if trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=pred_date): + amount_dict[stock_id] = ( + risk_total_value * target_weight_position[stock_id] / trade_exchange.get_close(stock_id, pred_date) + ) + elif stock_id in current_stock: + amount_dict[stock_id] = ( + risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id) + ) + else: + continue + order_list = trade_exchange.generate_order_for_target_amount_position( + target_position=amount_dict, + current_position=current.get_stock_amount_dict(), + trade_date=trade_date, + ) + return order_list diff --git a/qlib/strategy/strategy.py b/qlib/strategy/strategy.py new file mode 100644 index 000000000..0476f7d72 --- /dev/null +++ b/qlib/strategy/strategy.py @@ -0,0 +1,304 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +import copy +import numpy as np +import pandas as pd + +from ..data.dataset import DatasetH +from ..backtest.order import Order +from .order_generator import OrderGenWInteract + +""" +1. BaseStrategy 的粒度一定是数据粒度的整数倍 +- 关于calendar的合并咋整 +- adjust_dates这个东西啥用 +- label和freq和strategy的bar分离,这个如何决策呢 +""" +class BaseStrategy: + def __init__(self, bar, start_time, end_time): + self.bar = bar + self.start_time = start_time + self.end_time = end_time + self.current_time = start_time + + def generate_action(self, current): + pass + + +class RuleStrategy(BaseStrategy): + pass + +class DLStrategy(BaseStrategy): + def __init__(self, bar, model, dataset:DatasetH, start_time=None, end_time=None): + super(DLStrategy, self).__init__(bar, start_time, end_time) + self.model = model + self.dataset = dataset + self.pred_score_all = self.model.predict(dataset) + self.pred_score = None + _pred_dates = pred.index.get_level_values(level="datetime") + self.start_time = _pred_dates.min() if start_time is None else start_time + self.end_time = _pred_dates.max() if end_time is None else end_time + self.pred_date = [pd.Timestamp(self.start_time), *D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max(), freq=bar), self.end_time] + self.current_index = -1 + self.pred_length = len(self.pred_date) + + def _update_pred_score(self): + """update pred score + """ + pass + +class AdjustTimer: + """AdjustTimer + Responsible for timing of position adjusting + + This is designed as multiple inheritance mechanism due to: + - the is_adjust may need access to the internel state of a strategy. + + - it can be reguard as a enhancement to the existing strategy. + """ + + # adjust position in each trade date + def is_adjust(self, trade_date): + """is_adjust + Return if the strategy can adjust positions on `trade_date` + Will normally be used in strategy do trading with trade frequency + """ + return True + + +class ListAdjustTimer(AdjustTimer): + def __init__(self, adjust_dates=None): + """__init__ + + :param adjust_dates: an iterable object, it will return a timelist for trading dates + """ + if adjust_dates is None: + # None indicates that all dates is OK for adjusting + self.adjust_dates = None + else: + self.adjust_dates = {pd.Timestamp(dt) for dt in adjust_dates} + + def is_adjust(self, trade_date): + if self.adjust_dates is None: + return True + return pd.Timestamp(trade_date) in self.adjust_dates + +class TopkDropoutStrategy(DLStrategy, ListAdjustTimer): + def __init__( + self, + bar, + model, + dataset, + trade_exchange, + topk, + n_drop, + start_time=None, + end_time=None, + method_sell="bottom", + method_buy="top", + risk_degree=0.95, + thresh=1, + hold_thresh=1, + only_tradable=False, + **kwargs, + ): + """ + Parameters + ----------- + topk : int + the number of stocks in the portfolio. + n_drop : int + number of stocks to be replaced in each trading date. + method_sell : str + dropout method_sell, random/bottom. + method_buy : str + dropout method_buy, random/top. + risk_degree : float + position percentage of total value. + thresh : int + minimun holding days since last buy singal of the stock. + hold_thresh : int + minimum holding days + before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh. + only_tradable : bool + will the strategy only consider the tradable stock when buying and selling. + if only_tradable: + strategy will make buy sell decision without checking the tradable state of the stock. + else: + strategy will make decision with the tradable state of the stock info and avoid buy and sell them. + """ + super(TopkDropoutStrategy, self).__init__(bar, model, dataset, start_time, end_time) + ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None)) + self.trade_exchange = trade_exchange + self.topk = topk + self.n_drop = n_drop + self.method_sell = method_sell + self.method_buy = method_buy + self.risk_degree = risk_degree + self.thresh = thresh + # self.stock_count['code'] will be the days the stock has been hold + # since last buy signal. This is designed for thresh + self.stock_count = {} + + self.hold_thresh = hold_thresh + self.only_tradable = only_tradable + + def get_risk_degree(self, date): + """get_risk_degree + Return the proportion of your total value you will used in investment. + Dynamically risk_degree will result in Market timing. + """ + # It will use 95% amoutn of your total value by default + return self.risk_degree + + def generate_action(self, current): + + self.current_index += 1 + + if not self.is_adjust(trade_date): + return [] + + if self.only_tradable: + # If The strategy only consider tradable stock when make decision + # It needs following actions to filter stocks + def get_first_n(l, n, reverse=False): + cur_n = 0 + res = [] + for si in reversed(l) if reverse else l: + if self.trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date): + res.append(si) + cur_n += 1 + if cur_n >= n: + break + return res[::-1] if reverse else res + + def get_last_n(l, n): + return get_first_n(l, n, reverse=True) + + def filter_stock(l): + return [si for si in l if self.trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date)] + + else: + # Otherwise, the stock will make decision with out the stock tradable info + def get_first_n(l, n): + return list(l)[:n] + + def get_last_n(l, n): + return list(l)[-n:] + + def filter_stock(l): + return l + + current_temp = copy.deepcopy(current) + # generate order list for this adjust date + sell_order_list = [] + buy_order_list = [] + # load score + cash = current_temp.get_cash() + current_stock_list = current_temp.get_stock_list() + # last position (sorted by score) + last = self.pred_score.reindex(current_stock_list).sort_values(ascending=False).index + # The new stocks today want to buy **at most** + if self.method_buy == "top": + today = get_first_n( + self.pred_score[~self.pred_score.index.isin(last)].sort_values(ascending=False).index, + self.n_drop + self.topk - len(last), + ) + elif self.method_buy == "random": + topk_candi = get_first_n(self.pred_score.sort_values(ascending=False).index, self.topk) + candi = list(filter(lambda x: x not in last, topk_candi)) + n = self.n_drop + self.topk - len(last) + try: + today = np.random.choice(candi, n, replace=False) + except ValueError: + today = candi + else: + raise NotImplementedError(f"This type of input is not supported") + # combine(new stocks + last stocks), we will drop stocks from this list + # In case of dropping higher score stock and buying lower score stock. + comb = self.pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index + + # Get the stock list we really want to sell (After filtering the case that we sell high and buy low) + if self.method_sell == "bottom": + sell = last[last.isin(get_last_n(comb, self.n_drop))] + elif self.method_sell == "random": + candi = filter_stock(last) + try: + sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else []) + except ValueError: # No enough candidates + sell = candi + else: + raise NotImplementedError(f"This type of input is not supported") + + # Get the stock list we really want to buy + buy = today[: len(sell) + self.topk - len(last)] + + # buy singal: if a stock falls into topk, it appear in the buy_sinal + buy_signal = self.pred_score.sort_values(ascending=False).iloc[: self.topk].index + + for code in current_stock_list: + if not self.trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date): + continue + if code in sell: + # check hold limit + if self.stock_count[code] < self.thresh or current_temp.get_stock_count(code) < self.hold_thresh: + # can not sell this code + # no buy signal, but the stock is kept + self.stock_count[code] += 1 + continue + # sell order + sell_amount = current_temp.get_stock_amount(code=code) + sell_order = Order( + stock_id=code, + amount=sell_amount, + trade_date=trade_date, + direction=Order.SELL, # 0 for sell, 1 for buy + factor=self.trade_exchange.get_factor(code, trade_date), + ) + # is order executable + if self.trade_exchange.check_order(sell_order): + sell_order_list.append(sell_order) + trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(sell_order, position=current_temp) + # update cash + cash += trade_val - trade_cost + # sold + del self.stock_count[code] + else: + # no buy signal, but the stock is kept + self.stock_count[code] += 1 + elif code in buy_signal: + # NOTE: This is different from the original version + # get new buy signal + # Only the stock fall in to topk will produce buy signal + self.stock_count[code] = 1 + else: + self.stock_count[code] += 1 + # buy new stock + # note the current has been changed + current_stock_list = current_temp.get_stock_list() + value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0 + + # open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not + # consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line + # value = value / (1+self.trade_exchange.open_cost) # set open_cost limit + for code in buy: + # check is stock suspended + if not self.trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date): + continue + # buy order + buy_price = self.trade_exchange.get_deal_price(stock_id=code, trade_date=trade_date) + buy_amount = value / buy_price + factor = self.trade_exchange.quote[(code, trade_date)]["$factor"] + buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor) + buy_order = Order( + stock_id=code, + amount=buy_amount, + trade_date=trade_date, + direction=Order.BUY, # 1 for buy + factor=factor, + ) + buy_order_list.append(buy_order) + self.stock_count[code] = 1 + return sell_order_list + buy_order_list diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 1ee6f07a1..28982bc3a 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -799,3 +799,123 @@ def fname_to_code(fname: str): if fname.startswith(prefix): fname = fname.lstrip(prefix) return fname + +########################## Sample ############################ +def sample_calendar_bac(calendar_raw, freq_raw, freq_sam): + """ + freq_raw : "min" or "day" + """ + freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw + freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam + + if freq_sam.endswith(("minute", "min")): + def cal_next_sam_minute(x, sam_minutes): + hour = x.hour + minute = x.minute + if 9 <= hour <= 11: + minute_index = (11 - hour)*60 + 30 - minute + 120 + elif 13 <= hour <= 15: + minute_index = (15 - hour)*60 - minute + else: + raise ValueError("calendar hour must be in [9, 11] or [13, 15]") + + minute_index = minute_index // sam_minutes * sam_minutes + + if 0 <= minute_index < 120: + return 15 - (minute_index + 59) // 60, (120 - minute_index) % 60 + elif 120 <= minute_index < 240: + return 11 - (minute_index - 120 + 29) // 60, (240 - minute_index + 30) % 60 + else: + raise ValueError("calendar minute_index error") + + sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6]) + + if not freq_raw.endswith(("minute", "min")): + raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") + else: + raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6]) + if raw_minutes > sam_minutes: + raise ValueError("raw freq must be higher than sample freq") + + _calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59), calendar_raw))) + return _calendar_minute + else: + + _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 23, 59, 59), calendar_raw))) + if freq_sam.endswith(("day", "d")): + sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3]) + return _calendar_day[(len(_calendar_day) + sam_days - 1)%sam_days::sam_days] + + elif freq_sam.endswith(("week", "w")): + sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4]) + _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day))) + _calendar_week = _calendar_day[np.ediff1d(_day_in_week[::-1], to_begin=1)[::-1] > 0] + return _calendar_week[(len(_calendar_week) + sam_weeks - 1)%sam_weeks::sam_weeks] + + elif freq_sam.endswith(("month", "m")): + sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5]) + _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day))) + _calendar_month = _calendar_day[np.ediff1d(_day_in_month[::-1], to_begin=1)[::-1] > 0] + return _calendar_month[(len(_calendar_month) + sam_months - 1)%sam_months::sam_months] + else: + raise ValueError("sample freq must be xmin, xd, xw, xm") + +def sample_calendar(calendar_raw, freq_raw, freq_sam): + """ + freq_raw : "min" or "day" + """ + freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw + freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam + + if freq_sam.endswith(("minute", "min")): + def cal_next_sam_minute(x, sam_minutes): + hour = x.hour + minute = x.minute + if 9 <= hour <= 11: + minute_index = (hour - 9)*60 + minute - 30 + elif 13 <= hour <= 15: + minute_index = (hour - 13)*60 + minute + 120 + else: + raise ValueError("calendar hour must be in [9, 11] or [13, 15]") + + minute_index = minute_index // sam_minutes * sam_minutes + + if 0 <= minute_index < 120: + return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60 + elif 120 <= minute_index < 240: + return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60 + else: + raise ValueError("calendar minute_index error") + sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6]) + if not freq_raw.endswith(("minute", "min")): + raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") + else: + raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6]) + if raw_minutes > sam_minutes: + raise ValueError("raw freq must be higher than sample freq") + _calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 0), calendar_raw))) + return _calendar_minute + else: + _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw))) + if freq_sam.endswith(("day", "d")): + sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3]) + return _calendar_day[::sam_days] + + elif freq_sam.endswith(("week", "w")): + sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4]) + _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day))) + _calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0] + return _calendar_week[::sam_weeks] + + elif freq_sam.endswith(("month", "m")): + sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5]) + _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day))) + _calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0] + return _calendar_month[::sam_months] + else: + raise ValueError("sample freq must be xmin, xd, xw, xm") + +def sample_feature(feature_raw, freq, start_time, end_time, method="last"): + datetime_raw = feature_raw.index.get_level_values("datetime") + feature_sample = feature_raw[list(map(lambda x: start_time < x <= end_time, datetime_raw))] + return getattr(feature_sample.groupby(level="instrument"), method)() \ No newline at end of file