Source code for tensornet.data.datasets.mnist

from torchvision import datasets

from tensornet.data.datasets.dataset import BaseDataset


[docs]class MNIST(BaseDataset): """MNIST Dataset. `Note`: This dataset inherits the ``BaseDataset`` class. """ def _download(self, train=True, apply_transform=True): """Download dataset. Args: train (:obj:`bool`, optional): True for training data. (default: True) apply_transform (:obj:`bool`, optional): True if transform is to be applied on the dataset. (default: True) Returns: Downloaded dataset. """ transform = None if apply_transform: transform = self.train_transform if train else self.val_transform return datasets.MNIST( self.path, train=train, download=True, transform=transform ) def _get_image_size(self): """Return shape of data i.e. image size.""" return (1, 28, 28) def _get_classes(self): """Return list of classes in the dataset.""" return tuple([ str(x) for x in range(10) ]) def _get_mean(self): """Returns mean of the entire dataset.""" return 0.1307 def _get_std(self): """Returns standard deviation of the entire dataset.""" return 0.3081