mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix tra dataset bug (#1050)
This commit is contained in:
@@ -6,8 +6,7 @@ import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -95,7 +94,7 @@ class MTSDatasetH(DatasetH):
|
||||
shuffle=True,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
assert horizon > 0, "please specify `horizon` to avoid data leakage"
|
||||
@@ -150,8 +149,15 @@ class MTSDatasetH(DatasetH):
|
||||
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
fn = _get_date_parse_fn(self._index[0][1])
|
||||
start_date = fn(slc.start)
|
||||
end_date = fn(slc.stop)
|
||||
|
||||
if isinstance(slc, slice):
|
||||
start, stop = slc.start, slc.stop
|
||||
elif isinstance(slc, (list, tuple)):
|
||||
start, stop = slc
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
start_date = fn(start)
|
||||
end_date = fn(stop)
|
||||
obj = copy.copy(self) # shallow copy
|
||||
# NOTE: Seriable will disable copy `self._data` so we manually assign them here
|
||||
obj._data = self._data
|
||||
|
||||
Reference in New Issue
Block a user