Source code for tensornet.data.datasets.tinyimagenet

import os
import csv
import requests
import zipfile
import numpy as np

from io import BytesIO
from PIL import Image
from torch.utils.data import Dataset

from tensornet.data.datasets.dataset import BaseDataset


[docs]class TinyImageNet(BaseDataset): """Tiny ImageNet Dataset. `Note`: This dataset inherits the ``BaseDataset`` class. """ def _download(self, train=True, apply_transform=True): """Download dataset. Args: train (:obj:`bool`, optional): True for training data. (default: True) apply_transform (:obj:`bool`, optional): True if transform is to be applied on the dataset. (default: True) Returns: (`TinyImageNetDataset`): Downloaded dataset. """ if not self.path.endswith('tinyimagenet'): self.path = os.path.join(self.path, 'tinyimagenet') transform = None if apply_transform: transform = self.train_transform if train else self.val_transform return TinyImageNetDataset( self.path, train=train, train_split=self.train_split, transform=transform ) def _get_image_size(self): """Return shape of data i.e. image size.""" return (3, 64, 64) def _get_classes(self): """Return list of classes in the dataset.""" return self.train_data.classes def _get_mean(self): """Returns mean of the entire dataset.""" return (0.4914, 0.4822, 0.4465) def _get_std(self): """Returns standard deviation of the entire dataset.""" return (0.2023, 0.1994, 0.2010)
class TinyImageNetDataset(Dataset): """Create Tiny ImageNet Dataset. Args: path (str): Path where dataset will be downloaded. train (:obj:`bool`, optional): True for training data. (default: True) train_split (:obj:`float`, optional): Fraction of dataset to assign for training. (default: 0.7) download (:obj:`bool`, optional): If True, dataset will be downloaded. (default: True) random_seed (:obj:`int`, optional): Random seed value. This is required for splitting the data into training and validation datasets. (default: 1) transform (optional): Transformations to apply on the dataset. (default: None) """ def __init__(self, path, train=True, train_split=0.7, download=True, random_seed=1, transform=None): """Initializes the dataset for loading.""" super(TinyImageNetDataset, self).__init__() self.path = path self.train = train self.train_split = train_split self.transform = transform self._validate_params() # Download dataset if download: self.download() self._class_ids = self._get_class_map() self.data, self.targets = self._load_data() self._image_indices = np.arange(len(self.targets)) np.random.seed(random_seed) np.random.shuffle(self._image_indices) split_idx = int(len(self._image_indices) * train_split) self._image_indices = self._image_indices[:split_idx] if train else self._image_indices[split_idx:] def __len__(self): """Returns length of the dataset.""" return len(self._image_indices) def __getitem__(self, index): """Fetch an item from the dataset. Args: index (int): Index of the item to fetch. Returns: An image and its corresponding label. """ image_index = self._image_indices[index] image = self.data[image_index] if not self.transform is None: image = self.transform(image) return image, self.targets[image_index] def __repr__(self): """Representation string for the dataset object.""" head = 'Dataset TinyImageNet' body = ['Number of datapoints: {}'.format(self.__len__())] if self.path is not None: body.append('Root location: {}'.format(self.path)) body += [f'Split: {"Train" if self.train else "Test"}'] if hasattr(self, 'transforms') and self.transforms is not None: body += [repr(self.transforms)] lines = [head] + [' ' * 4 + line for line in body] return '\n'.join(lines) def _validate_params(self): """Validate input parameters.""" if self.train_split > 1: raise ValueError('train_split must be less than 1') @property def classes(self): """List of classes present in the dataset.""" return tuple(x[1]['name'] for x in sorted( self._class_ids.items(), key=lambda y: y[1]['id'] )) def _get_class_map(self): """Create a mapping from class id to the class name.""" with open(os.path.join(self.path, 'wnids.txt')) as f: class_ids = {x[:-1]: '' for x in f.readlines()} with open(os.path.join(self.path, 'words.txt')) as f: class_id = 0 for line in csv.reader(f, delimiter='\t'): if line[0] in class_ids: # class_ids[line[0]] = line[1].split(',')[0].lower() class_ids[line[0]] = { 'name': line[1], 'id': class_id } class_id += 1 return class_ids def _load_image(self, image_path): """Load an image from the dataset. Args: image_path (str): Path of the image. Returns: PIL object of the image. """ image = Image.open(image_path) # Convert grayscale image to RGB if image.mode == 'L': image = np.array(image) image = np.stack((image,) * 3, axis=-1) image = Image.fromarray(image.astype('uint8'), 'RGB') return image def _load_data(self): """Fetch data from each data directory and store them in a list.""" data, targets = [], [] # Fetch train dir images train_path = os.path.join(self.path, 'train') for class_dir in os.listdir(train_path): train_images_path = os.path.join(train_path, class_dir, 'images') for image in os.listdir(train_images_path): if image.lower().endswith('.jpeg'): data.append( self._load_image(os.path.join(train_images_path, image)) ) targets.append(self._class_ids[class_dir]['id']) # Fetch val dir images val_path = os.path.join(self.path, 'val') val_images_path = os.path.join(val_path, 'images') with open(os.path.join(val_path, 'val_annotations.txt')) as f: for line in csv.reader(f, delimiter='\t'): data.append( self._load_image(os.path.join(val_images_path, line[0])) ) targets.append(self._class_ids[line[1]]['id']) return data, targets def download(self): """Download the data if it does not exist.""" if not os.path.exists(self.path): print('Downloading dataset...') r = requests.get('http://cs231n.stanford.edu/tiny-imagenet-200.zip', stream=True) zip_ref = zipfile.ZipFile(BytesIO(r.content)) zip_ref.extractall(os.path.dirname(self.path)) zip_ref.close() # Move file to appropriate location os.rename( os.path.join(os.path.dirname(self.path), 'tiny-imagenet-200'), self.path ) print('Done.') else: print('Files already downloaded.')