Check that the required training arguments are present.
- args (argparse object) – the arguments to check
- eval_only (bool) – whether to check only the arguments for evaluation
make_optimizer_and_schedule(args, model, checkpoint)¶
Internal Function (called directly from train_model)
Creates an optimizer and a schedule for a given model, restoring from a checkpoint if it is non-null.
- An optimizer (ch.nn.optim.Optimizer) and a scheduler
eval_model(args, model, loader, store)¶
Evaluate a model for standard (and optionally adversarial) accuracy.
- args (object) – A list of arguments—should be a python object
- model (AttackerModel) – model to evaluate
- loader (iterable) – a dataloader serving (input, label) batches from the validation set
- store (cox.Store) – store for saving results in (via tensorboardX)
- args (object) – A list of arguments—should be a python object implementing
train_model(args, model, loaders, *, checkpoint=None, store=None)¶
Main function for training a model.
- args (object) –
A python object for arguments, implementing
setattr()and having the following attributes. See
robustness.defaults.TRAINING_ARGSfor a list of arguments, and you can use
robustness.defaults.check_and_fill_args()to make sure that all required arguments are filled and to fill missing args with reasonable defaults:
- adv_train (int or bool, required)
- if 1/True, adversarially train, otherwise if 0/False do standard training
- epochs (int, required)
- number of epochs to train for
- lr (float, required)
- learning rate for SGD optimizer
- weight_decay (float, required)
- weight decay for SGD optimizer
- momentum (float, required)
- momentum parameter for SGD optimizer
- step_lr (int)
- if given, drop learning rate by 10x every step_lr steps
- custom_schedule (str)
- If given, use a custom LR schedule (format: [(epoch, LR),…])
- adv_eval (int or bool)
- If True/1, then also do adversarial evaluation, otherwise skip (ignored if adv_train is True)
- log_iters (int, required)
- How frequently (in epochs) to save training logs
- save_ckpt_iters (int, required)
- How frequently (in epochs) to save checkpoints (if -1, then only save latest and best ckpts)
- attack_lr (float or str, required if adv_train or adv_eval)
- float (or float-parseable string) for the adv attack step size
- constraint (str, required if adv_train or adv_eval)
- the type of adversary constraint
- eps (float or str, required if adv_train or adv_eval)
- float (or float-parseable string) for the adv attack budget
- attack_steps (int, required if adv_train or adv_eval)
- number of steps to take in adv attack
- eps_fadein_epochs (int, required if adv_train or adv_eval)
- If greater than 0, fade in epsilon along this many epochs
- use_best (int or bool, required if adv_train or adv_eval) :
- If True/1, use the best (in terms of loss) PGD step as the attack, if False/0 use the last step
- random_restarts (int, required if adv_train or adv_eval)
- Number of random restarts to use for adversarial evaluation
- custom_train_loss (function, optional)
- If given, a custom loss instead of the default CrossEntropyLoss. Takes in (logits, targets) and returns a scalar.
- custom_adv_loss (function, required if custom_train_loss)
- If given, a custom loss function for the adversary. The custom loss function takes in model, input, target and should return a vector representing the loss for each element of the batch, as well as the classifier output.
- regularizer (function, optional)
- If given, this function of model, input, target returns a (scalar) that is added on to the training loss without being subject to adversarial attack
- iteration_hook (function, optional)
- If given, this function is called every training iteration by the training loop (useful for custom logging). The function is given arguments model, iteration #, loop_type [train/eval], current_batch_ims, current_batch_labels.
- epoch hook (function, optional)
- Similar to iteration_hook but called every epoch instead, and given arguments model, log_info where log_info is a dictionary with keys epoch, nat_prec1, adv_prec1, nat_loss, adv_loss, train_prec1, train_loss.
- model (AttackerModel) – the model to train.
- loaders (tuple[iterable]) – tuple of data loaders of the form (train_loader, val_loader)
- checkpoint (dict) – a loaded checkpoint previously saved by this library (if resuming from checkpoint)
- store (cox.Store) – a cox store for logging training progress
- args (object) –