robustness.datasets module

Module containing all the supported datasets, which are subclasses of the abstract class robustness.datasets.DataSet.

Currently supported datasets:

Using robustness as a general training library (Part 2: Customizing training) shows how to add custom datasets to the library.

class robustness.datasets.DataSet(ds_name, data_path, **kwargs)

Bases: object

Base class for representing a dataset. Meant to be subclassed, with subclasses implementing the get_model function.

Parameters:
  • ds_name (str) – string identifier for the dataset
  • data_path (str) – path to the dataset
  • num_classes (int) – required kwarg, the number of classes in the dataset
  • mean (ch.tensor) – required kwarg, the mean to normalize the dataset with (e.g. ch.tensor([0.4914, 0.4822, 0.4465]) for CIFAR-10)
  • std (ch.tensor) – required kwarg, the standard deviation to normalize the dataset with (e.g. ch.tensor([0.2023, 0.1994, 0.2010]) for CIFAR-10)
  • custom_class (type) – required kwarg, a torchvision.models class corresponding to the dataset, if it exists (otherwise None)
  • label_mapping (dict[int,str]) – required kwarg, a dictionary mapping from class numbers to human-interpretable class names (can be None)
  • transform_train (torchvision.transforms) – required kwarg, transforms to apply to the training images from the dataset
  • transform_test (torchvision.transforms) – required kwarg, transforms to apply to the validation images from the dataset
get_model(arch, pretrained)

Should be overriden by subclasses. Also, you will probably never need to call this function, and should instead by using model_utils.make_and_restore_model.

Parameters:
  • arch (str) – name of architecture
  • pretrained (bool) – whether to try to load torchvision pretrained checkpoint
Returns:

A model with the given architecture that works for each dataset (e.g. with the right input/output dimensions).

make_loaders(workers, batch_size, data_aug=True, subset=None, subset_start=0, subset_type='rand', val_batch_size=None, only_val=False)
Parameters:
  • workers (int) – number of workers for data fetching (required). batch_size (int) : batch size for the data loaders (required).
  • data_aug (bool) – whether or not to do train data augmentation.
  • subset (None|int) – if given, the returned training data loader will only use a subset of the training data; this should be a number specifying the number of training data points to use.
  • subset_start (int) – only used if subset is not None; this specifies the starting index of the subset.
  • subset_type ("rand"|"first"|"last") – only used if subset is not `None; “rand” selects the subset randomly, “first” uses the first subset images of the training data, and “last” uses the last subset images of the training data.
  • val_batch_size (None|int) – if not None, specifies a different batch size for the validation set loader.
  • only_val (bool) – If True, returns None in place of the training data loader
Returns:

A training loader and validation loader according to the parameters given. These are standard PyTorch data loaders, and thus can just be used via:

>>> train_loader, val_loader = ds.make_loaders(workers=8, batch_size=128)
>>> for im, lab in train_loader:
>>>     # Do stuff...

class robustness.datasets.ImageNet(data_path, **kwargs)

Bases: robustness.datasets.DataSet

ImageNet Dataset [DDS+09].

Requires ImageNet in ImageFolder-readable format. ImageNet can be downloaded from http://www.image-net.org. See here for more information about the format.

[DDS+09]Deng, J., Dong, W., Socher, R., Li, L., Li, K., & Fei-Fei, L. (2009). ImageNet: A large-scale hierarchical image database. 2009 IEEE Conference on Computer Vision and Pattern Recognition, 248-255.
get_model(arch, pretrained)
class robustness.datasets.RestrictedImageNet(data_path, **kwargs)

Bases: robustness.datasets.DataSet

RestrictedImagenet Dataset [TSE+19]

A subset of ImageNet with the following labels:

  • Dog (classes 151-268)
  • Cat (classes 281-285)
  • Frog (classes 30-32)
  • Turtle (classes 33-37)
  • Bird (classes 80-100)
  • Monkey (classes 365-382)
  • Fish (classes 389-397)
  • Crab (classes 118-121)
  • Insect (classes 300-319)

To initialize, just provide the path to the full ImageNet dataset (no special formatting required).

[TSE+19]Tsipras, D., Santurkar, S., Engstrom, L., Turner, A., & Madry, A. (2019). Robustness May Be at Odds with Accuracy. ICLR 2019.
get_model(arch, pretrained)
class robustness.datasets.CIFAR(data_path='/tmp/', **kwargs)

Bases: robustness.datasets.DataSet

CIFAR-10 dataset [Kri09].

A dataset with 50k training images and 10k testing images, with the following classes:

  • Airplane
  • Automobile
  • Bird
  • Cat
  • Deer
  • Dog
  • Frog
  • Horse
  • Ship
  • Truck
[Kri09]Krizhevsky, A (2009). Learning Multiple Layers of Features from Tiny Images. Technical Report.
get_model(arch, pretrained)
class robustness.datasets.CINIC(data_path, **kwargs)

Bases: robustness.datasets.DataSet

CINIC-10 dataset [DCA+18].

A dataset with the same classes as CIFAR-10, but with downscaled images from various matching ImageNet classes added in to increase the size of the dataset.

[DCA+18]Darlow L.N., Crowley E.J., Antoniou A., and A.J. Storkey (2018) CINIC-10 is not ImageNet or CIFAR-10. Report EDI-INF-ANC-1802 (arXiv:1810.03505)
get_model(arch, pretrained)
class robustness.datasets.A2B(data_path, **kwargs)

Bases: robustness.datasets.DataSet

A-to-B datasets [ZPI+17]

A general class for image-to-image translation dataset. Currently supported are:

  • Horse <-> Zebra
  • Apple <-> Orange
  • Summer <-> Winter
[ZPI+17]Zhu, J., Park, T., Isola, P., & Efros, A.A. (2017). Unpaired Image-to-Image Translation Using Cycle-Consistent Adversarial Networks. 2017 IEEE International Conference on Computer Vision (ICCV), 2242-2251.
get_model(arch, pretrained=False)
robustness.datasets.DATASETS = {'a2b': <class 'robustness.datasets.A2B'>, 'cifar': <class 'robustness.datasets.CIFAR'>, 'cinic': <class 'robustness.datasets.CINIC'>, 'imagenet': <class 'robustness.datasets.ImageNet'>, 'restricted_imagenet': <class 'robustness.datasets.RestrictedImageNet'>}
>>> import robustness.datasets
>>> ds = datasets.DATASETS['cifar']('/path/to/cifar')
Type:Dictionary of datasets. A dataset class can be accessed as