Shortcuts

Source code for torchvision.datasets.mnist

from __future__ import print_function
from .vision import VisionDataset
import warnings
from PIL import Image
import os
import os.path
import gzip
import numpy as np
import torch
import codecs
from .utils import download_url, makedir_exist_ok


[docs]class MNIST(VisionDataset): """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. Args: root (string): Root directory of dataset where ``MNIST/processed/training.pt`` and ``MNIST/processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ urls = [ 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', ] training_file = 'training.pt' test_file = 'test.pt' classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] @property def train_labels(self): warnings.warn("train_labels has been renamed targets") return self.targets @property def test_labels(self): warnings.warn("test_labels has been renamed targets") return self.targets @property def train_data(self): warnings.warn("train_data has been renamed data") return self.data @property def test_data(self): warnings.warn("test_data has been renamed data") return self.data def __init__(self, root, train=True, transform=None, target_transform=None, download=False): super(MNIST, self).__init__(root) self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') if self.train: data_file = self.training_file else: data_file = self.test_file self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], int(self.targets[index]) # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.data) @property def raw_folder(self): return os.path.join(self.root, self.__class__.__name__, 'raw') @property def processed_folder(self): return os.path.join(self.root, self.__class__.__name__, 'processed') @property def class_to_idx(self): return {_class: i for i, _class in enumerate(self.classes)} def _check_exists(self): return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and os.path.exists(os.path.join(self.processed_folder, self.test_file))) @staticmethod def extract_gzip(gzip_path, remove_finished=False): print('Extracting {}'.format(gzip_path)) with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \ gzip.GzipFile(gzip_path) as zip_f: out_f.write(zip_f.read()) if remove_finished: os.unlink(gzip_path) def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" if self._check_exists(): return makedir_exist_ok(self.raw_folder) makedir_exist_ok(self.processed_folder) # download files for url in self.urls: filename = url.rpartition('/')[2] file_path = os.path.join(self.raw_folder, filename) download_url(url, root=self.raw_folder, filename=filename, md5=None) self.extract_gzip(gzip_path=file_path, remove_finished=True) # process and save as torch files print('Processing...') training_set = ( read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) ) test_set = ( read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) ) with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: torch.save(training_set, f) with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f: torch.save(test_set, f) print('Done!') def extra_repr(self): return "Split: {}".format("Train" if self.train is True else "Test")
[docs]class FashionMNIST(MNIST): """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset. Args: root (string): Root directory of dataset where ``Fashion-MNIST/processed/training.pt`` and ``Fashion-MNIST/processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ urls = [ 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', ] classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
[docs]class KMNIST(MNIST): """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset. Args: root (string): Root directory of dataset where ``KMNIST/processed/training.pt`` and ``KMNIST/processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ urls = [ 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz', 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz', 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz', 'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz', ] classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']
[docs]class EMNIST(MNIST): """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset. Args: root (string): Root directory of dataset where ``EMNIST/processed/training.pt`` and ``EMNIST/processed/test.pt`` exist. split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies which one to use. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist url = 'https://cloudstor.aarnet.edu.au/plus/index.php/s/54h3OuGJhFLwAlQ/download' splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') def __init__(self, root, split, **kwargs): if split not in self.splits: raise ValueError('Split "{}" not found. Valid splits are: {}'.format( split, ', '.join(self.splits), )) self.split = split self.training_file = self._training_file(split) self.test_file = self._test_file(split) super(EMNIST, self).__init__(root, **kwargs) @staticmethod def _training_file(split): return 'training_{}.pt'.format(split) @staticmethod def _test_file(split): return 'test_{}.pt'.format(split) def download(self): """Download the EMNIST data if it doesn't exist in processed_folder already.""" import shutil import zipfile if self._check_exists(): return makedir_exist_ok(self.raw_folder) makedir_exist_ok(self.processed_folder) # download files filename = self.url.rpartition('/')[2] file_path = os.path.join(self.raw_folder, filename) download_url(self.url, root=self.raw_folder, filename=filename, md5=None) print('Extracting zip archive') with zipfile.ZipFile(file_path) as zip_f: zip_f.extractall(self.raw_folder) os.unlink(file_path) gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file)) # process and save as torch files for split in self.splits: print('Processing ' + split) training_set = ( read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) ) test_set = ( read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) ) with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f: torch.save(training_set, f) with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f: torch.save(test_set, f) shutil.rmtree(gzip_folder) print('Done!')
def get_int(b): return int(codecs.encode(b, 'hex'), 16) def read_label_file(path): with open(path, 'rb') as f: data = f.read() assert get_int(data[:4]) == 2049 length = get_int(data[4:8]) parsed = np.frombuffer(data, dtype=np.uint8, offset=8) return torch.from_numpy(parsed).view(length).long() def read_image_file(path): with open(path, 'rb') as f: data = f.read() assert get_int(data[:4]) == 2051 length = get_int(data[4:8]) num_rows = get_int(data[8:12]) num_cols = get_int(data[12:16]) parsed = np.frombuffer(data, dtype=np.uint8, offset=16) return torch.from_numpy(parsed).view(length, num_rows, num_cols)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources