mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 11:30:57 +08:00
new high freq struc
This commit is contained in:
committed by
you-n-g
parent
d9ad8ff791
commit
25f54ddaeb
@@ -19,6 +19,7 @@ from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManag
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
|
||||
# make import more user-friendly by enable `from qlib.backtest import STH`
|
||||
|
||||
|
||||
|
||||
@@ -9,19 +9,16 @@ if TYPE_CHECKING:
|
||||
|
||||
from qlib.backtest.position import BasePosition, Position
|
||||
import random
|
||||
import logging
|
||||
from typing import List, Tuple, Union, Callable, Iterable
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..data.data import D
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..config import C, REG_CN
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from .order import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import PandasQuote, CN1Min_NumpyQuote
|
||||
from .high_performance_ds import PandasQuote, CN1min_NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
@@ -39,7 +36,7 @@ class Exchange:
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
extra_quote=None,
|
||||
quote_cls=PandasQuote,
|
||||
quote_cls=CN1min_NumpyQuote,
|
||||
**kwargs,
|
||||
):
|
||||
"""__init__
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
from builtins import ValueError, isinstance
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import List, Text, Union, Callable, Iterable, Dict
|
||||
from collections import OrderedDict
|
||||
@@ -15,7 +16,7 @@ import numpy as np
|
||||
from ..utils.index_data import IndexData
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import _if_single_data
|
||||
from ..utils.time import if_single_data
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
@@ -38,9 +39,9 @@ class BaseQuote:
|
||||
stock_id: str,
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
fields: str = None,
|
||||
method: Union[str, Callable] = None,
|
||||
) -> Union[None, float, "IndexData"]:
|
||||
fields: Union[str, None] = None,
|
||||
method: Union[str, Callable, None] = None,
|
||||
) -> Union[None, Union[int, float, bool], "IndexData"]:
|
||||
"""get the specific fields of stock data during start time and end_time,
|
||||
and apply method to the data.
|
||||
|
||||
@@ -62,7 +63,7 @@ class BaseQuote:
|
||||
|
||||
this function is used for three case:
|
||||
|
||||
1. Both fields and method are not None. It returns float.
|
||||
1. Both fields and method are not None. It returns int/float/bool.
|
||||
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields="$close", method="last"))
|
||||
|
||||
85.713585
|
||||
@@ -88,15 +89,15 @@ class BaseQuote:
|
||||
closed start time for backtest
|
||||
end_time : Union[pd.Timestamp, str]
|
||||
closed end time for backtest
|
||||
fields : str
|
||||
fields : Union[str, None]
|
||||
the columns of data to fetch
|
||||
method : Union[str, Callable]
|
||||
method : Union[str, Callable, None]
|
||||
the method apply to data.
|
||||
e.g [None, "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last]
|
||||
|
||||
Return
|
||||
----------
|
||||
Union[None, float, pd.Series, pd.DataFrame, IndexData]
|
||||
Union[None, Union[int, float, bool], IndexData]
|
||||
please refer to Example as following.
|
||||
"""
|
||||
|
||||
@@ -115,121 +116,105 @@ class PandasQuote(BaseQuote):
|
||||
return self.data.keys()
|
||||
|
||||
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
|
||||
if fields is None and method is not None:
|
||||
raise ValueError(f"method must be None when fields is None")
|
||||
|
||||
if fields is None:
|
||||
return resam_ts_data(self.data[stock_id], start_time, end_time, method=method)
|
||||
elif isinstance(fields, (str, list)):
|
||||
return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method)
|
||||
stock_data = resam_ts_data(self.data[stock_id], start_time, end_time, method=method)
|
||||
elif isinstance(fields, str):
|
||||
stock_data = resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method)
|
||||
else:
|
||||
raise ValueError(f"fields must be None, str or list")
|
||||
raise ValueError(f"fields must be None, str")
|
||||
|
||||
if stock_data is None:
|
||||
return None
|
||||
elif isinstance(stock_data, (bool, np.bool_, int, float, np.signedinteger, np.floating)):
|
||||
return stock_data
|
||||
elif isinstance(stock_data, pd.Series):
|
||||
return IndexData.Series(stock_data)
|
||||
elif isinstance(stock_data, pd.DataFrame):
|
||||
return stock_data.values
|
||||
else:
|
||||
raise ValueError(f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame")
|
||||
|
||||
|
||||
class CN1Min_NumpyQuote(BaseQuote):
|
||||
class CN1min_NumpyQuote(BaseQuote):
|
||||
def __init__(self, quote_df: pd.DataFrame):
|
||||
"""CN1Min_NumpyQuote
|
||||
"""CN1min_NumpyQuote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
quote_df : pd.DataFrame
|
||||
the init dataframe from qlib.
|
||||
|
||||
Variables
|
||||
self.data: Dict[stock_id, np.ndarray]
|
||||
each stock has one two-dimensional np.ndarray to represent data.
|
||||
self.columns: Dict[str, int]
|
||||
map column name to column id in self.data.
|
||||
self.dt2idx: Dict[stock_id, Dict[pd.Timestap, int]]
|
||||
map timestap to row id in self.data.
|
||||
self.idx2dt: Dict[stock_id, List[pd.Timestap]]
|
||||
the dt2idx of each stock for searching.
|
||||
self.data : Dict(stock_id, IndexData.DataFrame)
|
||||
"""
|
||||
|
||||
super().__init__(quote_df=quote_df)
|
||||
# init data
|
||||
columns = quote_df.columns.values
|
||||
self.columns = dict(zip(columns, range(len(columns))))
|
||||
self.data, self.dt2idx, self.idx2dt = self._to_numpy(quote_df)
|
||||
|
||||
# lru
|
||||
self.multi_lru = {}
|
||||
self.max_lru_len = 256
|
||||
|
||||
def _to_numpy(self, quote_df):
|
||||
"""convert dataframe to numpy."""
|
||||
|
||||
quote_dict = {}
|
||||
date_dict = {}
|
||||
date_list = {}
|
||||
for stock_id, stock_val in quote_df.groupby(level="instrument"):
|
||||
quote_dict[stock_id] = stock_val.values
|
||||
date_dict[stock_id] = stock_val.index.get_level_values("datetime")
|
||||
date_list[stock_id] = list(date_dict[stock_id])
|
||||
for stock_id in date_dict:
|
||||
date_dict[stock_id] = dict(zip(date_dict[stock_id], range(len(date_dict[stock_id]))))
|
||||
return quote_dict, date_dict, date_list
|
||||
quote_dict[stock_id] = IndexData.DataFrame(stock_val.droplevel(level="instrument"))
|
||||
self.data = quote_dict
|
||||
self.freq = np.timedelta64(1, "m")
|
||||
|
||||
def get_all_stock(self):
|
||||
return self.data.keys()
|
||||
|
||||
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
|
||||
# check fields
|
||||
if isinstance(fields, list) and len(fields) > 1:
|
||||
raise ValueError(f"get_data in CN1Min_NumpyQuote only supports one field")
|
||||
if fields is None and method is not None:
|
||||
raise ValueError(f"method must be None when fields is None")
|
||||
|
||||
# check stock id
|
||||
if stock_id not in self.get_all_stock():
|
||||
return None
|
||||
|
||||
# get single data
|
||||
# single data is only one piece of data, so it don't need to agg by method.
|
||||
if _if_single_data(start_time, end_time, np.timedelta64(1, "m")):
|
||||
if start_time not in self.dt2idx[stock_id]:
|
||||
# single data
|
||||
# If it don't consider the classification of single data, it will consume a lot of time.
|
||||
if if_single_data(start_time, end_time, self.freq):
|
||||
now_index_map = self.data[stock_id].index_map
|
||||
now_columns_map = self.data[stock_id].columns_map
|
||||
if start_time not in now_index_map:
|
||||
return None
|
||||
if fields is None:
|
||||
# it used for check if data is None
|
||||
return self.data[stock_id][self.dt2idx[stock_id][start_time]]
|
||||
return self.data[stock_id].values[now_index_map[start_time]]
|
||||
else:
|
||||
return self.data[stock_id][self.dt2idx[stock_id][start_time]][self.columns[fields]]
|
||||
# get muti row data
|
||||
return self.data[stock_id].values[now_index_map[start_time], now_columns_map[fields]]
|
||||
|
||||
# multi data
|
||||
else:
|
||||
# check lru
|
||||
if (stock_id, start_time, end_time, fields, method) in self.multi_lru:
|
||||
return self.multi_lru[(stock_id, start_time, end_time, fields, method)]
|
||||
|
||||
start_id = bisect.bisect_left(self.idx2dt[stock_id], start_time)
|
||||
end_id = bisect.bisect_right(self.idx2dt[stock_id], end_time)
|
||||
if start_id == end_id:
|
||||
return None
|
||||
# it used for check if data is None
|
||||
if fields is None:
|
||||
return self.data[stock_id][start_id:end_id]
|
||||
elif method is None:
|
||||
stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]]
|
||||
stock_dt2idx = self.idx2dt[stock_id][start_id:end_id].to_list()
|
||||
return IndexData(stock_data, stock_dt2idx)
|
||||
else:
|
||||
agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method)
|
||||
|
||||
# result lru
|
||||
if len(self.multi_lru) >= self.max_lru_len:
|
||||
self.multi_lru.clear()
|
||||
self.multi_lru[(stock_id, start_time, end_time, fields, method)] = agg_stock_data
|
||||
return agg_stock_data
|
||||
if fields is None and method is None:
|
||||
stock_data = self.data[stock_id].loc(start_time, end_time)
|
||||
if stock_data.empty:
|
||||
return None
|
||||
else:
|
||||
return stock_data.values
|
||||
elif fields is not None and method is None:
|
||||
stock_data = self.data[stock_id].loc(start_time, end_time, fields)
|
||||
if stock_data.empty:
|
||||
return None
|
||||
else:
|
||||
return stock_data
|
||||
elif fields is not None and method is not None:
|
||||
stock_data = self.data[stock_id].loc(start_time, end_time, fields)
|
||||
if stock_data.empty:
|
||||
return None
|
||||
elif len(stock_data) == 1:
|
||||
return stock_data[0]
|
||||
else:
|
||||
return self._agg_data(stock_data.values, method)
|
||||
|
||||
def _agg_data(self, data, method):
|
||||
"""Agg data by specific method."""
|
||||
valid_data = data[data != np.array(None)].copy()
|
||||
if method == "sum":
|
||||
return np.nansum(valid_data)
|
||||
return np.nansum(data)
|
||||
elif method == "mean":
|
||||
return np.nanmean(valid_data)
|
||||
return np.nanmean(data)
|
||||
elif method == "last":
|
||||
return valid_data[-1]
|
||||
return data[-1]
|
||||
elif method == "all":
|
||||
return valid_data.all()
|
||||
return data.all()
|
||||
elif method == "any":
|
||||
return valid_data.any()
|
||||
return data.any()
|
||||
elif method == ts_data_last:
|
||||
valid_data = valid_data[valid_data != np.NaN]
|
||||
valid_data = data[data != np.NaN]
|
||||
if len(valid_data) == 0:
|
||||
return None
|
||||
else:
|
||||
@@ -412,6 +397,7 @@ class BaseOrderIndicator:
|
||||
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
|
||||
"""sum indicators with the same metrics.
|
||||
and assign to the order_indicator(BaseOrderIndicator).
|
||||
NOTE: indicators could be a empty list when orders in lower level all fail.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -601,6 +587,11 @@ class PandasOrderIndicator(BaseOrderIndicator):
|
||||
|
||||
|
||||
class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
"""
|
||||
The data structure is OrderedDict(str: IndexData.Series).
|
||||
Each IndexData.Series is one metric.
|
||||
Str is the name of metric.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data: Dict[str, IndexData.Series] = OrderedDict()
|
||||
@@ -640,4 +631,4 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
tmp_metric = IndexData.Series()
|
||||
for indicator in indicators:
|
||||
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
|
||||
order_indicator.data[metric] = tmp_metric
|
||||
order_indicator.data[metric] = tmp_metric
|
||||
|
||||
@@ -3,25 +3,19 @@
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
from logging import warning
|
||||
import pathlib
|
||||
from typing import Dict, List, Tuple, Union, Callable
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.core import groupby
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
|
||||
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator
|
||||
from ..utils.index_data import IndexData, SingleData
|
||||
from ..data import D
|
||||
from ..utils.index_data import IndexData, SingleData
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
from ..utils.time import Freq
|
||||
from .order import IdxTradeRange
|
||||
|
||||
|
||||
@@ -391,9 +385,7 @@ class Indicator:
|
||||
if price_s is None:
|
||||
return None, None
|
||||
|
||||
if isinstance(price_s, pd.Series):
|
||||
price_s = IndexData.Series(price_s)
|
||||
elif isinstance(price_s, (int, float, np.floating)):
|
||||
if isinstance(price_s, (int, float, np.signedinteger, np.floating)):
|
||||
price_s = IndexData.Series(price_s, [trade_start_time])
|
||||
elif isinstance(price_s, SingleData):
|
||||
pass
|
||||
@@ -479,10 +471,10 @@ class Indicator:
|
||||
bv_new = IndexData.Series(bv_new)
|
||||
bp_all.append(bp_new)
|
||||
bv_all.append(bv_new)
|
||||
bp_all = IndexData.concat(bp_all, axis = 1)
|
||||
bv_all = IndexData.concat(bv_all, axis = 1)
|
||||
bp_all = IndexData.concat(bp_all, axis=1)
|
||||
bv_all = IndexData.concat(bv_all, axis=1)
|
||||
|
||||
base_volume = bv_all.sum(axis = 1)
|
||||
base_volume = bv_all.sum(axis=1)
|
||||
self.order_indicator.assign("base_volume", base_volume.to_dict())
|
||||
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user