Source code for hyperion.torch.adv_attacks.rand_fgsm_attack

"""
 Copyright 2020 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
import torch

from .adv_attack import AdvAttack


[docs]class RandFGSMAttack(AdvAttack):
[docs] def __init__( self, model, eps, alpha, loss=None, targeted=False, range_min=None, range_max=None, ): super().__init__(model, loss, targeted, range_min, range_max) assert alpha < eps, "alpha({}) >= eps({})".format(alpha, eps) self.eps = eps self.alpha = alpha
@property def attack_info(self): info = super().attack_info new_info = { "eps": self.eps, "alpha": self.alpha, "threat_model": "linf", "attack_type": "rand-fgsm", } info.update(new_info) return info
[docs] def generate(self, input, target): x = input + self.alpha * torch.randn_like(input).sign() x.requires_grad = True output = self.model(x) loss = self.loss(output, target) self.model.zero_grad() loss.backward() dL_x = x.grad.data f = 1 if self.targeted: f = -1 adv_ex = x + f * (self.eps - self.alpha) * dL_x.sign() return self._clamp(adv_ex)