1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Supporting Arctic Backend Provider & Orderbook, Tick Data Example (#744)

* change weight_decay & batchsize

* del weight_decay

* big weight_decay

* mid weight_decay

* small layer

* 2 layer

* full layer

* no weight decay

* divide into two data source

* change parse field

* delete some debug

* add Toperator

* new format of arctic

* fix cache bug to arctic read

* fix connection problem

* add some operator

* final version for arcitc

* clear HZ cache

* remove not used function

* add topswrappers

* successfully import data and run first test

* A simpler version to support arctic

* Successfully run all high-freq expressions

* Black format and fix add docs

* Add docs for download and test data

* update scripts and docs

* Add docs

* fix bug

* Refine docs

* fix test bug

* fix CI error

* clean code

Co-authored-by: bxdd <bxddream@gmail.com>
Co-authored-by: wangwenxi.handsome <wangwenxi.handsome@gmail.com>
Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
luocy16
2022-01-18 09:13:11 +08:00
committed by GitHub
parent 7f274b1e4e
commit 2bb8a4ce0e
16 changed files with 923 additions and 90 deletions

View File

@@ -15,6 +15,7 @@ from .data import (
LocalCalendarProvider,
LocalInstrumentProvider,
LocalFeatureProvider,
ArcticFeatureProvider,
LocalExpressionProvider,
LocalDatasetProvider,
ClientCalendarProvider,

View File

@@ -150,7 +150,7 @@ class Expression(abc.ABC):
args = str(self), instrument, start_index, end_index, freq
if args in H["f"]:
return H["f"][args]
if start_index is None or end_index is None or start_index > end_index:
if start_index is not None and end_index is not None and start_index > end_index:
raise ValueError("Invalid index range: {} {}".format(start_index, end_index))
try:
series = self._load_internal(instrument, start_index, end_index, freq)

View File

@@ -147,6 +147,7 @@ class MemCache:
"""
size_limit = C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit
limit_type = C.mem_cache_limit_type if limit_type is None else limit_type
if limit_type == "length":
klass = MemCacheLengthUnit
@@ -1198,7 +1199,4 @@ class MemoryCalendarCache(CalendarCache):
return result
# MemCache sizeof
HZ = MemCache(C.mem_cache_space_limit, limit_type="sizeof")
# MemCache length
H = MemCache(limit_type="length")
H = MemCache()

View File

@@ -5,8 +5,10 @@
from __future__ import division
from __future__ import print_function
import os
import re
import abc
import time
import copy
import queue
import bisect
@@ -15,9 +17,11 @@ import pandas as pd
from multiprocessing import Pool
from typing import Iterable, Union
from typing import List, Union
from arctic import Arctic
# For supporting multiprocessing in outer code, joblib is used
from joblib import delayed
import pymongo
from .cache import H
from ..config import C
@@ -38,11 +42,17 @@ from ..utils import (
normalize_cache_fields,
code_to_fname,
set_log_with_config,
time_to_slc_point,
)
from ..utils.paral import ParallelExt
class ProviderBackendMixin:
"""
This helper class tries to make the provider based on storage backend more convenient
It is not necessary to inherent this class if that provider don't rely on the backend storage
"""
def get_default_backend(self):
backend = {}
provider_name: str = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[-2]
@@ -59,15 +69,12 @@ class ProviderBackendMixin:
return init_instance_by_config(backend)
class CalendarProvider(abc.ABC, ProviderBackendMixin):
class CalendarProvider(abc.ABC):
"""Calendar provider base class
Provide calendar data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
"""Get calendar of certain market in given time range.
@@ -194,15 +201,12 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
raise NotImplementedError("Subclass of CalendarProvider must implement `load_calendar` method")
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
class InstrumentProvider(abc.ABC):
"""Instrument provider base class
Provide instrument data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@staticmethod
def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] = None):
"""Get the general config dictionary for a base market adding several dynamic filters.
@@ -304,15 +308,12 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
raise ValueError(f"Unknown instrument type {inst}")
class FeatureProvider(abc.ABC, ProviderBackendMixin):
class FeatureProvider(abc.ABC):
"""Feature provider class
Provide feature data.
"""
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
@abc.abstractmethod
def feature(self, instrument, field, start_time, end_time, freq):
"""Get feature data.
@@ -365,9 +366,13 @@ class ExpressionProvider(abc.ABC):
return expression
@abc.abstractmethod
def expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
def expression(self, instrument, field, start_time=None, end_time=None, freq="day") -> pd.Series:
"""Get Expression data.
The responsibility of `expression`
- parse the `field` and `load` the according data.
- When loading the data, it should handle the time dependency of the data. `get_expression_instance` is commonly used in this method
Parameters
----------
instrument : str
@@ -385,6 +390,11 @@ class ExpressionProvider(abc.ABC):
-------
pd.Series
data of a certain expression
The data has two types of format
1) expression with datetime index
2) expression with integer index
- because the datetime is not as good as
"""
raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method")
@@ -500,7 +510,7 @@ class DatasetProvider(abc.ABC):
"""
normalize_column_names = normalize_cache_fields(column_names)
# One process for one task, so that the memory will be freed quicker.
workers = max(min(C.kernels, len(instruments_d)), 1)
workers = max(min(C.get_kernels(freq), len(instruments_d)), 1)
# create iterator
if isinstance(instruments_d, dict):
@@ -513,7 +523,7 @@ class DatasetProvider(abc.ABC):
for inst, spans in it:
inst_l.append(inst)
task_l.append(
delayed(DatasetProvider.expression_calculator)(
delayed(DatasetProvider.inst_calculator)(
inst, start_time, end_time, freq, normalize_column_names, spans, C, inst_processors
)
)
@@ -536,17 +546,17 @@ class DatasetProvider(abc.ABC):
data = DiskDatasetCache.cache_to_origin_data(data, column_names)
else:
data = pd.DataFrame(
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")),
columns=column_names,
dtype=np.float32,
)
return data
@staticmethod
def expression_calculator(
inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]
):
def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]):
"""
Calculate the expressions for one instrument, return a df result.
Calculate the expressions for **one** instrument, return a df result.
If the expression has been calculated before, load from cache.
return value: A data frame with index 'datetime' and other data columns.
@@ -566,8 +576,10 @@ class DatasetProvider(abc.ABC):
obj[field] = ExpressionD.expression(inst, field, start_time, end_time, freq)
data = pd.DataFrame(obj)
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(int)]
if not data.empty and not np.issubdtype(data.index.dtype, np.dtype("M")):
# If the underlaying provides the data not in datatime formmat, we'll convert it into datetime format
_calendar = Cal.calendar(freq=freq)
data.index = _calendar[data.index.values.astype(int)]
data.index.names = ["datetime"]
if spans is not None:
@@ -583,15 +595,16 @@ class DatasetProvider(abc.ABC):
return data
class LocalCalendarProvider(CalendarProvider):
class LocalCalendarProvider(CalendarProvider, ProviderBackendMixin):
"""Local calendar data provider class
Provide calendar data from local data source.
"""
def __init__(self, **kwargs):
super(LocalCalendarProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
def __init__(self, remote=False, backend={}):
super().__init__()
self.remote = remote
self.backend = backend
def load_calendar(self, freq, future):
"""Load original calendar timestamp from file.
@@ -623,12 +636,16 @@ class LocalCalendarProvider(CalendarProvider):
return [pd.Timestamp(x) for x in backend_obj]
class LocalInstrumentProvider(InstrumentProvider):
class LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin):
"""Local instrument data provider class
Provide instrument data from local data source.
"""
def __init__(self, backend={}) -> None:
super().__init__()
self.backend = backend
def _load_instruments(self, market, freq):
return self.backend_obj(market=market, freq=freq).data
@@ -667,15 +684,16 @@ class LocalInstrumentProvider(InstrumentProvider):
return _instruments_filtered
class LocalFeatureProvider(FeatureProvider):
class LocalFeatureProvider(FeatureProvider, ProviderBackendMixin):
"""Local feature data provider class
Provide feature data from local data source.
"""
def __init__(self, **kwargs):
super(LocalFeatureProvider, self).__init__(**kwargs)
self.remote = kwargs.get("remote", False)
def __init__(self, remote=False, backend={}):
super().__init__()
self.remote = remote
self.backend = backend
def feature(self, instrument, field, start_index, end_index, freq):
# validate
@@ -684,20 +702,72 @@ class LocalFeatureProvider(FeatureProvider):
return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1]
class ArcticFeatureProvider(FeatureProvider):
def __init__(
self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=[("09:15", "11:30"), ("13:00", "15:00")]
):
super().__init__()
self.uri = uri
# TODO:
# retry connecting if error occurs
# does it real matters?
self.retry_time = retry_time
# NOTE: this is especially important for TResample operator
self.market_transaction_time_list = market_transaction_time_list
def feature(self, instrument, field, start_index, end_index, freq):
field = str(field)[1:]
with pymongo.MongoClient(self.uri) as client:
# TODO: this will result in frequently connecting the server and performance issue
arctic = Arctic(client)
if freq not in arctic.list_libraries():
raise ValueError("lib {} not in arctic".format(freq))
if instrument not in arctic[freq].list_symbols():
# instruments does not exist
return pd.Series()
else:
df = arctic[freq].read(instrument, columns=[field], chunk_range=(start_index, end_index))
s = df[field]
if not s.empty:
s = pd.concat(
[
s.between_time(time_tuple[0], time_tuple[1])
for time_tuple in self.market_transaction_time_list
]
)
return s
class LocalExpressionProvider(ExpressionProvider):
"""Local expression data provider class
Provide expression data from local data source.
"""
def __init__(self, time2idx=True):
super().__init__()
self.time2idx = time2idx
def expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
expression = self.get_expression_instance(field)
start_time = pd.Timestamp(start_time)
end_time = pd.Timestamp(end_time)
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False)
lft_etd, rght_etd = expression.get_extended_window_size()
start_time = time_to_slc_point(start_time)
end_time = time_to_slc_point(end_time)
# Two kinds of queries are supported
# - Index-based expression: this may save a lot of memory because the datetime index is not saved on the disk
# - Data with datetime index expression: this will make it more convenient to integrating with some existing databases
if self.time2idx:
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False)
lft_etd, rght_etd = expression.get_extended_window_size()
query_start, query_end = max(0, start_index - lft_etd), end_index + rght_etd
else:
start_index, end_index = query_start, query_end = start_time, end_time
try:
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
series = expression.load(instrument, query_start, query_end, freq)
except Exception as e:
get_module_logger("data").debug(
f"Loading expression error: "
@@ -726,8 +796,18 @@ class LocalDatasetProvider(DatasetProvider):
Provide dataset data from local data source.
"""
def __init__(self):
pass
def __init__(self, align_time: bool = True):
"""
Parameters
----------
align_time : bool
Will we align the time to calendar
the frequency is flexible in some dataset and can't be aligned.
For the data with fixed frequency with a shared calendar, the align data to the calendar will provides following benefits
- Align queries to the same parameters, so the cache can be shared.
"""
super().__init__()
self.align_time = align_time
def dataset(
self,
@@ -740,14 +820,16 @@ class LocalDatasetProvider(DatasetProvider):
):
instruments_d = self.get_instruments_d(instruments, freq)
column_names = self.get_column_names(fields)
cal = Cal.calendar(start_time, end_time, freq)
if len(cal) == 0:
return pd.DataFrame(
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
)
start_time = cal[0]
end_time = cal[-1]
if self.align_time:
# NOTE: if the frequency is a fixed value.
# align the data to fixed calendar point
cal = Cal.calendar(start_time, end_time, freq)
if len(cal) == 0:
return pd.DataFrame(
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
)
start_time = cal[0]
end_time = cal[-1]
data = self.dataset_processor(
instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors
)

View File

@@ -721,9 +721,9 @@ class Rolling(ExpressionOps):
# NOTE: remove all null check,
# now it's user's responsibility to decide whether use features in null days
# isnull = series.isnull() # NOTE: isnull = NaN, inf is not null
if self.N == 0:
if isinstance(self.N, int) and self.N == 0:
series = getattr(series.expanding(min_periods=1), self.func)()
elif 0 < self.N < 1:
elif isinstance(self.N, int) and 0 < self.N < 1:
series = series.ewm(alpha=self.N, min_periods=1).mean()
else:
series = getattr(series.rolling(self.N, min_periods=1), self.func)()
@@ -1380,6 +1380,7 @@ class PairRolling(ExpressionOps):
"""
def __init__(self, feature_left, feature_right, N, func):
# TODO: in what case will a const be passed into `__init__` as `feature_left` or `feature_right`
self.feature_left = feature_left
self.feature_right = feature_right
self.N = N
@@ -1389,8 +1390,19 @@ class PairRolling(ExpressionOps):
return "{}({},{},{})".format(type(self).__name__, self.feature_left, self.feature_right, self.N)
def _load_internal(self, instrument, start_index, end_index, freq):
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
assert any(
[isinstance(self.feature_left, Expression), self.feature_right, Expression]
), "at least one of two inputs is Expression instance"
if isinstance(self.feature_left, Expression):
series_left = self.feature_left.load(instrument, start_index, end_index, freq)
else:
series_left = self.feature_left # numeric value
if isinstance(self.feature_right, Expression):
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
else:
series_right = self.feature_right
if self.N == 0:
series = getattr(series_left.expanding(min_periods=1), self.func)(series_right)
else:
@@ -1400,21 +1412,33 @@ class PairRolling(ExpressionOps):
def get_longest_back_rolling(self):
if self.N == 0:
return np.inf
return (
max(self.feature_left.get_longest_back_rolling(), self.feature_right.get_longest_back_rolling())
+ self.N
- 1
)
if isinstance(self.feature_left, Expression):
left_br = self.feature_left.get_longest_back_rolling()
else:
left_br = 0
if isinstance(self.feature_right, Expression):
right_br = self.feature_right.get_longest_back_rolling()
else:
right_br = 0
return max(left_br, right_br)
def get_extended_window_size(self):
ll, lr = self.feature_left.get_extended_window_size()
rl, rr = self.feature_right.get_extended_window_size()
if self.N == 0:
get_module_logger(self.__class__.__name__).warning(
"The PairRolling(ATTR, 0) will not be accurately calculated"
)
return -np.inf, max(lr, rr)
else:
if isinstance(self.feature_left, Expression):
ll, lr = self.feature_left.get_extended_window_size()
else:
ll, lr = 0, 0
if isinstance(self.feature_right, Expression):
rl, rr = self.feature_right.get_extended_window_size()
else:
rl, rr = 0, 0
return max(ll, rl) + self.N - 1, max(lr, rr)
@@ -1474,7 +1498,50 @@ class Cov(PairRolling):
super(Cov, self).__init__(feature_left, feature_right, N, "cov")
#################### Operator which only support data with time index ####################
# Convention
# - The name of the operators in this section will start with "T"
class TResample(ElemOperator):
def __init__(self, feature, freq, func):
"""
Resampling the data to target frequency.
The resample function of pandas is used.
- the timestamp will be at the start of the time span after resample.
Parameters
----------
feature : Expression
An expression for calculating the feature
freq : str
It will be passed into the resample method for resampling basedn on given frequency
func : method
The method to get the resampled values
Some expression are high frequently used
"""
self.feature = feature
self.freq = freq
self.func = func
def __str__(self):
return "{}({},{})".format(type(self).__name__, self.feature, self.freq)
def _load_internal(self, instrument, start_index, end_index, freq):
series = self.feature.load(instrument, start_index, end_index, freq)
if series.empty:
return series
else:
if self.func == "sum":
return getattr(series.resample(self.freq), self.func)(min_count=1)
else:
return getattr(series.resample(self.freq), self.func)()
TOpsList = [TResample]
OpsList = [
Rolling,
Ref,
Max,
Min,
@@ -1521,7 +1588,7 @@ OpsList = [
IdxMin,
If,
Feature,
]
] + [TResample]
class OpsWrapper: