
Source code for torchvision.datasets.voc

import os
import sys
import tarfile
import collections
from .vision import VisionDataset

if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
    import xml.etree.ElementTree as ET

from PIL import Image
from .utils import download_url, check_integrity, verify_str_arg

    '2012': {
        'url': '',
        'filename': 'VOCtrainval_11-May-2012.tar',
        'md5': '6cd6e144f989b92b3379bac3b3de84fd',
        'base_dir': 'VOCdevkit/VOC2012'
    '2011': {
        'url': '',
        'filename': 'VOCtrainval_25-May-2011.tar',
        'md5': '6c3384ef61512963050cb5d687e5bf1e',
        'base_dir': 'TrainVal/VOCdevkit/VOC2011'
    '2010': {
        'url': '',
        'filename': 'VOCtrainval_03-May-2010.tar',
        'md5': 'da459979d0c395079b5c75ee67908abb',
        'base_dir': 'VOCdevkit/VOC2010'
    '2009': {
        'url': '',
        'filename': 'VOCtrainval_11-May-2009.tar',
        'md5': '59065e4b188729180974ef6572f6a212',
        'base_dir': 'VOCdevkit/VOC2009'
    '2008': {
        'url': '',
        'filename': 'VOCtrainval_11-May-2012.tar',
        'md5': '2629fa636546599198acfcfbfcf1904a',
        'base_dir': 'VOCdevkit/VOC2008'
    '2007': {
        'url': '',
        'filename': 'VOCtrainval_06-Nov-2007.tar',
        'md5': 'c52e279531787c972589f7e41ab4ae64',
        'base_dir': 'VOCdevkit/VOC2007'

[docs]class VOCSegmentation(VisionDataset): """`Pascal VOC <>`_ Segmentation Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """ def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None, transforms=None): super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform) self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] valid_sets = ["train", "trainval", "val"] if year == "2007": valid_sets.append("test") self.image_set = verify_str_arg(image_set, "image_set", valid_sets) base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, 'JPEGImages') mask_dir = os.path.join(voc_root, 'SegmentationClass') if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] assert (len(self.images) == len(self.masks))
[docs] def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img =[index]).convert('RGB') target =[index]) if self.transforms is not None: img, target = self.transforms(img, target) return img, target
def __len__(self): return len(self.images)
[docs]class VOCDetection(VisionDataset): """`Pascal VOC <>`_ Detection Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. (default: alphabetic indexing of VOC's 20 classes). transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, required): A function/transform that takes in the target and transforms it. transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. """ def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None, transforms=None): super(VOCDetection, self).__init__(root, transforms, transform, target_transform) self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] valid_sets = ["train", "trainval", "val"] if year == "2007": valid_sets.append("test") self.image_set = verify_str_arg(image_set, "image_set", valid_sets) base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, 'JPEGImages') annotation_dir = os.path.join(voc_root, 'Annotations') if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') splits_dir = os.path.join(voc_root, 'ImageSets/Main') split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names] assert (len(self.images) == len(self.annotations))
[docs] def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is a dictionary of the XML tree. """ img =[index]).convert('RGB') target = self.parse_voc_xml( ET.parse(self.annotations[index]).getroot()) if self.transforms is not None: img, target = self.transforms(img, target) return img, target
def __len__(self): return len(self.images) def parse_voc_xml(self, node): voc_dict = {} children = list(node) if children: def_dic = collections.defaultdict(list) for dc in map(self.parse_voc_xml, children): for ind, v in dc.items(): def_dic[ind].append(v) voc_dict = { node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()} } if node.text: text = node.text.strip() if not children: voc_dict[node.tag] = text return voc_dict
def download_extract(url, root, filename, md5): download_url(url, root, filename, md5) with, filename), "r") as tar: tar.extractall(path=root)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources