1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00

new high freq struc

This commit is contained in:
wangwenxi.handsome
2021-08-26 15:54:19 +00:00
committed by you-n-g
parent d9ad8ff791
commit 25f54ddaeb
6 changed files with 151 additions and 145 deletions

View File

@@ -19,6 +19,7 @@ from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManag
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
# make import more user-friendly by enable `from qlib.backtest import STH`

View File

@@ -9,19 +9,16 @@ if TYPE_CHECKING:
from qlib.backtest.position import BasePosition, Position
import random
import logging
from typing import List, Tuple, Union, Callable, Iterable
from typing import List, Tuple, Union
import numpy as np
import pandas as pd
from ..data.data import D
from ..data.dataset.utils import get_level_index
from ..config import C, REG_CN
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from .order import Order, OrderDir, OrderHelper
from .high_performance_ds import PandasQuote, CN1Min_NumpyQuote
from .high_performance_ds import PandasQuote, CN1min_NumpyQuote
class Exchange:
@@ -39,7 +36,7 @@ class Exchange:
close_cost=0.0025,
min_cost=5,
extra_quote=None,
quote_cls=PandasQuote,
quote_cls=CN1min_NumpyQuote,
**kwargs,
):
"""__init__

View File

@@ -3,6 +3,7 @@
from builtins import ValueError, isinstance
from functools import lru_cache
import logging
from typing import List, Text, Union, Callable, Iterable, Dict
from collections import OrderedDict
@@ -15,7 +16,7 @@ import numpy as np
from ..utils.index_data import IndexData
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from ..utils.time import _if_single_data
from ..utils.time import if_single_data
class BaseQuote:
@@ -38,9 +39,9 @@ class BaseQuote:
stock_id: str,
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
fields: str = None,
method: Union[str, Callable] = None,
) -> Union[None, float, "IndexData"]:
fields: Union[str, None] = None,
method: Union[str, Callable, None] = None,
) -> Union[None, Union[int, float, bool], "IndexData"]:
"""get the specific fields of stock data during start time and end_time,
and apply method to the data.
@@ -62,7 +63,7 @@ class BaseQuote:
this function is used for three case:
1. Both fields and method are not None. It returns float.
1. Both fields and method are not None. It returns int/float/bool.
print(get_data(stock_id="SH600000", start_time="2010-01-04", end_time="2010-01-06", fields="$close", method="last"))
85.713585
@@ -88,15 +89,15 @@ class BaseQuote:
closed start time for backtest
end_time : Union[pd.Timestamp, str]
closed end time for backtest
fields : str
fields : Union[str, None]
the columns of data to fetch
method : Union[str, Callable]
method : Union[str, Callable, None]
the method apply to data.
e.g [None, "last", "all", "sum", "mean", qlib/utils/resam.py/ts_data_last]
Return
----------
Union[None, float, pd.Series, pd.DataFrame, IndexData]
Union[None, Union[int, float, bool], IndexData]
please refer to Example as following.
"""
@@ -115,121 +116,105 @@ class PandasQuote(BaseQuote):
return self.data.keys()
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
if fields is None and method is not None:
raise ValueError(f"method must be None when fields is None")
if fields is None:
return resam_ts_data(self.data[stock_id], start_time, end_time, method=method)
elif isinstance(fields, (str, list)):
return resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method)
stock_data = resam_ts_data(self.data[stock_id], start_time, end_time, method=method)
elif isinstance(fields, str):
stock_data = resam_ts_data(self.data[stock_id][fields], start_time, end_time, method=method)
else:
raise ValueError(f"fields must be None, str or list")
raise ValueError(f"fields must be None, str")
if stock_data is None:
return None
elif isinstance(stock_data, (bool, np.bool_, int, float, np.signedinteger, np.floating)):
return stock_data
elif isinstance(stock_data, pd.Series):
return IndexData.Series(stock_data)
elif isinstance(stock_data, pd.DataFrame):
return stock_data.values
else:
raise ValueError(f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame")
class CN1Min_NumpyQuote(BaseQuote):
class CN1min_NumpyQuote(BaseQuote):
def __init__(self, quote_df: pd.DataFrame):
"""CN1Min_NumpyQuote
"""CN1min_NumpyQuote
Parameters
----------
quote_df : pd.DataFrame
the init dataframe from qlib.
Variables
self.data: Dict[stock_id, np.ndarray]
each stock has one two-dimensional np.ndarray to represent data.
self.columns: Dict[str, int]
map column name to column id in self.data.
self.dt2idx: Dict[stock_id, Dict[pd.Timestap, int]]
map timestap to row id in self.data.
self.idx2dt: Dict[stock_id, List[pd.Timestap]]
the dt2idx of each stock for searching.
self.data : Dict(stock_id, IndexData.DataFrame)
"""
super().__init__(quote_df=quote_df)
# init data
columns = quote_df.columns.values
self.columns = dict(zip(columns, range(len(columns))))
self.data, self.dt2idx, self.idx2dt = self._to_numpy(quote_df)
# lru
self.multi_lru = {}
self.max_lru_len = 256
def _to_numpy(self, quote_df):
"""convert dataframe to numpy."""
quote_dict = {}
date_dict = {}
date_list = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val.values
date_dict[stock_id] = stock_val.index.get_level_values("datetime")
date_list[stock_id] = list(date_dict[stock_id])
for stock_id in date_dict:
date_dict[stock_id] = dict(zip(date_dict[stock_id], range(len(date_dict[stock_id]))))
return quote_dict, date_dict, date_list
quote_dict[stock_id] = IndexData.DataFrame(stock_val.droplevel(level="instrument"))
self.data = quote_dict
self.freq = np.timedelta64(1, "m")
def get_all_stock(self):
return self.data.keys()
def get_data(self, stock_id, start_time, end_time, fields=None, method=None):
# check fields
if isinstance(fields, list) and len(fields) > 1:
raise ValueError(f"get_data in CN1Min_NumpyQuote only supports one field")
if fields is None and method is not None:
raise ValueError(f"method must be None when fields is None")
# check stock id
if stock_id not in self.get_all_stock():
return None
# get single data
# single data is only one piece of data, so it don't need to agg by method.
if _if_single_data(start_time, end_time, np.timedelta64(1, "m")):
if start_time not in self.dt2idx[stock_id]:
# single data
# If it don't consider the classification of single data, it will consume a lot of time.
if if_single_data(start_time, end_time, self.freq):
now_index_map = self.data[stock_id].index_map
now_columns_map = self.data[stock_id].columns_map
if start_time not in now_index_map:
return None
if fields is None:
# it used for check if data is None
return self.data[stock_id][self.dt2idx[stock_id][start_time]]
return self.data[stock_id].values[now_index_map[start_time]]
else:
return self.data[stock_id][self.dt2idx[stock_id][start_time]][self.columns[fields]]
# get muti row data
return self.data[stock_id].values[now_index_map[start_time], now_columns_map[fields]]
# multi data
else:
# check lru
if (stock_id, start_time, end_time, fields, method) in self.multi_lru:
return self.multi_lru[(stock_id, start_time, end_time, fields, method)]
start_id = bisect.bisect_left(self.idx2dt[stock_id], start_time)
end_id = bisect.bisect_right(self.idx2dt[stock_id], end_time)
if start_id == end_id:
return None
# it used for check if data is None
if fields is None:
return self.data[stock_id][start_id:end_id]
elif method is None:
stock_data = self.data[stock_id][start_id:end_id, self.columns[fields]]
stock_dt2idx = self.idx2dt[stock_id][start_id:end_id].to_list()
return IndexData(stock_data, stock_dt2idx)
else:
agg_stock_data = self._agg_data(self.data[stock_id][start_id:end_id, self.columns[fields]], method)
# result lru
if len(self.multi_lru) >= self.max_lru_len:
self.multi_lru.clear()
self.multi_lru[(stock_id, start_time, end_time, fields, method)] = agg_stock_data
return agg_stock_data
if fields is None and method is None:
stock_data = self.data[stock_id].loc(start_time, end_time)
if stock_data.empty:
return None
else:
return stock_data.values
elif fields is not None and method is None:
stock_data = self.data[stock_id].loc(start_time, end_time, fields)
if stock_data.empty:
return None
else:
return stock_data
elif fields is not None and method is not None:
stock_data = self.data[stock_id].loc(start_time, end_time, fields)
if stock_data.empty:
return None
elif len(stock_data) == 1:
return stock_data[0]
else:
return self._agg_data(stock_data.values, method)
def _agg_data(self, data, method):
"""Agg data by specific method."""
valid_data = data[data != np.array(None)].copy()
if method == "sum":
return np.nansum(valid_data)
return np.nansum(data)
elif method == "mean":
return np.nanmean(valid_data)
return np.nanmean(data)
elif method == "last":
return valid_data[-1]
return data[-1]
elif method == "all":
return valid_data.all()
return data.all()
elif method == "any":
return valid_data.any()
return data.any()
elif method == ts_data_last:
valid_data = valid_data[valid_data != np.NaN]
valid_data = data[data != np.NaN]
if len(valid_data) == 0:
return None
else:
@@ -412,6 +397,7 @@ class BaseOrderIndicator:
def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, List[str]], fill_value: float = None):
"""sum indicators with the same metrics.
and assign to the order_indicator(BaseOrderIndicator).
NOTE: indicators could be a empty list when orders in lower level all fail.
Parameters
----------
@@ -601,6 +587,11 @@ class PandasOrderIndicator(BaseOrderIndicator):
class NumpyOrderIndicator(BaseOrderIndicator):
"""
The data structure is OrderedDict(str: IndexData.Series).
Each IndexData.Series is one metric.
Str is the name of metric.
"""
def __init__(self):
self.data: Dict[str, IndexData.Series] = OrderedDict()
@@ -640,4 +631,4 @@ class NumpyOrderIndicator(BaseOrderIndicator):
tmp_metric = IndexData.Series()
for indicator in indicators:
tmp_metric = tmp_metric.add(indicator.data[metric], fill_value)
order_indicator.data[metric] = tmp_metric
order_indicator.data[metric] = tmp_metric

View File

@@ -3,25 +3,19 @@
from collections import OrderedDict
from logging import warning
import pathlib
from typing import Dict, List, Tuple, Union, Callable
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from pandas.core import groupby
from pandas.core.frame import DataFrame
from qlib.backtest.exchange import Exchange
from qlib.backtest.order import BaseTradeDecision, Order, OrderDir
from qlib.backtest.utils import TradeCalendarManager
from .high_performance_ds import PandasOrderIndicator, NumpyOrderIndicator
from ..utils.index_data import IndexData, SingleData
from ..data import D
from ..utils.index_data import IndexData, SingleData
from ..tests.config import CSI300_BENCH
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
from ..utils.time import Freq
from .order import IdxTradeRange
@@ -391,9 +385,7 @@ class Indicator:
if price_s is None:
return None, None
if isinstance(price_s, pd.Series):
price_s = IndexData.Series(price_s)
elif isinstance(price_s, (int, float, np.floating)):
if isinstance(price_s, (int, float, np.signedinteger, np.floating)):
price_s = IndexData.Series(price_s, [trade_start_time])
elif isinstance(price_s, SingleData):
pass
@@ -479,10 +471,10 @@ class Indicator:
bv_new = IndexData.Series(bv_new)
bp_all.append(bp_new)
bv_all.append(bv_new)
bp_all = IndexData.concat(bp_all, axis = 1)
bv_all = IndexData.concat(bv_all, axis = 1)
bp_all = IndexData.concat(bp_all, axis=1)
bv_all = IndexData.concat(bv_all, axis=1)
base_volume = bv_all.sum(axis = 1)
base_volume = bv_all.sum(axis=1)
self.order_indicator.assign("base_volume", base_volume.to_dict())
self.order_indicator.assign("base_price", ((bp_all * bv_all).sum(axis=1) / base_volume).to_dict())

View File

@@ -2,16 +2,20 @@
# Licensed under the MIT License.
from typing import Union, Callable
import bisect
import numpy as np
import pandas as pd
from typing import Union, Callable
class IndexData:
"""This is a simplified version of pandas which is faster based on numpy.
"""
"""This is a simplified version of pandas which is faster based on numpy."""
@staticmethod
def Series(data: Union[dict, pd.Series, int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []):
def Series(
data: Union[dict, pd.Series, int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []
):
if isinstance(data, dict):
return SingleData(list(data.values()), list(data.keys()))
elif isinstance(data, pd.Series):
@@ -20,16 +24,20 @@ class IndexData:
return SingleData(data, index)
@staticmethod
def DataFrame(data: Union[pd.DataFrame, list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []):
def DataFrame(
data: Union[pd.DataFrame, list, np.ndarray] = [[]],
index: Union[list, pd.Index] = [],
columns: Union[list, pd.Index] = [],
):
if isinstance(data, pd.DataFrame):
return MultiData(data.values, data.index, data.columns)
else:
else:
return MultiData(data, index, columns)
@staticmethod
def concat(data_list, axis = 0):
def concat(data_list, axis=0):
"""concat all SingleData by index.
just for 1-dim data.
TODO: now just for SingleData.
Parameters
----------
@@ -57,15 +65,15 @@ class IndexData:
for data_id, index_data in enumerate(data_list):
assert isinstance(index_data, SingleData)
now_data_map = [all_index_map[index] for index in index_data.index]
tmp_data[now_data_map, data_id] = index_data.data
tmp_data[now_data_map, data_id] = index_data.data
return MultiData(tmp_data, all_index)
else:
raise ValueError(f"axis must be 0 or 1")
class BaseData:
"""Base data structure of SingleData and MultiData.
"""
"""Base data structure of SingleData and MultiData."""
def __init__(self):
self.index_columns = self._get_index_columns()
@@ -78,8 +86,7 @@ class BaseData:
return index_columns
def _align_index(self, other):
"""Align index before performing the four arithmetic operations.
"""
"""Align index before performing the four arithmetic operations."""
raise NotImplementedError(f"please implement _align_index func")
def __add__(self, other):
@@ -158,14 +165,12 @@ class BaseData:
return self.__class__(~self.data, *self.index_columns)
def abs(self):
"""get the abs of data except np.NaN.
"""
"""get the abs of data except np.NaN."""
tmp_data = np.absolute(self.data)
return self.__class__(tmp_data, *self.index_columns)
def astype(self, type):
"""change the type of data.
"""
"""change the type of data."""
tmp_data = self.data.astype(type)
return self.__class__(tmp_data, *self.index_columns)
@@ -178,8 +183,7 @@ class BaseData:
return self.__class__(tmp_data, *self.index_columns)
def apply(self, func: Callable):
"""apply a function to data.
"""
"""apply a function to data."""
tmp_data = func(self.data)
return self.__class__(tmp_data, *self.index_columns)
@@ -224,6 +228,10 @@ class BaseData:
def empty(self):
return len(self.data) == 0
@property
def values(self):
return self.data
class SingleData(BaseData):
def __init__(self, data: Union[int, float, np.floating, list, np.ndarray] = [], index: Union[list, pd.Index] = []):
@@ -239,7 +247,7 @@ class SingleData(BaseData):
"""
# data
if isinstance(data, (int, float, np.floating)):
self.data = np.full(len(index), fill_value=data)
self.data = np.full(len(index), fill_value=data, dtype=np.float64)
elif isinstance(data, list):
self.data = np.array(data)
elif isinstance(data, np.ndarray):
@@ -249,12 +257,12 @@ class SingleData(BaseData):
# data in SingleData must be one dim
assert self.data.ndim == 1
# replace int with float
if self.data.dtype == np.int:
if self.data.dtype == np.signedinteger:
self.data = self.data.astype(np.float64)
# replace None with np.NaN, because pd.Series does it.
if None in self.data:
self.data[self.data == None] = np.NaN
# index
if isinstance(index, list):
if index == [] and len(self.data) > 0:
@@ -265,18 +273,20 @@ class SingleData(BaseData):
else:
raise ValueError(f"index must be list or pd.Index")
assert len(self.data) == len(self.index)
# if data is not empty,
# if data is not empty,
self.index_map = dict(zip(self.index, range(len(self.index))))
super(SingleData, self).__init__()
def _align_index(self, other):
if self.index == other.index:
return self, other
return self, other
elif set(self.index) == set(other.index):
return self, other.reindex(self.index)
else:
raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations")
raise ValueError(
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
)
def reindex(self, index, fill_value=np.NaN):
"""reindex data and fill the missing value with np.NaN.
@@ -291,7 +301,7 @@ class SingleData(BaseData):
SingleData
reindex data
"""
tmp_data = np.full(len(index), fill_value, np.float64)
tmp_data = np.full(len(index), fill_value, dtype=np.float64)
for index_id, index_item in enumerate(index):
if index_item in self.index:
tmp_data[index_id] = self.data[self.index_map[index_item]]
@@ -299,8 +309,8 @@ class SingleData(BaseData):
def add(self, other, fill_value=0):
common_index = list(set(self.index) | set(other.index))
tmp_data1 = self.reindex(common_index,fill_value)
tmp_data2 = other.reindex(common_index,fill_value)
tmp_data1 = self.reindex(common_index, fill_value)
tmp_data2 = other.reindex(common_index, fill_value)
return tmp_data1 + tmp_data2
def to_dict(self):
@@ -324,7 +334,7 @@ class SingleData(BaseData):
return MultiData(self.data[:, np.newaxis], self.index)
def to_pd_series(self):
return pd.Series(self.data, index = self.index)
return pd.Series(self.data, index=self.index)
def __getitem__(self, index: Union["SingleData", int, str]):
if isinstance(index, int):
@@ -340,7 +350,12 @@ class SingleData(BaseData):
class MultiData(BaseData):
def __init__(self, data: Union[list, np.ndarray] = [[]], index: Union[list, pd.Index] = [], columns: Union[list, pd.Index] = []):
def __init__(
self,
data: Union[list, np.ndarray] = [[]],
index: Union[list, pd.Index] = [],
columns: Union[list, pd.Index] = [],
):
"""A data structure of index and numpy data.
It's used to replace pd.DataFrame due to high-speed.
@@ -363,12 +378,12 @@ class MultiData(BaseData):
# data in SingleData must be two dim
assert self.data.ndim == 2
# replace int with float
if self.data.dtype == np.int:
if self.data.dtype == np.signedinteger:
self.data = self.data.astype(np.float64)
# replace None with np.NaN, because pd.DataFrame does it.
if None in self.data:
self.data[self.data == None] = np.NaN
# index
if isinstance(index, list):
if index == [] and self.data.shape[0] > 0:
@@ -379,7 +394,7 @@ class MultiData(BaseData):
else:
raise ValueError(f"index must be list or pd.Index")
assert self.data.shape[0] == len(self.index)
# if data is not empty,
# if data is not empty,
self.index_map = dict(zip(self.index, range(len(self.index))))
# columns
@@ -392,19 +407,29 @@ class MultiData(BaseData):
else:
raise ValueError(f"columns must be list or pd.Index")
assert self.data.shape[1] == len(self.columns)
# if data is not empty,
self.columns_map = dict(zip(self.columns, range(len(self.columns))))
# if data is not empty,
self.columns_map = dict(zip(self.columns, range(len(self.columns))))
super(MultiData, self).__init__()
def _align_index(self, other):
if self.index_columns == other.index_columns:
return self, other
return self, other
else:
raise ValueError(f"The indexes of self and other do not meet the requirements of the four arithmetic operations")
raise ValueError(
f"The indexes of self and other do not meet the requirements of the four arithmetic operations"
)
def __getitem__(self, col) -> SingleData:
if col not in self.columns:
return SingleData()
else:
return SingleData(self.data[:, self.columns_map[col]], self.index)
def loc(self, start, end, col=None):
start_id = bisect.bisect_left(self.index, start)
end_id = bisect.bisect_right(self.index, end)
if col is None:
return MultiData(self.data[start_id:end_id], self.index[start_id:end_id], self.columns)
else:
return SingleData(self.data[start_id:end_id, self.columns_map[col]], self.index[start_id:end_id])

View File

@@ -38,7 +38,7 @@ def get_min_cal(shift: int = 0) -> List[time]:
return cal
def _if_single_data(start_time, end_time, freq):
def if_single_data(start_time, end_time, freq):
"""Is there only one piece of data to obtain.
Parameters