mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
init version of online serving and rolling
This commit is contained in:
0
qlib/model/ens/__init__.py
Normal file
0
qlib/model/ens/__init__.py
Normal file
13
qlib/workflow/task/__init__.py
Normal file
13
qlib/workflow/task/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Task related workflow is implemented in this folder
|
||||
|
||||
A typical task workflow
|
||||
|
||||
| Step | Description |
|
||||
|-----------------------+------------------------------------------------|
|
||||
| TaskGen | Generating tasks. |
|
||||
| TaskManager(optional) | Manage generated tasks |
|
||||
| run task | retrive tasks from TaskManager and run tasks. |
|
||||
"""
|
||||
52
qlib/workflow/task/collect.py
Normal file
52
qlib/workflow/task/collect.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from qlib.workflow import R
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class RollingEnsemble:
|
||||
'''
|
||||
Rolling Models Ensemble based on (R)ecord
|
||||
|
||||
This shares nothing with Ensemble
|
||||
'''
|
||||
# TODO: 这边还可以加加速
|
||||
def __init__(self, get_key_func, flt_func=None):
|
||||
self.get_key_func = get_key_func
|
||||
self.flt_func = flt_func
|
||||
|
||||
def __call__(self, exp_name) -> Union[pd.Series, dict]:
|
||||
# TODO;
|
||||
# Should we split the scripts into several sub functions?
|
||||
exp = R.get_exp(experiment_name=exp_name)
|
||||
|
||||
# filter records
|
||||
recs = exp.list_recorders()
|
||||
|
||||
recs_flt = {}
|
||||
for rid, rec in tqdm(recs.items(), desc="Loading data"):
|
||||
# rec = exp.get_recorder(recorder_id=rid)
|
||||
params = rec.load_object("param")
|
||||
if rec.status == rec.STATUS_FI:
|
||||
if self.flt_func is None or self.flt_func(params):
|
||||
rec.params = params
|
||||
recs_flt[rid] = rec
|
||||
|
||||
# group
|
||||
recs_group = {}
|
||||
for _, rec in recs_flt.items():
|
||||
params = rec.params
|
||||
group_key = self.get_key_func(params)
|
||||
recs_group.setdefault(group_key, []).append(rec)
|
||||
|
||||
# reduce group
|
||||
reduce_group = {}
|
||||
for k, rec_l in recs_group.items():
|
||||
pred_l = []
|
||||
for rec in rec_l:
|
||||
pred_l.append(rec.load_object('pred.pkl').iloc[:, 0])
|
||||
pred = pd.concat(pred_l).sort_index()
|
||||
reduce_group[k] = pred
|
||||
|
||||
return reduce_group
|
||||
|
||||
133
qlib/workflow/task/gen.py
Normal file
133
qlib/workflow/task/gen.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
'''
|
||||
this is a task generator
|
||||
'''
|
||||
import abc
|
||||
import copy
|
||||
import typing
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
class TaskGen(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> typing.List[dict]:
|
||||
"""
|
||||
generate
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args, kwargs:
|
||||
The info for generating tasks
|
||||
Example 1):
|
||||
input: a specific task template
|
||||
output: rolling version of the tasks
|
||||
Example 2):
|
||||
input: a specific task template
|
||||
output: a set of tasks with different losses
|
||||
|
||||
Returns
|
||||
-------
|
||||
typing.List[dict]:
|
||||
A list of tasks
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RollingGen(TaskGen):
|
||||
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD
|
||||
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX):
|
||||
"""
|
||||
Generate tasks for rolling
|
||||
|
||||
Parameters
|
||||
----------
|
||||
step : int
|
||||
step to rolling
|
||||
rtype : str
|
||||
rolling type (expanding, rolling)
|
||||
"""
|
||||
self.step = step
|
||||
self.rtype = rtype
|
||||
self.ta = TimeAdjuster(future=True) # 为了保证test最后的日期不是None, 所以这边要改一改
|
||||
|
||||
self.test_key = 'test'
|
||||
self.train_key = 'train'
|
||||
|
||||
def __call__(self, task: dict):
|
||||
"""
|
||||
Converting the task into a rolling task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : dict
|
||||
A dict describing a task. For example.
|
||||
|
||||
DEFAULT_TASK = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
"""
|
||||
res = []
|
||||
|
||||
prev_seg = None
|
||||
test_end = None
|
||||
while True:
|
||||
t = copy.deepcopy(task)
|
||||
|
||||
# calculate segments
|
||||
if prev_seg is None:
|
||||
# First rolling
|
||||
# 1) prepare the end porint
|
||||
segments = copy.deepcopy(self.ta.align_seg(t['dataset']['kwargs']['segments']))
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
# 2) and the init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
else:
|
||||
segments = {}
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# 决定怎么shift
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
rtype = self.ta.SHIFT_SD
|
||||
# 整段数据做shift
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
except KeyError:
|
||||
# We reach the end of tasks
|
||||
# No more rolling
|
||||
break
|
||||
|
||||
t['dataset']['kwargs']['segments'] = copy.deepcopy(segments)
|
||||
prev_seg = segments
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
|
||||
290
qlib/workflow/task/manage.py
Normal file
290
qlib/workflow/task/manage.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
A task consists of 2 parts
|
||||
- tasks description: the desc will define the task
|
||||
- tasks status: the status of the task
|
||||
- tasks result information : A user can get the task with the task description and task result.
|
||||
|
||||
"""
|
||||
from bson.binary import Binary
|
||||
import pickle
|
||||
from pymongo.errors import InvalidDocument
|
||||
from fire import Fire
|
||||
from bson.objectid import ObjectId
|
||||
from contextlib import contextmanager
|
||||
from loguru import logger
|
||||
from tqdm.cli import tqdm
|
||||
import time
|
||||
import concurrent
|
||||
import pymongo
|
||||
from qlib.config import C
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""TaskManager
|
||||
here is the what will a task looks like
|
||||
{
|
||||
'def': pickle serialized task definition. using pickle will make it easier
|
||||
'filter': json-like data. This is for filtering the tasks.
|
||||
'status': 'waiting' | 'running' | 'done'
|
||||
'res': pickle serialized task result,
|
||||
}
|
||||
|
||||
The tasks manager assume that you will only update the tasks you fetched.
|
||||
The mongo fetch one and update will make it date updating secure.
|
||||
|
||||
Usage Examples from the CLI.
|
||||
python -m blocks.tasks.__init__ task_stat --task_pool meta_task_rule
|
||||
|
||||
|
||||
NOTE:
|
||||
- 假设: 存储在db里面的都是encode过的, 拿出来的都是decode过的
|
||||
"""
|
||||
STATUS_WAITING = 'waiting'
|
||||
STATUS_RUNNING = 'running'
|
||||
STATUS_DONE = 'done'
|
||||
STATUS_PART_DONE = 'part_done'
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ['def', 'res']
|
||||
|
||||
def __init__(self, task_pool=None):
|
||||
self.mdb = get_mongodb()
|
||||
self.task_pool = task_pool
|
||||
|
||||
def list(self):
|
||||
return self.mdb.list_collection_names()
|
||||
|
||||
def _encode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = Binary(pickle.dumps(task[k]))
|
||||
return task
|
||||
|
||||
def _decode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = pickle.loads(task[k])
|
||||
return task
|
||||
|
||||
def _get_task_pool(self, task_pool=None):
|
||||
if task_pool is None:
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
raise ValueError('You must specify a task pool.')
|
||||
if isinstance(task_pool, str):
|
||||
return getattr(self.mdb, task_pool)
|
||||
return task_pool
|
||||
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def replace_task(self, task, new_task, task_pool=None):
|
||||
# 这里的假设是从接口拿出来的都是decode过的,在接口内部的都是 encode过的
|
||||
new_task = self._encode_task(new_task)
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
query = {'_id': ObjectId(task['_id'])}
|
||||
try:
|
||||
task_pool.replace_one(query, new_task)
|
||||
except InvalidDocument:
|
||||
task['filter'] = self._dict_to_str(task['filter'])
|
||||
task_pool.replace_one(query, new_task)
|
||||
|
||||
def insert_task(self, task, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
try:
|
||||
task_pool.insert_one(task)
|
||||
except InvalidDocument:
|
||||
task['filter'] = self._dict_to_str(task['filter'])
|
||||
task_pool.insert_one(task)
|
||||
|
||||
def insert_task_def(self, task_def, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
task = self._encode_task({
|
||||
'def': task_def,
|
||||
'filter': task_def, # FIXME: catch the raised error
|
||||
'status': self.STATUS_WAITING,
|
||||
})
|
||||
self.insert_task(task, task_pool)
|
||||
|
||||
def create_task(self, task_def_l, task_pool=None, dry_run=False, print_nt=False):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
new_tasks = []
|
||||
for t in task_def_l:
|
||||
try:
|
||||
r = task_pool.find_one({'filter': t})
|
||||
except InvalidDocument:
|
||||
r = task_pool.find_one({'filter': self._dict_to_str(t)})
|
||||
if r is None:
|
||||
new_tasks.append(t)
|
||||
print("Total Tasks, New Tasks:", len(task_def_l), len(new_tasks))
|
||||
|
||||
if print_nt: # print new task
|
||||
for t in new_tasks:
|
||||
print(t)
|
||||
|
||||
if dry_run:
|
||||
return
|
||||
|
||||
for t in new_tasks:
|
||||
self.insert_task_def(t, task_pool)
|
||||
|
||||
def fetch_task(self, query={}, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
query = query.copy()
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
query.update({'status': self.STATUS_WAITING})
|
||||
task = task_pool.find_one_and_update(query, {'$set': {
|
||||
'status': self.STATUS_RUNNING
|
||||
}},
|
||||
sort=[('priority', pymongo.DESCENDING)])
|
||||
# 这里我的 priority 必须是 高数优先级更高,因为 null会被在 ASCENDING时被排在最前面
|
||||
if task is None:
|
||||
return None
|
||||
task['status'] = self.STATUS_RUNNING
|
||||
return self._decode_task(task)
|
||||
|
||||
@contextmanager
|
||||
def safe_fetch_task(self, query={}, task_pool=None):
|
||||
task = self.fetch_task(query=query, task_pool=task_pool)
|
||||
try:
|
||||
yield task
|
||||
except Exception:
|
||||
if task is not None:
|
||||
logger.info('Returning task before raising error')
|
||||
self.return_task(task)
|
||||
logger.info('Task returned')
|
||||
raise
|
||||
|
||||
def task_fetcher_iter(self, query={}, task_pool=None):
|
||||
while True:
|
||||
with self.safe_fetch_task(query=query, task_pool=task_pool) as task:
|
||||
if task is None:
|
||||
break
|
||||
yield task
|
||||
|
||||
def query(self, query={}, decode=True, task_pool=None):
|
||||
"""query
|
||||
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
|
||||
|
||||
:param query:
|
||||
:param decode:
|
||||
:param task_pool:
|
||||
"""
|
||||
query = query.copy()
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
for t in task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def commit_task_res(self, task, res, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
task_pool.update_one({"_id": task['_id']}, {'$set': {'status': status, 'res': Binary(pickle.dumps(res))}})
|
||||
|
||||
def return_task(self, task, status=None, task_pool=None):
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_WAITING
|
||||
update_dict = {'$set': {'status': status}}
|
||||
task_pool.update_one({"_id": task['_id']}, update_dict)
|
||||
|
||||
def remove(self, query={}, task_pool=None):
|
||||
query = query.copy()
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}, task_pool=None):
|
||||
query = query.copy()
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
tasks = self.query(task_pool=task_pool, query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
status_stat[t['status']] = status_stat.get(t['status'], 0) + 1
|
||||
return status_stat
|
||||
|
||||
def reset_waiting(self, query={}, task_pool=None):
|
||||
query = query.copy()
|
||||
# default query
|
||||
if 'status' not in query:
|
||||
query['status'] = self.STATUS_RUNNING
|
||||
return self.reset_status(query=query, status=self.STATUS_WAITING, task_pool=task_pool)
|
||||
|
||||
def reset_status(self, query, status, task_pool=None):
|
||||
query = query.copy()
|
||||
task_pool = self._get_task_pool(task_pool)
|
||||
if '_id' in query:
|
||||
query['_id'] = ObjectId(query['_id'])
|
||||
print(task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
|
||||
|
||||
def _get_total(self, task_stat):
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}, task_pool=None):
|
||||
task_stat = self.task_stat(query, task_pool)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
undone_n = self._get_undone_n(self.task_stat(query, task_pool))
|
||||
pbar.update(last_undone_n - undone_n)
|
||||
last_undone_n = undone_n
|
||||
if undone_n == 0:
|
||||
break
|
||||
|
||||
def __str__(self):
|
||||
return f"TaskManager({self.task_pool})"
|
||||
|
||||
|
||||
def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
"""run_task.
|
||||
While task pool is not empty, use task_func to fetch and run tasks in task_pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_func : def (task_def, *args, **kwargs) -> <res which will be committed>
|
||||
the function to run the task
|
||||
task_pool :
|
||||
The name of the task pool
|
||||
force_release :
|
||||
will the program force to release the resource
|
||||
args :
|
||||
args
|
||||
kwargs :
|
||||
kwargs
|
||||
"""
|
||||
tm = TaskManager(task_pool)
|
||||
|
||||
ever_run = False
|
||||
|
||||
while True:
|
||||
with tm.safe_fetch_task() as task:
|
||||
if task is None:
|
||||
break
|
||||
logger.info(task['def'])
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
res = executor.submit(task_func, task['def'], *args, **kwargs).result()
|
||||
else:
|
||||
res = task_func(task['def'], *args, **kwargs)
|
||||
tm.commit_task_res(task, res)
|
||||
ever_run = True
|
||||
|
||||
return ever_run
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Fire(TaskManager)
|
||||
134
qlib/workflow/task/utils.py
Normal file
134
qlib/workflow/task/utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import bisect
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
from qlib.config import C
|
||||
from qlib.log import get_module_logger
|
||||
from pymongo import MongoClient
|
||||
|
||||
|
||||
def get_mongodb():
|
||||
try:
|
||||
cfg = C['mongo']
|
||||
except KeyError:
|
||||
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
|
||||
raise
|
||||
|
||||
client = MongoClient(cfg['task_url'])
|
||||
return client.get_database(name=cfg['task_db_name'])
|
||||
|
||||
|
||||
class TimeAdjuster:
|
||||
'''找到合适的日期,然后adjust date'''
|
||||
def __init__(self, future=False):
|
||||
self.cals = D.calendar(future=future)
|
||||
|
||||
def get(self, idx: int):
|
||||
"""
|
||||
Get datetime by index
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : int
|
||||
index of the calendar
|
||||
"""
|
||||
if idx >= len(self.cals):
|
||||
return None
|
||||
return self.cals[idx]
|
||||
|
||||
def max(self):
|
||||
"""
|
||||
Return return the max calendar date
|
||||
"""
|
||||
return max(self.cals)
|
||||
|
||||
def align_idx(self, time_point, tp_type="start"):
|
||||
time_point = pd.Timestamp(time_point)
|
||||
if tp_type == 'start':
|
||||
idx = bisect.bisect_left(self.cals, time_point)
|
||||
elif tp_type == 'end':
|
||||
idx = bisect.bisect_right(self.cals, time_point) - 1
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return idx
|
||||
|
||||
def align_time(self, time_point, tp_type="start"):
|
||||
"""
|
||||
Align a timepoint to calendar weekdays
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_point :
|
||||
Time point
|
||||
tp_type : str
|
||||
time point type (`"start"`, `"end"`)
|
||||
"""
|
||||
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
|
||||
|
||||
def align_seg(self, segment):
|
||||
if isinstance(segment, dict):
|
||||
return {k: self.align_seg(seg) for k, seg in segment.items()}
|
||||
elif isinstance(segment, tuple):
|
||||
return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end")
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def truncate(self, segment, test_start, days: int):
|
||||
"""
|
||||
truncate the segment based on the test_start date
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segment :
|
||||
time segment
|
||||
days : int
|
||||
The trading days to be truncated
|
||||
大部分情况是因为这个时间段的数据(一般是特征)会用到 `days` 天的数据
|
||||
"""
|
||||
test_idx = self.align_idx(test_start)
|
||||
if isinstance(segment, tuple):
|
||||
new_seg = []
|
||||
for time_point in segment:
|
||||
tp_idx = min(self.align_idx(time_point), test_idx - days)
|
||||
assert (tp_idx > 0)
|
||||
new_seg.append(self.get(tp_idx))
|
||||
return tuple(new_seg)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
SHIFT_SD = "sliding"
|
||||
SHIFT_EX = "expanding"
|
||||
|
||||
def shift(self, seg, step: int, rtype=SHIFT_SD):
|
||||
"""
|
||||
shift the datatiem of segment
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seg :
|
||||
datetime segment
|
||||
step : int
|
||||
rolling step
|
||||
rtype : str
|
||||
rolling type ("sliding" or "expanding")
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError:
|
||||
shift will raise error if the index(both start and end) is out of self.cal
|
||||
"""
|
||||
if isinstance(seg, tuple):
|
||||
start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end")
|
||||
if rtype == self.SHIFT_SD:
|
||||
start_idx += step
|
||||
end_idx += step
|
||||
elif rtype == self.SHIFT_EX:
|
||||
end_idx += step
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
if start_idx > len(self.cals):
|
||||
raise KeyError("The segment is out of valid calendar")
|
||||
return self.get(start_idx), self.get(end_idx)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
Reference in New Issue
Block a user