1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 02:21:18 +08:00
Files
qlib/qlib/model/utils.py
you-n-g 60d45ad770 Enhance pytorch nn (#917)
* enhance pytorch_nn

* fix dim bug

* Black format

* Fix pylint error
2022-02-15 19:22:48 +08:00

27 lines
579 B
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from torch.utils.data import Dataset
class ConcatDataset(Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
def __len__(self):
return min(len(d) for d in self.datasets)
class IndexSampler:
def __init__(self, sampler):
self.sampler = sampler
def __getitem__(self, i: int):
return self.sampler[i], i
def __len__(self):
return len(self.sampler)