Source code for

import os
import csv
import requests
import zipfile
import numpy as np

from io import BytesIO
from PIL import Image
from import Dataset

from import BaseDataset

[docs]class TinyImageNet(BaseDataset): """Tiny ImageNet Dataset. `Note`: This dataset inherits the ``BaseDataset`` class. """ def _download(self, train=True, apply_transform=True): """Download dataset. Args: train (:obj:`bool`, optional): True for training data. (default: True) apply_transform (:obj:`bool`, optional): True if transform is to be applied on the dataset. (default: True) Returns: (`TinyImageNetDataset`): Downloaded dataset. """ if not self.path.endswith('tinyimagenet'): self.path = os.path.join(self.path, 'tinyimagenet') transform = None if apply_transform: transform = self.train_transform if train else self.val_transform return TinyImageNetDataset( self.path, train=train, train_split=self.train_split, transform=transform ) def _get_image_size(self): """Return shape of data i.e. image size.""" return (3, 64, 64) def _get_classes(self): """Return list of classes in the dataset.""" return self.train_data.classes def _get_mean(self): """Returns mean of the entire dataset.""" return (0.4914, 0.4822, 0.4465) def _get_std(self): """Returns standard deviation of the entire dataset.""" return (0.2023, 0.1994, 0.2010)
class TinyImageNetDataset(Dataset): """Create Tiny ImageNet Dataset. Args: path (str): Path where dataset will be downloaded. train (:obj:`bool`, optional): True for training data. (default: True) train_split (:obj:`float`, optional): Fraction of dataset to assign for training. (default: 0.7) download (:obj:`bool`, optional): If True, dataset will be downloaded. (default: True) random_seed (:obj:`int`, optional): Random seed value. This is required for splitting the data into training and validation datasets. (default: 1) transform (optional): Transformations to apply on the dataset. (default: None) """ def __init__(self, path, train=True, train_split=0.7, download=True, random_seed=1, transform=None): """Initializes the dataset for loading.""" super(TinyImageNetDataset, self).__init__() self.path = path self.train = train self.train_split = train_split self.transform = transform self._validate_params() # Download dataset if download: self._class_ids = self._get_class_map(), self.targets = self._load_data() self._image_indices = np.arange(len(self.targets)) np.random.seed(random_seed) np.random.shuffle(self._image_indices) split_idx = int(len(self._image_indices) * train_split) self._image_indices = self._image_indices[:split_idx] if train else self._image_indices[split_idx:] def __len__(self): """Returns length of the dataset.""" return len(self._image_indices) def __getitem__(self, index): """Fetch an item from the dataset. Args: index (int): Index of the item to fetch. Returns: An image and its corresponding label. """ image_index = self._image_indices[index] image =[image_index] if not self.transform is None: image = self.transform(image) return image, self.targets[image_index] def __repr__(self): """Representation string for the dataset object.""" head = 'Dataset TinyImageNet' body = ['Number of datapoints: {}'.format(self.__len__())] if self.path is not None: body.append('Root location: {}'.format(self.path)) body += [f'Split: {"Train" if self.train else "Test"}'] if hasattr(self, 'transforms') and self.transforms is not None: body += [repr(self.transforms)] lines = [head] + [' ' * 4 + line for line in body] return '\n'.join(lines) def _validate_params(self): """Validate input parameters.""" if self.train_split > 1: raise ValueError('train_split must be less than 1') @property def classes(self): """List of classes present in the dataset.""" return tuple(x[1]['name'] for x in sorted( self._class_ids.items(), key=lambda y: y[1]['id'] )) def _get_class_map(self): """Create a mapping from class id to the class name.""" with open(os.path.join(self.path, 'wnids.txt')) as f: class_ids = {x[:-1]: '' for x in f.readlines()} with open(os.path.join(self.path, 'words.txt')) as f: class_id = 0 for line in csv.reader(f, delimiter='\t'): if line[0] in class_ids: # class_ids[line[0]] = line[1].split(',')[0].lower() class_ids[line[0]] = { 'name': line[1], 'id': class_id } class_id += 1 return class_ids def _load_image(self, image_path): """Load an image from the dataset. Args: image_path (str): Path of the image. Returns: PIL object of the image. """ image = # Convert grayscale image to RGB if image.mode == 'L': image = np.array(image) image = np.stack((image,) * 3, axis=-1) image = Image.fromarray(image.astype('uint8'), 'RGB') return image def _load_data(self): """Fetch data from each data directory and store them in a list.""" data, targets = [], [] # Fetch train dir images train_path = os.path.join(self.path, 'train') for class_dir in os.listdir(train_path): train_images_path = os.path.join(train_path, class_dir, 'images') for image in os.listdir(train_images_path): if image.lower().endswith('.jpeg'): data.append( self._load_image(os.path.join(train_images_path, image)) ) targets.append(self._class_ids[class_dir]['id']) # Fetch val dir images val_path = os.path.join(self.path, 'val') val_images_path = os.path.join(val_path, 'images') with open(os.path.join(val_path, 'val_annotations.txt')) as f: for line in csv.reader(f, delimiter='\t'): data.append( self._load_image(os.path.join(val_images_path, line[0])) ) targets.append(self._class_ids[line[1]]['id']) return data, targets def download(self): """Download the data if it does not exist.""" if not os.path.exists(self.path): print('Downloading dataset...') r = requests.get('', stream=True) zip_ref = zipfile.ZipFile(BytesIO(r.content)) zip_ref.extractall(os.path.dirname(self.path)) zip_ref.close() # Move file to appropriate location os.rename( os.path.join(os.path.dirname(self.path), 'tiny-imagenet-200'), self.path ) print('Done.') else: print('Files already downloaded.')