Source code for tensornet.data.datasets.cifar10

from torchvision import datasets

from tensornet.data.datasets.dataset import BaseDataset


[docs]class CIFAR10(BaseDataset): """CIFAR-10 Dataset. `Note`: This dataset inherits the ``BaseDataset`` class. """ def _download(self, train=True, apply_transform=True): """Download dataset. Args: train (:obj:`bool`, optional): True for training data. (default: True) apply_transform (:obj:`bool`, optional): True if transform is to be applied on the dataset. (default: True) Returns: Downloaded dataset. """ transform = None if apply_transform: transform = self.train_transform if train else self.val_transform return datasets.CIFAR10( self.path, train=train, download=True, transform=transform ) def _get_image_size(self): """Return shape of data i.e. image size.""" return (3, 32, 32) def _get_classes(self): """Return list of classes in the dataset.""" return ( 'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ) def _get_mean(self): """Returns mean of the entire dataset.""" return (0.49139, 0.48215, 0.44653) def _get_std(self): """Returns standard deviation of the entire dataset.""" return (0.24703, 0.24348, 0.26158)