"""
Copyright 2020 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from .adv_attack import AdvAttack
[docs]class CarliniWagner(AdvAttack):
[docs] def __init__(
self,
model,
confidence=0.0,
lr=1e-2,
max_iter=10000,
abort_early=True,
initial_c=1e-3,
norm_time=False,
time_dim=None,
use_snr=False,
targeted=False,
range_min=None,
range_max=None,
):
super().__init__(model, None, targeted, range_min, range_max)
self.confidence = confidence
self.lr = lr
self.max_iter = max_iter
self.abort_early = abort_early
self.initial_c = initial_c
self.is_binary = None
self.box_scale = (self.range_max - self.range_min) / 2
self.box_bias = (self.range_max + self.range_min) / 2
self.norm_time = norm_time
self.time_dim = time_dim
self.use_snr = use_snr
@property
def attack_info(self):
info = super().attack_info
new_info = {
"confidence": self.confidence,
"lr": self.lr,
"max_iter": self.max_iter,
"abort_early": self.abort_early,
"initial_c": self.initial_c,
"norm_time": self.norm_time,
"use_snr": self.use_snr,
}
info.update(new_info)
return info
[docs] @staticmethod
def atanh(x, eps=1e-6):
x = (1 - eps) * x
return 0.5 * torch.log((1 + x) / (1 - x))
[docs] def x_w(self, w):
return self.box_scale * torch.tanh(w) + self.box_bias
[docs] def w_x(self, x):
return self.atanh((x - self.box_bias) / self.box_scale)
[docs] def f(self, z, target):
if self.is_binary:
z_t = z.clone()
z_t[target == 0] *= -1
z_other = 0
else:
idx = torch.arange(0, z.shape[0], device=z.device)
z_t = z[idx, target]
z_clone = z.clone()
z_clone[idx, target] = -1e10
z_other = torch.max(z_clone, dim=-1)[0]
if self.targeted:
f = F.relu(z_other - z_t + self.confidence) # max(0, z_other-z_target+k)
else:
f = F.relu(z_t - z_other + self.confidence) # max(0, z_target-z_other+k)
return f
[docs] def generate(self, input, target):
raise NotImplementedError()