"""
Copyright 2020 Johns Hopkins University (Author: Jesus Villalba, Nanxin Chen)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
[docs]class VectorQuantizer(nn.Module):
[docs] def __init__(
self, num_embed, embed_feats, project=True, in_feats=None, in_dim=None
):
super().__init__()
self.num_embed = num_embed
self.embed_feats = embed_feats
self.project = project
self._proj = None
if project:
assert (
in_feats is not None
), "input channels must be given to make the projection"
assert (
in_dim is not None
), "input tensor dim must be given to make the projection"
self._proj = self._make_proj(in_feats, embed_feats, in_dim)
elif in_feats is not None:
assert in_feats == embed_feats, (
"in_feats (%d) != embed_feats (%), which is required when project=False"
% (in_feats, embed_feats)
)
else:
in_feats = embed_feats
self.in_feats = in_feats
self.in_dim = in_dim
def __repr__(self):
return self.__str__()
def _make_proj(self, in_feats, out_feats, ndims):
if ndims == 2:
return nn.Linear(in_feats, out_feats)
elif ndims == 3:
return nn.Conv1d(in_feats, out_feats, kernel_size=1)
elif ndims == 4:
return nn.Conv2d(in_feats, out_feats, kernel_size=1)
elif ndims == 5:
return nn.Conv3d(in_feats, out_feats, kernel_size=1)
else:
raise ValueError("ndim=%d is not supported" % ndims)
[docs]class KMeansVectorQuantizer(VectorQuantizer):
[docs] def __init__(
self,
num_embed,
embed_feats,
commitment_cost=0.25,
project=True,
in_feats=None,
in_dim=None,
):
super().__init__(
num_embed, embed_feats, project=project, in_feats=in_feats, in_dim=in_dim
)
self.commitment_cost = commitment_cost
self.embed = nn.Parameter(torch.empty(num_embed, embed_feats))
# this how it is init in DeepMind code:
# self.embed.weight.data.uniform_(-math.sqrt(3)/math.sqrt(num_embed), math.sqrt(3)/math.sqrt(num_embed))
# or equivalently:
# nn.init.kaiming_uniform_(self.embed.weight, mode='fan_in', nonlinearity='linear')
# normal seems to give a little better result, but not much, still we need to explore the best init.
nn.init.normal_(self.embed, std=1.0)
self._log_num_embed = math.log(num_embed)
def __str__(self):
s = (
"{}(num_embed={}, embed_feats={}, commitment_cost={}, project={}, "
"in_feats={}, in_dim={})"
).format(
self.__class__.__name__,
self.num_embed,
self.embed_feats,
self.commitment_cost,
self.project,
self.in_feats,
self.in_dim,
)
return s
[docs] def forward(self, inputs, return_r=False):
# inputs -> z_e in paper
if self.project:
inputs = self._proj(inputs)
# convert inputs from BCHW -> BHWC
inputs = inputs.transpose(1, -1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_inputs = inputs.view(-1, self.embed_feats)
# Calculate distances
d2 = (
torch.sum(flat_inputs ** 2, dim=1, keepdim=True)
+ torch.sum(self.embed ** 2, dim=1)
- 2 * torch.matmul(flat_inputs, self.embed.t())
)
# Encoding
# quantization integer indexes
q_idx = torch.argmin(d2, dim=1).unsqueeze(1)
# 1 hot responsibilities
r = torch.zeros(q_idx.shape[0], self.num_embed, device=inputs.device)
r.scatter_(1, q_idx, 1)
z_q = torch.matmul(r, self.embed).view(input_shape)
# Loss
vq_loss = F.mse_loss(z_q, inputs.detach())
commitment_loss = F.mse_loss(z_q.detach(), inputs)
loss = vq_loss + self.commitment_cost * commitment_loss
# this allows to backprogate the gradients as if the output were equal to z_e
z_q = inputs + (z_q - inputs).detach()
# compute the perplexity
probs = torch.mean(r, dim=0)
log_perplexity = -torch.sum(probs * torch.log(probs + 1e-10))
# compute KL divergence between r and uniform categorical prior
# KL = \sum_i \log(1/(1/num_embed)) = \sum_i \log(num_embed) for i = all HxH or T elements
# KL is constant so it doesn't contribute to the training
# but we keep it to get a better estimation of the ELBO
# in the paper they don't use it
num_spatial_positions = r.size(0) / inputs.size(0)
kldiv_r = (
self._log_num_embed
* num_spatial_positions
* torch.ones((inputs.size(0), 1), device=inputs.device)
)
# convert quantized from BHWC -> BCHW
z_q = z_q.transpose(1, -1).contiguous()
output = {
"z_q": z_q,
"loss": loss,
"kldiv_qrpr": kldiv_r,
"log_perplexity": log_perplexity,
}
if return_r:
output["r"] = r
return output
[docs]class MultiKMeansVectorQuantizer(VectorQuantizer):
[docs] def __init__(
self,
num_groups,
num_embed,
embed_feats,
commitment_cost=0.25,
project=True,
in_feats=None,
in_dim=None,
):
super().__init__(
num_embed, embed_feats, project=project, in_feats=in_feats, in_dim=in_dim
)
assert (
embed_feats % num_groups == 0
), "VQ latent channels (%d) must be multiple of num_groups (%d)" % (
embed_feats,
num_groups,
)
self.num_groups = num_groups
embed_feats_i = embed_feats // num_groups
self.vq_layers = nn.ModuleList([])
for i in range(num_groups):
vq_i = KMeansVectorQuantizer(
num_embed, embed_feats_i, commitment_cost, project=False
)
self.vq_layers.append(vq_i)
@property
def commitment_cost(self):
return self.vq_layers[0].commitment_cost
def __str__(self):
s = (
"{}(num_groups={}, num_embed={}, embed_feats={}, commitment_cost={}, project={}, "
"in_feats={}, in_dim={})"
).format(
self.__class__.__name__,
self.num_groups,
self.num_embed,
self.embed_feats,
self.commitment_cost,
self.project,
self.in_feats,
self.in_dim,
)
return s
[docs] def forward(self, inputs, return_r=False):
if self.project:
inputs = self._proj(inputs)
inputs = inputs.chunk(self.num_groups, dim=1)
z_q = []
r = []
for i in range(self.num_groups):
output_i = self.vq_layers[i](inputs[i], return_r=return_r)
z_qi = output_i["z_q"]
loss_i = output_i["loss"]
kldiv_ri = output_i["kldiv_qrpr"]
H_i = output_i["log_perplexity"]
z_q.append(z_qi)
if return_r:
r.append(output_i["r"])
if i == 0:
loss = loss_i
kldiv_r = kldiv_ri
H = H_i
else:
loss += loss_i
kldiv_r += kldiv_ri
H += H_i
z_q = torch.cat(tuple(z_q), dim=1)
log_perplexity = H / self.num_groups
output = {
"z_q": z_q,
"loss": loss,
"kldiv_qrpr": kldiv_r,
"log_perplexity": log_perplexity,
}
if return_r:
output["r"] = r
return output
[docs]class EMAKMeansVectorQuantizer(VectorQuantizer):
[docs] def __init__(
self,
num_embed,
embed_feats,
commitment_cost=0.25,
gamma=0.99,
eps=1e-5,
project=True,
in_feats=None,
in_dim=None,
):
super().__init__(
num_embed, embed_feats, project=project, in_feats=in_feats, in_dim=in_dim
)
self.num_embed = num_embed
self.embed_feats = embed_feats
self.commitment_cost = commitment_cost
self.gamma = gamma
self.eps = eps
self.register_buffer("embed", torch.empty(num_embed, embed_feats))
nn.init.normal_(self.embed, std=1.0)
self.register_buffer("_ema_N", torch.zeros(num_embed))
self.register_buffer("_ema_z_acc", torch.empty(num_embed, embed_feats))
nn.init.normal_(self._ema_z_acc, std=1.0)
self._log_num_embed = math.log(num_embed)
def __str__(self):
s = (
"{}(num_embed={}, embed_feats={}, commitment_cost={}, "
"gamma={}, eps={} project={}, in_feats={}, in_dim={})"
).format(
self.__class__.__name__,
self.num_embed,
self.embed_feats,
self.commitment_cost,
self.gamma,
self.eps,
self.project,
self.in_feats,
self.in_dim,
)
return s
[docs] def forward(self, inputs, return_r=False):
# inputs -> z_e in paper
if self.project:
inputs = self._proj(inputs)
# convert inputs from BCHW -> BHWC
inputs = inputs.transpose(1, -1).contiguous()
input_shape = inputs.shape
# Flatten input
flat_inputs = inputs.view(-1, self.embed_feats)
# Calculate distances
d2 = (
torch.sum(flat_inputs ** 2, dim=1, keepdim=True)
+ torch.sum(self.embed ** 2, dim=1)
- 2 * torch.matmul(flat_inputs, self.embed.t())
)
# Encoding
# quantization integer indexes
q_idx = torch.argmin(d2, dim=1).unsqueeze(1)
# 1 hot responsibilities
r = torch.zeros(q_idx.shape[0], self.num_embed, device=inputs.device)
r.scatter_(1, q_idx, 1)
z_q = torch.matmul(r, self.embed).view(input_shape)
# Use Exponetial Moving Average (EMA) to update the embedding vectors
if self.training:
N = torch.sum(r, dim=0)
# required to sync gpus in DDP
dist.all_reduce(N, op=dist.ReduceOp.SUM)
ema_N = self._ema_N * self.gamma + (1 - self.gamma) * N
N_tot = torch.sum(ema_N)
# Laplace smoothing
self._ema_N = (
(ema_N + self.eps) / (N_tot + self.num_embed * self.eps) * N_tot
).detach()
z_acc = torch.matmul(r.t(), flat_inputs)
# required to sync gpus in DDP
dist.all_reduce(z_acc, op=dist.ReduceOp.SUM)
self._ema_z_acc = (
self.gamma * self._ema_z_acc + (1 - self.gamma) * z_acc
).detach()
self.embed = (self._ema_z_acc / self._ema_N.unsqueeze(1)).detach()
# Loss
commitment_loss = F.mse_loss(z_q.detach(), inputs)
loss = self.commitment_cost * commitment_loss
# this allows to backprogate the gradients as if the output were equal to z_e
z_q = inputs + (z_q - inputs).detach()
# compute the perplexity
probs = torch.mean(r, dim=0)
log_perplexity = -torch.sum(probs * torch.log(probs + 1e-10))
# compute KL divergence between r and uniform categorical prior
# KL = \sum_i \log(1/(1/num_embed)) = \sum_i \log(num_embed) for i = all HxH or T elements
# KL is constant so it doesn't contribute to the training
# but we keep it to get a better estimation of the ELBO
# in the paper they don't use it
num_spatial_positions = r.size(0) / inputs.size(0)
kldiv_r = (
self._log_num_embed
* num_spatial_positions
* torch.ones((inputs.size(0), 1), device=inputs.device)
)
# convert quantized from BHWC -> BCHW
z_q = z_q.transpose(1, -1).contiguous()
output = {
"z_q": z_q,
"loss": loss,
"kldiv_qrpr": kldiv_r,
"log_perplexity": log_perplexity,
}
if return_r:
output["r"] = r
return output
[docs]class MultiEMAKMeansVectorQuantizer(VectorQuantizer):
[docs] def __init__(
self,
num_groups,
num_embed,
embed_feats,
commitment_cost=0.25,
gamma=0.99,
eps=1e-5,
project=True,
in_feats=None,
in_dim=None,
):
super().__init__(
num_embed, embed_feats, project=project, in_feats=in_feats, in_dim=in_dim
)
assert (
embed_feats % embed_feats == 0
), "VQ latent channels (%d) must be multiple of num_groups (%d)" % (
embed_feats,
num_groups,
)
self.num_groups = num_groups
embed_feats_i = embed_feats // num_groups
self.vq_layers = nn.ModuleList([])
for i in range(num_groups):
vq_i = EMAKMeansVectorQuantizer(
num_embed, embed_feats_i, commitment_cost, gamma, eps, project=False
)
self.vq_layers.append(vq_i)
@property
def commitment_cost(self):
return self.vq_layers[0].commitment_cost
@property
def gamma(self):
return self.vq_layers[0].gamma
@property
def eps(self):
return self.vq_layers[0].eps
def __str__(self):
s = (
"{}(num_groups={}, num_embed={}, embed_feats={}, commitment_cost={}, "
"gamma={}, eps={} project={}, in_feats={}, in_dim={})"
).format(
self.__class__.__name__,
self.num_groups,
self.num_embed,
self.embed_feats,
self.commitment_cost,
self.gamma,
self.eps,
self.project,
self.in_feats,
self.in_dim,
)
return s
[docs] def forward(self, inputs, return_r=False):
if self.project:
inputs = self._proj(inputs)
inputs = inputs.chunk(self.num_groups, dim=1)
z_q = []
r = []
for i in range(self.num_groups):
output_i = self.vq_layers[i](inputs[i])
z_qi = output_i["z_q"]
loss_i = output_i["loss"]
kldiv_ri = output_i["kldiv_qrpr"]
H_i = output_i["log_perplexity"]
z_q.append(z_qi)
if return_r:
r.append(output_i["r"])
if i == 0:
loss = loss_i
kldiv_r = kldiv_ri
H = H_i
else:
loss += loss_i
kldiv_r += kldiv_ri
H += H_i
z_q = torch.cat(tuple(z_q), dim=1)
loss /= self.num_groups
log_perplexity = H / self.num_groups
output = {
"z_q": z_q,
"loss": loss,
"kldiv_qrpr": kldiv_r,
"log_perplexity": log_perplexity,
}
if return_r:
output["r"] = r
return output