mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 03:50:57 +08:00
update TimeSeriesDataset
This commit is contained in:
@@ -5,6 +5,10 @@ from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import bisect
|
||||
from ...utils import lazy_sort_index
|
||||
from .utils import get_level_index
|
||||
|
||||
|
||||
class Dataset(Serializable):
|
||||
@@ -115,6 +119,16 @@ class DatasetH(Dataset):
|
||||
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self._segments = segments.copy()
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
"""
|
||||
Give a slice, retrieve the according data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
slc : slice
|
||||
"""
|
||||
return self._handler.fetch(slc, **kwargs)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
@@ -157,9 +171,157 @@ class DatasetH(Dataset):
|
||||
else:
|
||||
logger.info(f"data_key[{data_key}] is ignored.")
|
||||
|
||||
# Handle all kinds of segments format
|
||||
if isinstance(segments, (list, tuple)):
|
||||
return [self._handler.fetch(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
|
||||
return [self._prepare_seg(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
|
||||
elif isinstance(segments, str):
|
||||
return self._handler.fetch(slice(*self._segments[segments]), **fetch_kwargs)
|
||||
return self._prepare_seg(slice(*self._segments[segments]), **fetch_kwargs)
|
||||
elif isinstance(segments, slice):
|
||||
return self._prepare_seg(segments, **fetch_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
|
||||
class TSDataSampler:
|
||||
"""
|
||||
(T)ime-(S)eries DataSampler
|
||||
This is the result of TSDatasetH
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
|
||||
If user have further requirements for processing data, user could process
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, start, end, step_len):
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.step_len = step_len
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
# self.index_link = self.build_link(self.data)
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
|
||||
@staticmethod
|
||||
def build_index(data: pd.DataFrame) -> dict:
|
||||
"""
|
||||
The relation of the data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : pd.DataFrame
|
||||
The dataframe with <datetime, DataFrame>
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
{<index>: <prev_index or None>}
|
||||
# get the previous index of a line given index
|
||||
"""
|
||||
# object incase of pandas converting int to flaot
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=np.object)
|
||||
idx_df = lazy_sort_index(idx_df.unstack())
|
||||
# NOTE: the correctness of `__getitem__` depends on columns sorted here
|
||||
idx_df = lazy_sort_index(idx_df, axis=1)
|
||||
|
||||
idx_map = {}
|
||||
for i, (_, row) in enumerate(idx_df.iterrows()):
|
||||
for j, real_idx in enumerate(row):
|
||||
if not np.isnan(real_idx):
|
||||
idx_map[real_idx] = (i, j)
|
||||
return idx_df, idx_map
|
||||
|
||||
def __getitem__(self, idx: Union[int, Tuple[object, str]]):
|
||||
"""
|
||||
# We have two method to get the time-series of a sample
|
||||
tsds is a instance of TSDataSampler
|
||||
|
||||
# 1) sample by int index directly
|
||||
tsds[len(tsds) - 1]
|
||||
|
||||
# 2) sample by <datetime,instrument> index
|
||||
tsds['2016-12-31', "SZ300315"]
|
||||
|
||||
# The return value will be similar to the data retrieved by following code
|
||||
df.loc(axis=0)['2015-01-01':'2016-12-31', "SZ300315"].iloc[-30:]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : Union[int, Tuple[object, str]]
|
||||
"""
|
||||
# The the right row number `i` and col number `j` in idx_df
|
||||
if isinstance(idx, int):
|
||||
real_idx = self.start_idx + idx
|
||||
if self.start_idx <= real_idx < self.end_idx:
|
||||
i, j = self.idx_map[real_idx]
|
||||
elif isinstance(idx, tuple):
|
||||
# <TSDataSampler object>["datetime", "instruments"]
|
||||
date, inst = idx
|
||||
date = pd.Timestamp(date)
|
||||
i = bisect.bisect_right(self.idx_df.index, date) - 1
|
||||
# NOTE: This relies on the idx_df columns sorted in `__init__`
|
||||
j = bisect.bisect_left(self.idx_df.columns, inst)
|
||||
else:
|
||||
raise KeyError(f"{real_idx} is out of [{self.start_idx}, {self.end_idx})")
|
||||
|
||||
data_l = []
|
||||
indices = self.idx_df.iloc[max(i - self.step_len + 1, 0) : i + 1, j].values
|
||||
indices = indices.reshape(-1)
|
||||
if len(indices) < self.step_len:
|
||||
indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices])
|
||||
for idx in indices:
|
||||
if np.isnan(idx):
|
||||
data_l.append(np.full((self.data.shape[1],), np.nan))
|
||||
else:
|
||||
data_l.append(self.data.iloc[idx])
|
||||
return np.array(data_l)
|
||||
|
||||
def __len__(self):
|
||||
return self.end_idx - self.start_idx
|
||||
|
||||
|
||||
class TSDatasetH(DatasetH):
|
||||
"""
|
||||
(T)ime-(S)eries Dataset (H)andler
|
||||
|
||||
|
||||
Covnert the tabular data to Time-Series data
|
||||
|
||||
Requirements analysis
|
||||
|
||||
The typical workflow of a user to get time-series data for an sample
|
||||
- process features
|
||||
- slice proper data from data handler: dimension of sample <feature, >
|
||||
- Build relation of samples by <time, instrument> index
|
||||
- Be able to sample times series of data <timestep, feature>
|
||||
- It will be better if the interface is like "torch.utils.data.Dataset"
|
||||
- User could build customized batch based on the data
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, *args, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
super().setup_data(*args, **kwargs)
|
||||
cal = self._handler.fetch(col_set=self._handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
# Get the datatime index for building timestamp
|
||||
self.cal = cal
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
start, end = slc.start, slc.stop
|
||||
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
|
||||
pad_start_idx = max(0, start_idx - self.step_len)
|
||||
pad_start = self.cal[pad_start_idx]
|
||||
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
|
||||
return tsds
|
||||
|
||||
@@ -155,6 +155,9 @@ class DataHandler(Serializable):
|
||||
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
|
||||
if cal_set == CS_RAW:
|
||||
the raw dataset will be returned.
|
||||
|
||||
- if isinstance(col_set, List[str]):
|
||||
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
|
||||
20
qlib/tests/__init__.py
Normal file
20
qlib/tests/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import sys
|
||||
import unittest
|
||||
from ..utils import exists_qlib_data
|
||||
from .data import GetData
|
||||
from .. import init
|
||||
from ..config import REG_CN
|
||||
|
||||
|
||||
class TestAutoData(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
|
||||
)
|
||||
init(provider_uri=provider_uri, region=REG_CN)
|
||||
98
qlib/tests/data.py
Normal file
98
qlib/tests/data.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import zipfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GetData:
|
||||
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delete_zip_file : bool, optional
|
||||
Whether to delete the zip file, value from True or False, by default False
|
||||
"""
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def _download_data(self, file_name: str, target_dir: [Path, str]):
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
target_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
url = f"{self.REMOTE_URL}/{file_name}"
|
||||
target_path = target_dir.joinpath(file_name)
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chuck_size = 1024
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chuck in resp.iter_content(chunk_size=chuck_size):
|
||||
fp.write(chuck)
|
||||
p_bar.update(chuck_size)
|
||||
|
||||
self._unzip(target_path, target_dir)
|
||||
if self.delete_zip_file:
|
||||
target_path.unlike()
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path):
|
||||
logger.info(f"{file_path} unzipping......")
|
||||
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data(
|
||||
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
|
||||
):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
name: str
|
||||
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
|
||||
version: str
|
||||
data version, value from [v0, v1, ..., latest], by default latest
|
||||
interval: str
|
||||
data freq, value from [1d], by default 1d
|
||||
region: str
|
||||
data region, value from [cn, us], by default cn
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
|
||||
self._download_data(file_name.lower(), target_dir)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "csv_data_cn.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
@@ -622,9 +622,9 @@ def exists_qlib_data(qlib_dir):
|
||||
return True
|
||||
|
||||
|
||||
def lexsort_index(df: pd.DataFrame) -> pd.DataFrame:
|
||||
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
"""
|
||||
make the df index lexsorted
|
||||
make the df index sorted
|
||||
|
||||
df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`
|
||||
This function could avoid such case
|
||||
@@ -638,10 +638,11 @@ def lexsort_index(df: pd.DataFrame) -> pd.DataFrame:
|
||||
pd.DataFrame:
|
||||
sorted dataframe
|
||||
"""
|
||||
if df.index.is_lexsorted():
|
||||
idx = df.index if axis == 0 else df.columns
|
||||
if idx.is_monotonic_increasing:
|
||||
return df
|
||||
else:
|
||||
return df.sort_index()
|
||||
return df.sort_index(axis=axis)
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="."):
|
||||
|
||||
@@ -1,103 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import fire
|
||||
import zipfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GetData:
|
||||
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
delete_zip_file : bool, optional
|
||||
Whether to delete the zip file, value from True or False, by default False
|
||||
"""
|
||||
self.delete_zip_file = delete_zip_file
|
||||
|
||||
def _download_data(self, file_name: str, target_dir: [Path, str]):
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
target_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
url = f"{self.REMOTE_URL}/{file_name}"
|
||||
target_path = target_dir.joinpath(file_name)
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chuck_size = 1024
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chuck in resp.iter_content(chunk_size=chuck_size):
|
||||
fp.write(chuck)
|
||||
p_bar.update(chuck_size)
|
||||
|
||||
self._unzip(target_path, target_dir)
|
||||
if self.delete_zip_file:
|
||||
target_path.unlike()
|
||||
|
||||
@staticmethod
|
||||
def _unzip(file_path: Path, target_dir: Path):
|
||||
logger.info(f"{file_path} unzipping......")
|
||||
with zipfile.ZipFile(str(file_path.resolve()), "r") as zp:
|
||||
for _file in tqdm(zp.namelist()):
|
||||
zp.extract(_file, str(target_dir.resolve()))
|
||||
|
||||
def qlib_data(
|
||||
self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
|
||||
):
|
||||
"""download cn qlib data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
name: str
|
||||
dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data
|
||||
version: str
|
||||
data version, value from [v0, v1, ..., latest], by default latest
|
||||
interval: str
|
||||
data freq, value from [1d], by default 1d
|
||||
region: str
|
||||
data region, value from [cn, us], by default cn
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
|
||||
self._download_data(file_name.lower(), target_dir)
|
||||
|
||||
def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"):
|
||||
"""download cn csv data from remote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_dir: str
|
||||
data save directory
|
||||
|
||||
Examples
|
||||
---------
|
||||
python get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
-------
|
||||
|
||||
"""
|
||||
file_name = "csv_data_cn.zip"
|
||||
self._download_data(file_name, target_dir)
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -22,6 +22,8 @@ from qlib.contrib.evaluate import (
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
from qlib.tests.get_data import GetData
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
market = "csi300"
|
||||
@@ -156,26 +158,12 @@ def backtest_analysis(pred, rid):
|
||||
return analysis_df
|
||||
|
||||
|
||||
class TestAllFlow(unittest.TestCase):
|
||||
class TestAllFlow(TestAutoData):
|
||||
PRED_SCORE = None
|
||||
REPORT_NORMAL = None
|
||||
POSITIONS = None
|
||||
RID = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data_simple" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(
|
||||
name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
|
||||
)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
|
||||
|
||||
53
tests/test_dataset.py
Normal file
53
tests/test_dataset.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestDataset(TestAutoData):
|
||||
def testTSDataset(self):
|
||||
tsdh = TSDatasetH(
|
||||
handler={
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi300",
|
||||
},
|
||||
},
|
||||
segments={
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
)
|
||||
_ = tsdh.prepare("train") # Test the correctness
|
||||
tsds = tsdh.prepare("valid") # prepare a dataset with is friendly to converting tabular data to time-series
|
||||
|
||||
# The dimension of sample is same as tabular data, but it will return timeseries data of the sample
|
||||
|
||||
# We have two method to get the time-series of a sample
|
||||
|
||||
# 1) sample by int index directly
|
||||
tsds[len(tsds) - 1]
|
||||
|
||||
# 2) sample by <datetime,instrument> index
|
||||
data_from_ds = tsds["2016-12-31", "SZ300315"]
|
||||
|
||||
# Check the data
|
||||
# Get data from DataFrame Directly
|
||||
data_from_df = tsdh._handler.fetch().loc(axis=0)["2015-01-01":"2016-12-31", "SZ300315"].iloc[-30:].values
|
||||
|
||||
equal = np.isclose(data_from_df, data_from_ds)
|
||||
self.assertTrue(equal[~np.isnan(data_from_df)].all())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
Reference in New Issue
Block a user