Shortcuts

Source code for torchtime.datasets.timeseries

import csv
import os
from typing import Tuple, Any, Callable, List, Optional, Union

import torch
from torch.utils.data import Dataset

from ..transforms import Compose


[docs]class TimeSeriesDataset(Dataset): """ Base class for making datasets which are compatible with torchtime. It is necessary to override the ``__getitem__`` and ``__len__`` method. Args: root (string): Root directory of dataset. transforms (callable, optional): A function/transforms that takes in a time series and a label and returns the transformed versions of both. transform (callable, optional): A function/transform that takes in a uni- or multivariate time series and returns a transformed version. E.g, ``transforms.NaN2Value`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. """ _repr_indent = 4 def __init__( self, root: str, transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: torch._C._log_api_usage_once(f"torchtime.datasets.{self.__class__.__name__}") if isinstance(root, str): root = os.path.expanduser(root) self.root = root has_transforms = transforms is not None has_separate_transform = transform is not None or target_transform is not None if has_transforms and has_separate_transform: raise ValueError("Only transforms or transform/target_transform can be passed as argument!") if has_separate_transform: transforms = StandardTransform(transform, target_transform) self.transforms = transforms
[docs] def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (series, target) where target is index of the target class. """ raise NotImplementedError
def __len__(self) -> int: raise NotImplementedError def __repr__(self) -> str: head = "Dataset " + self.__class__.__name__ body = ["Number of datapoints: {}".format(self.__len__())] if self.root is not None: body.append("Root location: {}".format(self.root)) body += self.extra_repr().splitlines() if hasattr(self, "transforms") and self.transforms is not None: body += [repr(self.transforms)] lines = [head] + [" " * self._repr_indent + line for line in body] return '\n'.join(lines) def extra_repr(self) -> str: return ""
class TsvDataset(TimeSeriesDataset): """Create a Dataset for `.tsv` data. Args: root (str or Path): Path to the directory where the dataset is located. (Where the ``tsv`` file is present.) tsv (str, optional): The name of the tsv file used to construct the metadata, such as ``"train.tsv"``, ``"test.tsv"``, ``"dev.tsv"``, ``"invalidated.tsv"``, ``"validated.tsv"`` and ``"other.tsv"``. (default: ``"train.tsv"``) """ def __init__(self, root: str, tsv: str = "train.tsv", transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None ) -> None: super(TsvDataset, self).__init__(root, transforms, transform, target_transform) self._tsv = os.path.join(self.root, tsv) with open(self._tsv, "r") as tsv_: walker = csv.reader(tsv_, delimiter="\t") self._header = next(walker) self._walker = list(walker) def __getitem__(self, n: int) -> Tuple[torch.Tensor, int]: """Load the n-th sample from the dataset. Args: n (int): The index of the sample to be loaded Returns: (Tensor, int, Dict[str, str]): ``(waveform, sample_rate, dictionary)``, where dictionary is built from the TSV file with the following keys: ``client_id``, ``path``, ``sentence``, ``up_votes``, ``down_votes``, ``age``, ``gender`` and ``accent``. """ line = self._walker[n] sample = torch.as_tensor(line[:-1]) target = line[-1] return sample, target def __len__(self) -> int: return len(self._walker) class StandardTransform(object): def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): self.transform = transform self.target_transform = target_transform def __call__(self, data: Any, target: Any) -> Tuple[Any, Any]: if self.transform is not None: data = self.transform(data) if self.target_transform is not None: target = self.target_transform(target) return data, target def _add_transform(self, transform: Callable, target: bool = False): transform_ = self.target_transform if target else self.transform if transform_ is not None: if isinstance(transform_, Compose): transform_.add_transform(transform) else: if isinstance(transform, list): transform_ = Compose([transform_] + transform) else: transform_ = Compose([transform_, transform]) else: transform_ = transform if target: self.target_transform = transform_ else: self.transform = transform_ def add_transform(self, transform: Optional[Union[Callable, List[Callable]]] = None, target_transform: Optional[Union[Callable, List[Callable]]] = None): if transform is not None: self._add_transform(transform, target=False) if target_transform is not None: self._add_transform(target_transform, target=True) @staticmethod def _format_transform_repr(transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() return (["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) def __repr__(self) -> str: body = [self.__class__.__name__] if self.transform is not None: body += self._format_transform_repr(self.transform, "Transform: ") if self.target_transform is not None: body += self._format_transform_repr(self.target_transform, "Target transform: ") return '\n'.join(body)