mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Supporting Arctic Backend Provider & Orderbook, Tick Data Example (#744)
* change weight_decay & batchsize * del weight_decay * big weight_decay * mid weight_decay * small layer * 2 layer * full layer * no weight decay * divide into two data source * change parse field * delete some debug * add Toperator * new format of arctic * fix cache bug to arctic read * fix connection problem * add some operator * final version for arcitc * clear HZ cache * remove not used function * add topswrappers * successfully import data and run first test * A simpler version to support arctic * Successfully run all high-freq expressions * Black format and fix add docs * Add docs for download and test data * update scripts and docs * Add docs * fix bug * Refine docs * fix test bug * fix CI error * clean code Co-authored-by: bxdd <bxddream@gmail.com> Co-authored-by: wangwenxi.handsome <wangwenxi.handsome@gmail.com> Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from .data import (
|
||||
LocalCalendarProvider,
|
||||
LocalInstrumentProvider,
|
||||
LocalFeatureProvider,
|
||||
ArcticFeatureProvider,
|
||||
LocalExpressionProvider,
|
||||
LocalDatasetProvider,
|
||||
ClientCalendarProvider,
|
||||
|
||||
@@ -150,7 +150,7 @@ class Expression(abc.ABC):
|
||||
args = str(self), instrument, start_index, end_index, freq
|
||||
if args in H["f"]:
|
||||
return H["f"][args]
|
||||
if start_index is None or end_index is None or start_index > end_index:
|
||||
if start_index is not None and end_index is not None and start_index > end_index:
|
||||
raise ValueError("Invalid index range: {} {}".format(start_index, end_index))
|
||||
try:
|
||||
series = self._load_internal(instrument, start_index, end_index, freq)
|
||||
|
||||
@@ -147,6 +147,7 @@ class MemCache:
|
||||
"""
|
||||
|
||||
size_limit = C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit
|
||||
limit_type = C.mem_cache_limit_type if limit_type is None else limit_type
|
||||
|
||||
if limit_type == "length":
|
||||
klass = MemCacheLengthUnit
|
||||
@@ -1198,7 +1199,4 @@ class MemoryCalendarCache(CalendarCache):
|
||||
return result
|
||||
|
||||
|
||||
# MemCache sizeof
|
||||
HZ = MemCache(C.mem_cache_space_limit, limit_type="sizeof")
|
||||
# MemCache length
|
||||
H = MemCache(limit_type="length")
|
||||
H = MemCache()
|
||||
|
||||
@@ -5,8 +5,10 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import time
|
||||
import copy
|
||||
import queue
|
||||
import bisect
|
||||
@@ -15,9 +17,11 @@ import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
from typing import Iterable, Union
|
||||
from typing import List, Union
|
||||
from arctic import Arctic
|
||||
|
||||
# For supporting multiprocessing in outer code, joblib is used
|
||||
from joblib import delayed
|
||||
import pymongo
|
||||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
@@ -38,11 +42,17 @@ from ..utils import (
|
||||
normalize_cache_fields,
|
||||
code_to_fname,
|
||||
set_log_with_config,
|
||||
time_to_slc_point,
|
||||
)
|
||||
from ..utils.paral import ParallelExt
|
||||
|
||||
|
||||
class ProviderBackendMixin:
|
||||
"""
|
||||
This helper class tries to make the provider based on storage backend more convenient
|
||||
It is not necessary to inherent this class if that provider don't rely on the backend storage
|
||||
"""
|
||||
|
||||
def get_default_backend(self):
|
||||
backend = {}
|
||||
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
|
||||
@@ -59,15 +69,12 @@ class ProviderBackendMixin:
|
||||
return init_instance_by_config(backend)
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
class CalendarProvider(abc.ABC):
|
||||
"""Calendar provider base class
|
||||
|
||||
Provide calendar data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
"""Get calendar of certain market in given time range.
|
||||
|
||||
@@ -194,15 +201,12 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
raise NotImplementedError("Subclass of CalendarProvider must implement `load_calendar` method")
|
||||
|
||||
|
||||
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
class InstrumentProvider(abc.ABC):
|
||||
"""Instrument provider base class
|
||||
|
||||
Provide instrument data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@staticmethod
|
||||
def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] = None):
|
||||
"""Get the general config dictionary for a base market adding several dynamic filters.
|
||||
@@ -304,15 +308,12 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
raise ValueError(f"Unknown instrument type {inst}")
|
||||
|
||||
|
||||
class FeatureProvider(abc.ABC, ProviderBackendMixin):
|
||||
class FeatureProvider(abc.ABC):
|
||||
"""Feature provider class
|
||||
|
||||
Provide feature data.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def feature(self, instrument, field, start_time, end_time, freq):
|
||||
"""Get feature data.
|
||||
@@ -365,9 +366,13 @@ class ExpressionProvider(abc.ABC):
|
||||
return expression
|
||||
|
||||
@abc.abstractmethod
|
||||
def expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
|
||||
def expression(self, instrument, field, start_time=None, end_time=None, freq="day") -> pd.Series:
|
||||
"""Get Expression data.
|
||||
|
||||
The responsibility of `expression`
|
||||
- parse the `field` and `load` the according data.
|
||||
- When loading the data, it should handle the time dependency of the data. `get_expression_instance` is commonly used in this method
|
||||
|
||||
Parameters
|
||||
----------
|
||||
instrument : str
|
||||
@@ -385,6 +390,11 @@ class ExpressionProvider(abc.ABC):
|
||||
-------
|
||||
pd.Series
|
||||
data of a certain expression
|
||||
|
||||
The data has two types of format
|
||||
1) expression with datetime index
|
||||
2) expression with integer index
|
||||
- because the datetime is not as good as
|
||||
"""
|
||||
raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method")
|
||||
|
||||
@@ -500,7 +510,7 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
normalize_column_names = normalize_cache_fields(column_names)
|
||||
# One process for one task, so that the memory will be freed quicker.
|
||||
workers = max(min(C.kernels, len(instruments_d)), 1)
|
||||
workers = max(min(C.get_kernels(freq), len(instruments_d)), 1)
|
||||
|
||||
# create iterator
|
||||
if isinstance(instruments_d, dict):
|
||||
@@ -513,7 +523,7 @@ class DatasetProvider(abc.ABC):
|
||||
for inst, spans in it:
|
||||
inst_l.append(inst)
|
||||
task_l.append(
|
||||
delayed(DatasetProvider.expression_calculator)(
|
||||
delayed(DatasetProvider.inst_calculator)(
|
||||
inst, start_time, end_time, freq, normalize_column_names, spans, C, inst_processors
|
||||
)
|
||||
)
|
||||
@@ -536,17 +546,17 @@ class DatasetProvider(abc.ABC):
|
||||
data = DiskDatasetCache.cache_to_origin_data(data, column_names)
|
||||
else:
|
||||
data = pd.DataFrame(
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")),
|
||||
columns=column_names,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def expression_calculator(
|
||||
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]
|
||||
):
|
||||
def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]):
|
||||
"""
|
||||
Calculate the expressions for one instrument, return a df result.
|
||||
Calculate the expressions for **one** instrument, return a df result.
|
||||
If the expression has been calculated before, load from cache.
|
||||
|
||||
return value: A data frame with index 'datetime' and other data columns.
|
||||
@@ -566,8 +576,10 @@ class DatasetProvider(abc.ABC):
|
||||
obj[field] = ExpressionD.expression(inst, field, start_time, end_time, freq)
|
||||
|
||||
data = pd.DataFrame(obj)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
|
||||
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
if spans is not None:
|
||||
@@ -583,15 +595,16 @@ class DatasetProvider(abc.ABC):
|
||||
return data
|
||||
|
||||
|
||||
class LocalCalendarProvider(CalendarProvider):
|
||||
class LocalCalendarProvider(CalendarProvider, ProviderBackendMixin):
|
||||
"""Local calendar data provider class
|
||||
|
||||
Provide calendar data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LocalCalendarProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
def __init__(self, remote=False, backend={}):
|
||||
super().__init__()
|
||||
self.remote = remote
|
||||
self.backend = backend
|
||||
|
||||
def load_calendar(self, freq, future):
|
||||
"""Load original calendar timestamp from file.
|
||||
@@ -623,12 +636,16 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
return [pd.Timestamp(x) for x in backend_obj]
|
||||
|
||||
|
||||
class LocalInstrumentProvider(InstrumentProvider):
|
||||
class LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin):
|
||||
"""Local instrument data provider class
|
||||
|
||||
Provide instrument data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self, backend={}) -> None:
|
||||
super().__init__()
|
||||
self.backend = backend
|
||||
|
||||
def _load_instruments(self, market, freq):
|
||||
return self.backend_obj(market=market, freq=freq).data
|
||||
|
||||
@@ -667,15 +684,16 @@ class LocalInstrumentProvider(InstrumentProvider):
|
||||
return _instruments_filtered
|
||||
|
||||
|
||||
class LocalFeatureProvider(FeatureProvider):
|
||||
class LocalFeatureProvider(FeatureProvider, ProviderBackendMixin):
|
||||
"""Local feature data provider class
|
||||
|
||||
Provide feature data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LocalFeatureProvider, self).__init__(**kwargs)
|
||||
self.remote = kwargs.get("remote", False)
|
||||
def __init__(self, remote=False, backend={}):
|
||||
super().__init__()
|
||||
self.remote = remote
|
||||
self.backend = backend
|
||||
|
||||
def feature(self, instrument, field, start_index, end_index, freq):
|
||||
# validate
|
||||
@@ -684,20 +702,72 @@ class LocalFeatureProvider(FeatureProvider):
|
||||
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
|
||||
|
||||
|
||||
class ArcticFeatureProvider(FeatureProvider):
|
||||
def __init__(
|
||||
self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=[("09:15", "11:30"), ("13:00", "15:00")]
|
||||
):
|
||||
super().__init__()
|
||||
self.uri = uri
|
||||
# TODO:
|
||||
# retry connecting if error occurs
|
||||
# does it real matters?
|
||||
self.retry_time = retry_time
|
||||
# NOTE: this is especially important for TResample operator
|
||||
self.market_transaction_time_list = market_transaction_time_list
|
||||
|
||||
def feature(self, instrument, field, start_index, end_index, freq):
|
||||
field = str(field)[1:]
|
||||
with pymongo.MongoClient(self.uri) as client:
|
||||
# TODO: this will result in frequently connecting the server and performance issue
|
||||
arctic = Arctic(client)
|
||||
|
||||
if freq not in arctic.list_libraries():
|
||||
raise ValueError("lib {} not in arctic".format(freq))
|
||||
|
||||
if instrument not in arctic[freq].list_symbols():
|
||||
# instruments does not exist
|
||||
return pd.Series()
|
||||
else:
|
||||
df = arctic[freq].read(instrument, columns=[field], chunk_range=(start_index, end_index))
|
||||
s = df[field]
|
||||
|
||||
if not s.empty:
|
||||
s = pd.concat(
|
||||
[
|
||||
s.between_time(time_tuple[0], time_tuple[1])
|
||||
for time_tuple in self.market_transaction_time_list
|
||||
]
|
||||
)
|
||||
return s
|
||||
|
||||
|
||||
class LocalExpressionProvider(ExpressionProvider):
|
||||
"""Local expression data provider class
|
||||
|
||||
Provide expression data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self, time2idx=True):
|
||||
super().__init__()
|
||||
self.time2idx = time2idx
|
||||
|
||||
def expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
|
||||
expression = self.get_expression_instance(field)
|
||||
start_time = pd.Timestamp(start_time)
|
||||
end_time = pd.Timestamp(end_time)
|
||||
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False)
|
||||
lft_etd, rght_etd = expression.get_extended_window_size()
|
||||
start_time = time_to_slc_point(start_time)
|
||||
end_time = time_to_slc_point(end_time)
|
||||
|
||||
# Two kinds of queries are supported
|
||||
# - Index-based expression: this may save a lot of memory because the datetime index is not saved on the disk
|
||||
# - Data with datetime index expression: this will make it more convenient to integrating with some existing databases
|
||||
if self.time2idx:
|
||||
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False)
|
||||
lft_etd, rght_etd = expression.get_extended_window_size()
|
||||
query_start, query_end = max(0, start_index - lft_etd), end_index + rght_etd
|
||||
else:
|
||||
start_index, end_index = query_start, query_end = start_time, end_time
|
||||
|
||||
try:
|
||||
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
|
||||
series = expression.load(instrument, query_start, query_end, freq)
|
||||
except Exception as e:
|
||||
get_module_logger("data").debug(
|
||||
f"Loading expression error: "
|
||||
@@ -726,8 +796,18 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
Provide dataset data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, align_time: bool = True):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
align_time : bool
|
||||
Will we align the time to calendar
|
||||
the frequency is flexible in some dataset and can't be aligned.
|
||||
For the data with fixed frequency with a shared calendar, the align data to the calendar will provides following benefits
|
||||
- Align queries to the same parameters, so the cache can be shared.
|
||||
"""
|
||||
super().__init__()
|
||||
self.align_time = align_time
|
||||
|
||||
def dataset(
|
||||
self,
|
||||
@@ -740,14 +820,16 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
):
|
||||
instruments_d = self.get_instruments_d(instruments, freq)
|
||||
column_names = self.get_column_names(fields)
|
||||
cal = Cal.calendar(start_time, end_time, freq)
|
||||
if len(cal) == 0:
|
||||
return pd.DataFrame(
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
|
||||
)
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
if self.align_time:
|
||||
# NOTE: if the frequency is a fixed value.
|
||||
# align the data to fixed calendar point
|
||||
cal = Cal.calendar(start_time, end_time, freq)
|
||||
if len(cal) == 0:
|
||||
return pd.DataFrame(
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
|
||||
)
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
data = self.dataset_processor(
|
||||
instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors
|
||||
)
|
||||
|
||||
@@ -721,9 +721,9 @@ class Rolling(ExpressionOps):
|
||||
# NOTE: remove all null check,
|
||||
# now it's user's responsibility to decide whether use features in null days
|
||||
# isnull = series.isnull() # NOTE: isnull = NaN, inf is not null
|
||||
if self.N == 0:
|
||||
if isinstance(self.N, int) and self.N == 0:
|
||||
series = getattr(series.expanding(min_periods=1), self.func)()
|
||||
elif 0 < self.N < 1:
|
||||
elif isinstance(self.N, int) and 0 < self.N < 1:
|
||||
series = series.ewm(alpha=self.N, min_periods=1).mean()
|
||||
else:
|
||||
series = getattr(series.rolling(self.N, min_periods=1), self.func)()
|
||||
@@ -1380,6 +1380,7 @@ class PairRolling(ExpressionOps):
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right, N, func):
|
||||
# TODO: in what case will a const be passed into `__init__` as `feature_left` or `feature_right`
|
||||
self.feature_left = feature_left
|
||||
self.feature_right = feature_right
|
||||
self.N = N
|
||||
@@ -1389,8 +1390,19 @@ class PairRolling(ExpressionOps):
|
||||
return "{}({},{},{})".format(type(self).__name__, self.feature_left, self.feature_right, self.N)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
assert any(
|
||||
[isinstance(self.feature_left, Expression), self.feature_right, Expression]
|
||||
), "at least one of two inputs is Expression instance"
|
||||
|
||||
if isinstance(self.feature_left, Expression):
|
||||
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
|
||||
else:
|
||||
series_left = self.feature_left # numeric value
|
||||
if isinstance(self.feature_right, Expression):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
else:
|
||||
series_right = self.feature_right
|
||||
|
||||
if self.N == 0:
|
||||
series = getattr(series_left.expanding(min_periods=1), self.func)(series_right)
|
||||
else:
|
||||
@@ -1400,21 +1412,33 @@ class PairRolling(ExpressionOps):
|
||||
def get_longest_back_rolling(self):
|
||||
if self.N == 0:
|
||||
return np.inf
|
||||
return (
|
||||
max(self.feature_left.get_longest_back_rolling(), self.feature_right.get_longest_back_rolling())
|
||||
+ self.N
|
||||
- 1
|
||||
)
|
||||
if isinstance(self.feature_left, Expression):
|
||||
left_br = self.feature_left.get_longest_back_rolling()
|
||||
else:
|
||||
left_br = 0
|
||||
|
||||
if isinstance(self.feature_right, Expression):
|
||||
right_br = self.feature_right.get_longest_back_rolling()
|
||||
else:
|
||||
right_br = 0
|
||||
return max(left_br, right_br)
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll, lr = self.feature_left.get_extended_window_size()
|
||||
rl, rr = self.feature_right.get_extended_window_size()
|
||||
if self.N == 0:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The PairRolling(ATTR, 0) will not be accurately calculated"
|
||||
)
|
||||
return -np.inf, max(lr, rr)
|
||||
else:
|
||||
if isinstance(self.feature_left, Expression):
|
||||
ll, lr = self.feature_left.get_extended_window_size()
|
||||
else:
|
||||
ll, lr = 0, 0
|
||||
|
||||
if isinstance(self.feature_right, Expression):
|
||||
rl, rr = self.feature_right.get_extended_window_size()
|
||||
else:
|
||||
rl, rr = 0, 0
|
||||
return max(ll, rl) + self.N - 1, max(lr, rr)
|
||||
|
||||
|
||||
@@ -1474,7 +1498,50 @@ class Cov(PairRolling):
|
||||
super(Cov, self).__init__(feature_left, feature_right, N, "cov")
|
||||
|
||||
|
||||
#################### Operator which only support data with time index ####################
|
||||
# Convention
|
||||
# - The name of the operators in this section will start with "T"
|
||||
|
||||
|
||||
class TResample(ElemOperator):
|
||||
def __init__(self, feature, freq, func):
|
||||
"""
|
||||
Resampling the data to target frequency.
|
||||
The resample function of pandas is used.
|
||||
- the timestamp will be at the start of the time span after resample.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Expression
|
||||
An expression for calculating the feature
|
||||
freq : str
|
||||
It will be passed into the resample method for resampling basedn on given frequency
|
||||
func : method
|
||||
The method to get the resampled values
|
||||
Some expression are high frequently used
|
||||
"""
|
||||
self.feature = feature
|
||||
self.freq = freq
|
||||
self.func = func
|
||||
|
||||
def __str__(self):
|
||||
return "{}({},{})".format(type(self).__name__, self.feature, self.freq)
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
series = self.feature.load(instrument, start_index, end_index, freq)
|
||||
|
||||
if series.empty:
|
||||
return series
|
||||
else:
|
||||
if self.func == "sum":
|
||||
return getattr(series.resample(self.freq), self.func)(min_count=1)
|
||||
else:
|
||||
return getattr(series.resample(self.freq), self.func)()
|
||||
|
||||
|
||||
TOpsList = [TResample]
|
||||
OpsList = [
|
||||
Rolling,
|
||||
Ref,
|
||||
Max,
|
||||
Min,
|
||||
@@ -1521,7 +1588,7 @@ OpsList = [
|
||||
IdxMin,
|
||||
If,
|
||||
Feature,
|
||||
]
|
||||
] + [TResample]
|
||||
|
||||
|
||||
class OpsWrapper:
|
||||
|
||||
Reference in New Issue
Block a user