robustness.train module

robustness.train.check_required_args(args, eval_only=False)

Check that the required training arguments are present.

Parameters:
  • args (argparse object) – the arguments to check
  • eval_only (bool) – whether to check only the arguments for evaluation
robustness.train.make_optimizer_and_schedule(args, model, checkpoint)

Internal Function (called directly from train_model)

Creates an optimizer and a schedule for a given model, restoring from a checkpoint if it is non-null.

Parameters:
  • args (object) – an arguments object, see train_model() for details
  • model (AttackerModel) – the model to create the optimizer for
  • checkpoint (dict) – a loaded checkpoint saved by this library and loaded with ch.load
Returns:

An optimizer (ch.nn.optim.Optimizer) and a scheduler

(ch.nn.optim.lr_schedulers module).

robustness.train.eval_model(args, model, loader, store)

Evaluate a model for standard (and optionally adversarial) accuracy.

Parameters:
  • args (object) – A list of arguments—should be a python object implementing getattr() and setattr().
  • model (AttackerModel) – model to evaluate
  • loader (iterable) – a dataloader serving (input, label) batches from the validation set
  • store (cox.Store) – store for saving results in (via tensorboardX)
robustness.train.train_model(args, model, loaders, *, checkpoint=None, store=None)

Main function for training a model.

Parameters:
  • args (object) –

    A python object for arguments, implementing getattr() and setattr() and having the following attributes. See robustness.defaults.TRAINING_ARGS for a list of arguments, and you can use robustness.defaults.check_and_fill_args() to make sure that all required arguments are filled and to fill missing args with reasonable defaults:

    adv_train (int or bool, required)
    if 1/True, adversarially train, otherwise if 0/False do standard training
    epochs (int, required)
    number of epochs to train for
    lr (float, required)
    learning rate for SGD optimizer
    weight_decay (float, required)
    weight decay for SGD optimizer
    momentum (float, required)
    momentum parameter for SGD optimizer
    step_lr (int)
    if given, drop learning rate by 10x every step_lr steps
    custom_schedule (str)
    If given, use a custom LR schedule (format: [(epoch, LR),…])
    adv_eval (int or bool)
    If True/1, then also do adversarial evaluation, otherwise skip (ignored if adv_train is True)
    log_iters (int, required)
    How frequently (in epochs) to save training logs
    save_ckpt_iters (int, required)
    How frequently (in epochs) to save checkpoints (if -1, then only save latest and best ckpts)
    attack_lr (float or str, required if adv_train or adv_eval)
    float (or float-parseable string) for the adv attack step size
    constraint (str, required if adv_train or adv_eval)
    the type of adversary constraint (robustness.attacker.STEPS)
    eps (float or str, required if adv_train or adv_eval)
    float (or float-parseable string) for the adv attack budget
    attack_steps (int, required if adv_train or adv_eval)
    number of steps to take in adv attack
    eps_fadein_epochs (int, required if adv_train or adv_eval)
    If greater than 0, fade in epsilon along this many epochs
    use_best (int or bool, required if adv_train or adv_eval) :
    If True/1, use the best (in terms of loss) PGD step as the attack, if False/0 use the last step
    random_restarts (int, required if adv_train or adv_eval)
    Number of random restarts to use for adversarial evaluation
    custom_train_loss (function, optional)
    If given, a custom loss instead of the default CrossEntropyLoss. Takes in (logits, targets) and returns a scalar.
    custom_adv_loss (function, required if custom_train_loss)
    If given, a custom loss function for the adversary. The custom loss function takes in model, input, target and should return a vector representing the loss for each element of the batch, as well as the classifier output.
    regularizer (function, optional)
    If given, this function of model, input, target returns a (scalar) that is added on to the training loss without being subject to adversarial attack
    iteration_hook (function, optional)
    If given, this function is called every training iteration by the training loop (useful for custom logging). The function is given arguments model, iteration #, loop_type [train/eval], current_batch_ims, current_batch_labels.
    epoch hook (function, optional)
    Similar to iteration_hook but called every epoch instead, and given arguments model, log_info where log_info is a dictionary with keys epoch, nat_prec1, adv_prec1, nat_loss, adv_loss, train_prec1, train_loss.
  • model (AttackerModel) – the model to train.
  • loaders (tuple[iterable]) – tuple of data loaders of the form (train_loader, val_loader)
  • checkpoint (dict) – a loaded checkpoint previously saved by this library (if resuming from checkpoint)
  • store (cox.Store) – a cox store for logging training progress