1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 03:50:57 +08:00

update TimeSeriesDataset

This commit is contained in:
Young
2020-12-03 14:51:21 +00:00
committed by you-n-g
parent d093afd684
commit 5d5f8c8868
8 changed files with 347 additions and 118 deletions

View File

@@ -5,6 +5,10 @@ from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from inspect import getfullargspec
import pandas as pd
import numpy as np
import bisect
from ...utils import lazy_sort_index
from .utils import get_level_index
class Dataset(Serializable):
@@ -115,6 +119,16 @@ class DatasetH(Dataset):
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
self._segments = segments.copy()
def _prepare_seg(self, slc: slice, **kwargs):
"""
Give a slice, retrieve the according data
Parameters
----------
slc : slice
"""
return self._handler.fetch(slc, **kwargs)
def prepare(
self,
segments: Union[List[str], Tuple[str], str, slice],
@@ -157,9 +171,157 @@ class DatasetH(Dataset):
else:
logger.info(f"data_key[{data_key}] is ignored.")
# Handle all kinds of segments format
if isinstance(segments, (list, tuple)):
return [self._handler.fetch(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
elif isinstance(segments, str):
return self._handler.fetch(slice(*self._segments[segments]), **fetch_kwargs)
return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs)
elif isinstance(segments, slice):
return self._prepare_seg(segments, **fetch_kwargs)
else:
raise NotImplementedError(f"This type of input is not supported")
class TSDataSampler:
"""
(T)ime-(S)eries DataSampler
This is the result of TSDatasetH
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
dataset based on tabular data.
If user have further requirements for processing data, user could process
"""
def __init__(self, data, start, end, step_len):
self.start = start
self.end = end
self.step_len = step_len
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
# The index of usable data is between start_idx and end_idx
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
# self.index_link = self.build_link(self.data)
self.idx_df, self.idx_map = self.build_index(self.data)
@staticmethod
def build_index(data: pd.DataFrame) -> dict:
"""
The relation of the data
Parameters
----------
data : pd.DataFrame
The dataframe with <datetime, DataFrame>
Returns
-------
dict:
{<index>: <prev_index or None>}
# get the previous index of a line given index
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
idx_df = lazy_sort_index(idx_df.unstack())
# NOTE: the correctness of `__getitem__` depends on columns sorted here
idx_df = lazy_sort_index(idx_df, axis=1)
idx_map = {}
for i, (_, row) in enumerate(idx_df.iterrows()):
for j, real_idx in enumerate(row):
if not np.isnan(real_idx):
idx_map[real_idx] = (i, j)
return idx_df, idx_map
def __getitem__(self, idx: Union[int, Tuple[object, str]]):
"""
# We have two method to get the time-series of a sample
tsds is a instance of TSDataSampler
# 1) sample by int index directly
tsds[len(tsds) - 1]
# 2) sample by <datetime,instrument> index
tsds['2016-12-31', "SZ300315"]
# The return value will be similar to the data retrieved by following code
df.loc(axis=0)['2015-01-01':'2016-12-31', "SZ300315"].iloc[-30:]
Parameters
----------
idx : Union[int, Tuple[object, str]]
"""
# The the right row number `i` and col number `j` in idx_df
if isinstance(idx, int):
real_idx = self.start_idx + idx
if self.start_idx <= real_idx < self.end_idx:
i, j = self.idx_map[real_idx]
elif isinstance(idx, tuple):
# <TSDataSampler object>["datetime", "instruments"]
date, inst = idx
date = pd.Timestamp(date)
i = bisect.bisect_right(self.idx_df.index, date) - 1
# NOTE: This relies on the idx_df columns sorted in `__init__`
j = bisect.bisect_left(self.idx_df.columns, inst)
else:
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
data_l = []
indices = self.idx_df.iloc[max(i - self.step_len + 1, 0) : i + 1, j].values
indices = indices.reshape(-1)
if len(indices) < self.step_len:
indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices])
for idx in indices:
if np.isnan(idx):
data_l.append(np.full((self.data.shape[1],), np.nan))
else:
data_l.append(self.data.iloc[idx])
return np.array(data_l)
def __len__(self):
return self.end_idx - self.start_idx
class TSDatasetH(DatasetH):
"""
(T)ime-(S)eries Dataset (H)andler
Covnert the tabular data to Time-Series data
Requirements analysis
The typical workflow of a user to get time-series data for an sample
- process features
- slice proper data from data handler: dimension of sample <feature, >
- Build relation of samples by <time, instrument> index
- Be able to sample times series of data <timestep, feature>
- It will be better if the interface is like "torch.utils.data.Dataset"
- User could build customized batch based on the data
- The dimension of a batch of data <batch_idx, feature, timestep>
"""
def __init__(self, step_len=30, *args, **kwargs):
self.step_len = step_len
super().__init__(*args, **kwargs)
def setup_data(self, *args, **kwargs):
super().setup_data(*args, **kwargs)
cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
# Get the datatime index for building timestamp
self.cal = cal
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
pad_start_idx = max(0, start_idx - self.step_len)
pad_start = self.cal[pad_start_idx]
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
return tsds

View File

@@ -155,6 +155,9 @@ class DataHandler(Serializable):
select a set of meaningful columns.(e.g. features, columns)
if cal_set == CS_RAW:
the raw dataset will be returned.
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple levels

20
qlib/tests/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
import sys
import unittest
from ..utils import exists_qlib_data
from .data import GetData
from .. import init
from ..config import REG_CN
class TestAutoData(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
)
init(provider_uri=provider_uri, region=REG_CN)

98
qlib/tests/data.py Normal file
View File

@@ -0,0 +1,98 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import zipfile
import requests
from tqdm import tqdm
from pathlib import Path
from loguru import logger
class GetData:
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
def __init__(self, delete_zip_file=False):
"""
Parameters
----------
delete_zip_file : bool, optional
Whether to delete the zip file, value from True or False, by default False
"""
self.delete_zip_file = delete_zip_file
def _download_data(self, file_name: str, target_dir: [Path, str]):
target_dir = Path(target_dir).expanduser()
target_dir.mkdir(exist_ok=True, parents=True)
url = f"{self.REMOTE_URL}/{file_name}"
target_path = target_dir.joinpath(file_name)
resp = requests.get(url, stream=True)
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chuck_size = 1024
logger.warning(
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
)
logger.info(f"{file_name} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chuck in resp.iter_content(chunk_size=chuck_size):
fp.write(chuck)
p_bar.update(chuck_size)
self._unzip(target_path, target_dir)
if self.delete_zip_file:
target_path.unlike()
@staticmethod
def _unzip(file_path: Path, target_dir: Path):
logger.info(f"{file_path} unzipping......")
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
for _file in tqdm(zp.namelist()):
zp.extract(_file, str(target_dir.resolve()))
def qlib_data(
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
):
"""download cn qlib data from remote
Parameters
----------
target_dir: str
data save directory
name: str
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
version: str
data version, value from [v0, v1, ..., latest], by default latest
interval: str
data freq, value from [1d], by default 1d
region: str
data region, value from [cn, us], by default cn
Examples
---------
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
-------
"""
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
self._download_data(file_name.lower(), target_dir)
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
"""download cn csv data from remote
Parameters
----------
target_dir: str
data save directory
Examples
---------
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
-------
"""
file_name = "csv_data_cn.zip"
self._download_data(file_name, target_dir)

View File

@@ -622,9 +622,9 @@ def exists_qlib_data(qlib_dir):
return True
def lexsort_index(df: pd.DataFrame) -> pd.DataFrame:
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
"""
make the df index lexsorted
make the df index sorted
df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`
This function could avoid such case
@@ -638,10 +638,11 @@ def lexsort_index(df: pd.DataFrame) -> pd.DataFrame:
pd.DataFrame:
sorted dataframe
"""
if df.index.is_lexsorted():
idx = df.index if axis == 0 else df.columns
if idx.is_monotonic_increasing:
return df
else:
return df.sort_index()
return df.sort_index(axis=axis)
def flatten_dict(d, parent_key="", sep="."):

View File

@@ -1,103 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import fire
import zipfile
import requests
from tqdm import tqdm
from pathlib import Path
from loguru import logger
class GetData:
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
def __init__(self, delete_zip_file=False):
"""
Parameters
----------
delete_zip_file : bool, optional
Whether to delete the zip file, value from True or False, by default False
"""
self.delete_zip_file = delete_zip_file
def _download_data(self, file_name: str, target_dir: [Path, str]):
target_dir = Path(target_dir).expanduser()
target_dir.mkdir(exist_ok=True, parents=True)
url = f"{self.REMOTE_URL}/{file_name}"
target_path = target_dir.joinpath(file_name)
resp = requests.get(url, stream=True)
if resp.status_code != 200:
raise requests.exceptions.HTTPError()
chuck_size = 1024
logger.warning(
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
)
logger.info(f"{file_name} downloading......")
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
with target_path.open("wb") as fp:
for chuck in resp.iter_content(chunk_size=chuck_size):
fp.write(chuck)
p_bar.update(chuck_size)
self._unzip(target_path, target_dir)
if self.delete_zip_file:
target_path.unlike()
@staticmethod
def _unzip(file_path: Path, target_dir: Path):
logger.info(f"{file_path} unzipping......")
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
for _file in tqdm(zp.namelist()):
zp.extract(_file, str(target_dir.resolve()))
def qlib_data(
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
):
"""download cn qlib data from remote
Parameters
----------
target_dir: str
data save directory
name: str
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
version: str
data version, value from [v0, v1, ..., latest], by default latest
interval: str
data freq, value from [1d], by default 1d
region: str
data region, value from [cn, us], by default cn
Examples
---------
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
-------
"""
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
self._download_data(file_name.lower(), target_dir)
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
"""download cn csv data from remote
Parameters
----------
target_dir: str
data save directory
Examples
---------
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
-------
"""
file_name = "csv_data_cn.zip"
self._download_data(file_name, target_dir)
from qlib.tests.data import GetData
if __name__ == "__main__":

View File

@@ -22,6 +22,8 @@ from qlib.contrib.evaluate import (
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
from qlib.tests.get_data import GetData
from qlib.tests import TestAutoData
market = "csi300"
@@ -156,26 +158,12 @@ def backtest_analysis(pred, rid):
return analysis_df
class TestAllFlow(unittest.TestCase):
class TestAllFlow(TestAutoData):
PRED_SCORE = None
REPORT_NORMAL = None
POSITIONS = None
RID = None
@classmethod
def setUpClass(cls) -> None:
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
)
qlib.init(provider_uri=provider_uri, region=REG_CN)
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))

53
tests/test_dataset.py Normal file
View File

@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import sys
from qlib.tests import TestAutoData
from qlib.data.dataset import TSDatasetH
import numpy as np
class TestDataset(TestAutoData):
def testTSDataset(self):
tsdh = TSDatasetH(
handler={
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi300",
},
},
segments={
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
)
_ = tsdh.prepare("train") # Test the correctness
tsds = tsdh.prepare("valid") # prepare a dataset with is friendly to converting tabular data to time-series
# The dimension of sample is same as tabular data, but it will return timeseries data of the sample
# We have two method to get the time-series of a sample
# 1) sample by int index directly
tsds[len(tsds) - 1]
# 2) sample by <datetime,instrument> index
data_from_ds = tsds["2016-12-31", "SZ300315"]
# Check the data
# Get data from DataFrame Directly
data_from_df = tsdh._handler.fetch().loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"].iloc[-30:].values
equal = np.isclose(data_from_df, data_from_ds)
self.assertTrue(equal[~np.isnan(data_from_df)].all())
if __name__ == "__main__":
unittest.main(verbosity=10)