From 5d5f8c88685efb0e502b929e28e3815937a9daeb Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 3 Dec 2020 14:51:21 +0000 Subject: [PATCH] update TimeSeriesDataset --- qlib/data/dataset/__init__.py | 166 +++++++++++++++++++++++++++++++++- qlib/data/dataset/handler.py | 3 + qlib/tests/__init__.py | 20 ++++ qlib/tests/data.py | 98 ++++++++++++++++++++ qlib/utils/__init__.py | 9 +- scripts/get_data.py | 98 +------------------- tests/test_all_pipeline.py | 18 +--- tests/test_dataset.py | 53 +++++++++++ 8 files changed, 347 insertions(+), 118 deletions(-) create mode 100644 qlib/tests/__init__.py create mode 100644 qlib/tests/data.py create mode 100644 tests/test_dataset.py diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index e7d296d73..dab91bdcd 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -5,6 +5,10 @@ from ...log import get_module_logger from .handler import DataHandler, DataHandlerLP from inspect import getfullargspec import pandas as pd +import numpy as np +import bisect +from ...utils import lazy_sort_index +from .utils import get_level_index class Dataset(Serializable): @@ -115,6 +119,16 @@ class DatasetH(Dataset): self._handler = init_instance_by_config(handler, accept_types=DataHandler) self._segments = segments.copy() + def _prepare_seg(self, slc: slice, **kwargs): + """ + Give a slice, retrieve the according data + + Parameters + ---------- + slc : slice + """ + return self._handler.fetch(slc, **kwargs) + def prepare( self, segments: Union[List[str], Tuple[str], str, slice], @@ -157,9 +171,157 @@ class DatasetH(Dataset): else: logger.info(f"data_key[{data_key}] is ignored.") + # Handle all kinds of segments format if isinstance(segments, (list, tuple)): - return [self._handler.fetch(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments] + return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments] elif isinstance(segments, str): - return self._handler.fetch(slice(*self._segments[segments]), **fetch_kwargs) + return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs) + elif isinstance(segments, slice): + return self._prepare_seg(segments, **fetch_kwargs) else: raise NotImplementedError(f"This type of input is not supported") + + +class TSDataSampler: + """ + (T)ime-(S)eries DataSampler + This is the result of TSDatasetH + + It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series + dataset based on tabular data. + + If user have further requirements for processing data, user could process + + """ + + def __init__(self, data, start, end, step_len): + self.start = start + self.end = end + self.step_len = step_len + assert get_level_index(data, "datetime") == 0 + self.data = lazy_sort_index(data) + # The index of usable data is between start_idx and end_idx + self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) + # self.index_link = self.build_link(self.data) + self.idx_df, self.idx_map = self.build_index(self.data) + + @staticmethod + def build_index(data: pd.DataFrame) -> dict: + """ + The relation of the data + + Parameters + ---------- + data : pd.DataFrame + The dataframe with + + Returns + ------- + dict: + {: } + # get the previous index of a line given index + """ + # object incase of pandas converting int to flaot + idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object) + idx_df = lazy_sort_index(idx_df.unstack()) + # NOTE: the correctness of `__getitem__` depends on columns sorted here + idx_df = lazy_sort_index(idx_df, axis=1) + + idx_map = {} + for i, (_, row) in enumerate(idx_df.iterrows()): + for j, real_idx in enumerate(row): + if not np.isnan(real_idx): + idx_map[real_idx] = (i, j) + return idx_df, idx_map + + def __getitem__(self, idx: Union[int, Tuple[object, str]]): + """ + # We have two method to get the time-series of a sample + tsds is a instance of TSDataSampler + + # 1) sample by int index directly + tsds[len(tsds) - 1] + + # 2) sample by index + tsds['2016-12-31', "SZ300315"] + + # The return value will be similar to the data retrieved by following code + df.loc(axis=0)['2015-01-01':'2016-12-31', "SZ300315"].iloc[-30:] + + Parameters + ---------- + idx : Union[int, Tuple[object, str]] + """ + # The the right row number `i` and col number `j` in idx_df + if isinstance(idx, int): + real_idx = self.start_idx + idx + if self.start_idx <= real_idx < self.end_idx: + i, j = self.idx_map[real_idx] + elif isinstance(idx, tuple): + # ["datetime", "instruments"] + date, inst = idx + date = pd.Timestamp(date) + i = bisect.bisect_right(self.idx_df.index, date) - 1 + # NOTE: This relies on the idx_df columns sorted in `__init__` + j = bisect.bisect_left(self.idx_df.columns, inst) + else: + raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})") + + data_l = [] + indices = self.idx_df.iloc[max(i - self.step_len + 1, 0) : i + 1, j].values + indices = indices.reshape(-1) + if len(indices) < self.step_len: + indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices]) + for idx in indices: + if np.isnan(idx): + data_l.append(np.full((self.data.shape[1],), np.nan)) + else: + data_l.append(self.data.iloc[idx]) + return np.array(data_l) + + def __len__(self): + return self.end_idx - self.start_idx + + +class TSDatasetH(DatasetH): + """ + (T)ime-(S)eries Dataset (H)andler + + + Covnert the tabular data to Time-Series data + + Requirements analysis + + The typical workflow of a user to get time-series data for an sample + - process features + - slice proper data from data handler: dimension of sample + - Build relation of samples by index + - Be able to sample times series of data + - It will be better if the interface is like "torch.utils.data.Dataset" + - User could build customized batch based on the data + - The dimension of a batch of data + """ + + def __init__(self, step_len=30, *args, **kwargs): + self.step_len = step_len + super().__init__(*args, **kwargs) + + def setup_data(self, *args, **kwargs): + super().setup_data(*args, **kwargs) + cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique() + cal = sorted(cal) + # Get the datatime index for building timestamp + self.cal = cal + + def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: + # Dataset decide how to slice data(Get more data for timeseries). + start, end = slc.start, slc.stop + start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start)) + pad_start_idx = max(0, start_idx - self.step_len) + pad_start = self.cal[pad_start_idx] + + # TSDatasetH will retrieve more data for complete + data = super()._prepare_seg(slice(pad_start, end), **kwargs) + + tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len) + return tsds diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 905fcd623..18f838300 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -155,6 +155,9 @@ class DataHandler(Serializable): select a set of meaningful columns.(e.g. features, columns) + if cal_set == CS_RAW: + the raw dataset will be returned. + - if isinstance(col_set, List[str]): select several sets of meaningful columns, the returned data has multiple levels diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py new file mode 100644 index 000000000..a1b33a2a2 --- /dev/null +++ b/qlib/tests/__init__.py @@ -0,0 +1,20 @@ +import sys +import unittest +from ..utils import exists_qlib_data +from .data import GetData +from .. import init +from ..config import REG_CN + + +class TestAutoData(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + + GetData().qlib_data( + name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri + ) + init(provider_uri=provider_uri, region=REG_CN) diff --git a/qlib/tests/data.py b/qlib/tests/data.py new file mode 100644 index 000000000..66bfb0e29 --- /dev/null +++ b/qlib/tests/data.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import zipfile +import requests +from tqdm import tqdm +from pathlib import Path +from loguru import logger + + +class GetData: + REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads" + + def __init__(self, delete_zip_file=False): + """ + + Parameters + ---------- + delete_zip_file : bool, optional + Whether to delete the zip file, value from True or False, by default False + """ + self.delete_zip_file = delete_zip_file + + def _download_data(self, file_name: str, target_dir: [Path, str]): + target_dir = Path(target_dir).expanduser() + target_dir.mkdir(exist_ok=True, parents=True) + + url = f"{self.REMOTE_URL}/{file_name}" + target_path = target_dir.joinpath(file_name) + + resp = requests.get(url, stream=True) + if resp.status_code != 200: + raise requests.exceptions.HTTPError() + + chuck_size = 1024 + logger.warning( + f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)" + ) + logger.info(f"{file_name} downloading......") + with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar: + with target_path.open("wb") as fp: + for chuck in resp.iter_content(chunk_size=chuck_size): + fp.write(chuck) + p_bar.update(chuck_size) + + self._unzip(target_path, target_dir) + if self.delete_zip_file: + target_path.unlike() + + @staticmethod + def _unzip(file_path: Path, target_dir: Path): + logger.info(f"{file_path} unzipping......") + with zipfile.ZipFile(str(file_path.resolve()), "r") as zp: + for _file in tqdm(zp.namelist()): + zp.extract(_file, str(target_dir.resolve())) + + def qlib_data( + self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn" + ): + """download cn qlib data from remote + + Parameters + ---------- + target_dir: str + data save directory + name: str + dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data + version: str + data version, value from [v0, v1, ..., latest], by default latest + interval: str + data freq, value from [1d], by default 1d + region: str + data region, value from [cn, us], by default cn + + Examples + --------- + python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn + ------- + + """ + file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip" + self._download_data(file_name.lower(), target_dir) + + def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"): + """download cn csv data from remote + + Parameters + ---------- + target_dir: str + data save directory + + Examples + --------- + python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data + ------- + + """ + file_name = "csv_data_cn.zip" + self._download_data(file_name, target_dir) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 5b313a0ef..138ce35f8 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -622,9 +622,9 @@ def exists_qlib_data(qlib_dir): return True -def lexsort_index(df: pd.DataFrame) -> pd.DataFrame: +def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame: """ - make the df index lexsorted + make the df index sorted df.sort_index() will take a lot of time even when `df.is_lexsorted() == True` This function could avoid such case @@ -638,10 +638,11 @@ def lexsort_index(df: pd.DataFrame) -> pd.DataFrame: pd.DataFrame: sorted dataframe """ - if df.index.is_lexsorted(): + idx = df.index if axis == 0 else df.columns + if idx.is_monotonic_increasing: return df else: - return df.sort_index() + return df.sort_index(axis=axis) def flatten_dict(d, parent_key="", sep="."): diff --git a/scripts/get_data.py b/scripts/get_data.py index f4dba1474..2f6a7494a 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -1,103 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - import fire -import zipfile -import requests -from tqdm import tqdm -from pathlib import Path -from loguru import logger - - -class GetData: - REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads" - - def __init__(self, delete_zip_file=False): - """ - - Parameters - ---------- - delete_zip_file : bool, optional - Whether to delete the zip file, value from True or False, by default False - """ - self.delete_zip_file = delete_zip_file - - def _download_data(self, file_name: str, target_dir: [Path, str]): - target_dir = Path(target_dir).expanduser() - target_dir.mkdir(exist_ok=True, parents=True) - - url = f"{self.REMOTE_URL}/{file_name}" - target_path = target_dir.joinpath(file_name) - - resp = requests.get(url, stream=True) - if resp.status_code != 200: - raise requests.exceptions.HTTPError() - - chuck_size = 1024 - logger.warning( - f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)" - ) - logger.info(f"{file_name} downloading......") - with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar: - with target_path.open("wb") as fp: - for chuck in resp.iter_content(chunk_size=chuck_size): - fp.write(chuck) - p_bar.update(chuck_size) - - self._unzip(target_path, target_dir) - if self.delete_zip_file: - target_path.unlike() - - @staticmethod - def _unzip(file_path: Path, target_dir: Path): - logger.info(f"{file_path} unzipping......") - with zipfile.ZipFile(str(file_path.resolve()), "r") as zp: - for _file in tqdm(zp.namelist()): - zp.extract(_file, str(target_dir.resolve())) - - def qlib_data( - self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn" - ): - """download cn qlib data from remote - - Parameters - ---------- - target_dir: str - data save directory - name: str - dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data - version: str - data version, value from [v0, v1, ..., latest], by default latest - interval: str - data freq, value from [1d], by default 1d - region: str - data region, value from [cn, us], by default cn - - Examples - --------- - python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn - ------- - - """ - file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip" - self._download_data(file_name.lower(), target_dir) - - def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"): - """download cn csv data from remote - - Parameters - ---------- - target_dir: str - data save directory - - Examples - --------- - python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data - ------- - - """ - file_name = "csv_data_cn.zip" - self._download_data(file_name, target_dir) +from qlib.tests.data import GetData if __name__ == "__main__": diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index befd296b0..93c1a5fee 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -22,6 +22,8 @@ from qlib.contrib.evaluate import ( from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord +from qlib.tests.get_data import GetData +from qlib.tests import TestAutoData market = "csi300" @@ -156,26 +158,12 @@ def backtest_analysis(pred, rid): return analysis_df -class TestAllFlow(unittest.TestCase): +class TestAllFlow(TestAutoData): PRED_SCORE = None REPORT_NORMAL = None POSITIONS = None RID = None - @classmethod - def setUpClass(cls) -> None: - # use default data - provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir - if not exists_qlib_data(provider_uri): - print(f"Qlib data is not found in {provider_uri}") - sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) - from get_data import GetData - - GetData().qlib_data( - name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri - ) - qlib.init(provider_uri=provider_uri, region=REG_CN) - @classmethod def tearDownClass(cls) -> None: shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve())) diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 000000000..dc3042175 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest +import sys +from qlib.tests import TestAutoData +from qlib.data.dataset import TSDatasetH +import numpy as np + + +class TestDataset(TestAutoData): + def testTSDataset(self): + tsdh = TSDatasetH( + handler={ + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": "csi300", + }, + }, + segments={ + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + ) + _ = tsdh.prepare("train") # Test the correctness + tsds = tsdh.prepare("valid") # prepare a dataset with is friendly to converting tabular data to time-series + + # The dimension of sample is same as tabular data, but it will return timeseries data of the sample + + # We have two method to get the time-series of a sample + + # 1) sample by int index directly + tsds[len(tsds) - 1] + + # 2) sample by index + data_from_ds = tsds["2016-12-31", "SZ300315"] + + # Check the data + # Get data from DataFrame Directly + data_from_df = tsdh._handler.fetch().loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"].iloc[-30:].values + + equal = np.isclose(data_from_df, data_from_ds) + self.assertTrue(equal[~np.isnan(data_from_df)].all()) + + +if __name__ == "__main__": + unittest.main(verbosity=10)