1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00
Files
qlib/qlib/workflow/task/utils.py
2021-03-26 04:20:25 +00:00

213 lines
6.4 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import bisect
import pandas as pd
from qlib.data import D
from qlib.workflow import R
from qlib.config import C
from qlib.log import get_module_logger
from pymongo import MongoClient
from typing import Union
def get_mongodb():
"""
get database in MongoDB, which means you need to declare the address and the name of database.
for example:
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/",
"task_db_name" : "rolling_db"
}
"""
try:
cfg = C["mongo"]
except KeyError:
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
raise
client = MongoClient(cfg["task_url"])
return client.get_database(name=cfg["task_db_name"])
def list_recorders(experiment, rec_filter_func=None):
"""list all recorders which can pass the filter in a experiment.
Args:
experiment (str or Experiment): the name of a Experiment or a instance
rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
Returns:
dict: a dict {rid: recorder} after filtering.
"""
if isinstance(experiment, str):
experiment, _ = R.exp_manager._get_or_create_exp(experiment_name=experiment)
recs = experiment.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
return recs_flt
class TimeAdjuster:
"""
find appropriate date and adjust date.
"""
def __init__(self, future=False):
self.cals = D.calendar(future=future)
def get(self, idx: int):
"""
Get datetime by index
Parameters
----------
idx : int
index of the calendar
"""
if idx >= len(self.cals):
return None
return self.cals[idx]
def max(self):
"""
(Deprecated)
Return the max calendar datetime
"""
return max(self.cals)
def last_date(self) -> pd.Timestamp:
"""
Return the last datetime in the calendar
"""
return self.cals[-1]
def align_idx(self, time_point, tp_type="start"):
"""
align the index of time_point in the calendar
Parameters
----------
time_point
tp_type : str
Returns
-------
index : int
"""
time_point = pd.Timestamp(time_point)
if tp_type == "start":
idx = bisect.bisect_left(self.cals, time_point)
elif tp_type == "end":
idx = bisect.bisect_right(self.cals, time_point) - 1
else:
raise NotImplementedError(f"This type of input is not supported")
return idx
def cal_interval(self, time_point_A, time_point_B):
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
def align_time(self, time_point, tp_type="start"):
"""
Align time_point to trade date of calendar
Parameters
----------
time_point
Time point
tp_type : str
time point type (`"start"`, `"end"`)
"""
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
def align_seg(self, segment: Union[dict, tuple]):
"""
align the given date to trade date
for example:
.. code-block:: python
input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}
output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),
'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),
'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
Parameters
----------
segment
Returns
-------
the start and end trade date (pd.Timestamp) between the given start and end date.
"""
if isinstance(segment, dict):
return {k: self.align_seg(seg) for k, seg in segment.items()}
elif isinstance(segment, tuple):
return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end")
else:
raise NotImplementedError(f"This type of input is not supported")
def truncate(self, segment: tuple, test_start, days: int):
"""
truncate the segment based on the test_start date
Parameters
----------
segment : tuple
time segment
test_start
days : int
The trading days to be truncated
the data in this segment may need 'days' data
"""
test_idx = self.align_idx(test_start)
if isinstance(segment, tuple):
new_seg = []
for time_point in segment:
tp_idx = min(self.align_idx(time_point), test_idx - days)
assert tp_idx > 0
new_seg.append(self.get(tp_idx))
return tuple(new_seg)
else:
raise NotImplementedError(f"This type of input is not supported")
SHIFT_SD = "sliding"
SHIFT_EX = "expanding"
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD):
"""
shift the datatime of segment
Parameters
----------
seg :
datetime segment
step : int
rolling step
rtype : str
rolling type ("sliding" or "expanding")
Raises
------
KeyError:
shift will raise error if the index(both start and end) is out of self.cal
"""
if isinstance(seg, tuple):
start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end")
if rtype == self.SHIFT_SD:
start_idx += step
end_idx += step
elif rtype == self.SHIFT_EX:
end_idx += step
else:
raise NotImplementedError(f"This type of input is not supported")
if start_idx > len(self.cals):
raise KeyError("The segment is out of valid calendar")
return self.get(start_idx), self.get(end_idx)
else:
raise NotImplementedError(f"This type of input is not supported")