mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
* Fix TSDataSampler Slicing Bug #1716 * Fix TSDataSampler Slicing Bug #1716 * Fix TSDataSampler Slicing Bug #1716 * Fix TSDataSampler Slicing Bug with simplyer implmentation#1716 with Simplified Implementation * Refactor: Fix CI errors by addressing pylint formatting issues * Refactor: Remove extraneous whitespace for improved code formatting with Black
This commit is contained in:
@@ -5,8 +5,9 @@ import unittest
|
||||
import pytest
|
||||
import sys
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
from qlib.data.dataset import TSDatasetH, TSDataSampler
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import time
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
@@ -98,6 +99,54 @@ class TestDataset(TestAutoData):
|
||||
print(idx[i])
|
||||
|
||||
|
||||
class TestTSDataSampler(unittest.TestCase):
|
||||
def test_TSDataSampler(self):
|
||||
"""
|
||||
Test TSDataSampler for issue #1716
|
||||
"""
|
||||
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
|
||||
instruments = ["000001", "000002", "000003", "000004", "000005"]
|
||||
index = pd.MultiIndex.from_product(
|
||||
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
|
||||
)
|
||||
data = np.random.randn(len(datetime_list) * len(instruments))
|
||||
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
|
||||
dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2)
|
||||
print()
|
||||
print("--------------dataset[0]--------------")
|
||||
print(dataset[0])
|
||||
print("--------------dataset[1]--------------")
|
||||
print(dataset[1])
|
||||
assert len(dataset[0]) == 2
|
||||
self.assertTrue(np.isnan(dataset[0][0]))
|
||||
self.assertEqual(dataset[0][1], dataset[1][0])
|
||||
self.assertEqual(dataset[1][1], dataset[2][0])
|
||||
self.assertEqual(dataset[2][1], dataset[3][0])
|
||||
|
||||
def test_TSDataSampler2(self):
|
||||
"""
|
||||
Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front
|
||||
"""
|
||||
datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"]
|
||||
instruments = ["000001", "000002", "000003", "000004", "000005"]
|
||||
index = pd.MultiIndex.from_product(
|
||||
[pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"]
|
||||
)
|
||||
data = np.random.randn(len(datetime_list) * len(instruments))
|
||||
test_df = pd.DataFrame(data=data, index=index, columns=["factor"])
|
||||
dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3)
|
||||
print()
|
||||
print("--------------dataset[0]--------------")
|
||||
print(dataset[0])
|
||||
print("--------------dataset[1]--------------")
|
||||
print(dataset[1])
|
||||
for i in range(3):
|
||||
self.assertFalse(np.isnan(dataset[0][i]))
|
||||
self.assertFalse(np.isnan(dataset[1][i]))
|
||||
self.assertEqual(dataset[0][1], dataset[1][0])
|
||||
self.assertEqual(dataset[0][2], dataset[1][1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user