Source code for hyperion.torch.models.vae.vq_vae

"""
 Copyright 2020 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import logging

import torch
import torch.nn as nn
import torch.distributions as pdf

from ...torch_model import TorchModel
from ...narchs import TorchNALoader
from ...layers import tensor2pdf as t2pdf
from ...layers import vq


[docs]class VQVAE(TorchModel): """Vector Quantized Variational Autoencoder class From: https://arxiv.org/abs/1711.00937 Attributes: encoder_net: NArch encoder network object decoder_net: NArch decoder network object z_dim: latent variable dimension kldiv_weight: weight KL divergene when computing ELBO diversity_weight: weigth for log-perplexity of the codebook, it inteds to maximize the number of codewords used. vq_type: type of vector quantizer vq_gropus: number of vector quantization groups. vq_clusters: number of codewords in each vq group vq_commitment_cost: weigth of the commitmenet loss vq_ema_gamma: exponential moving average decay coeff. vq_ema_eps: Laplace smoothing parameter px_pdf: type of prob distribution for the data likelihood flatten_spatial: if True all time/spatial dimensions are generated from a single latent vector, if False, we have multiple latents depending on the data size. spatial_shape: shape of the data, only needed if flatten_spatial=True scale_invariant: for future use data_scale = for future use """
[docs] def __init__( self, encoder_net, decoder_net, z_dim, kldiv_weight=1, diversity_weight=0.1, vq_type="multi-ema-k-means-vq", vq_groups=1, vq_clusters=64, vq_commitment_cost=0.25, vq_ema_gamma=0.99, vq_ema_eps=1e-5, px_pdf="normal-glob-diag-cov", flatten_spatial=False, spatial_shape=None, scale_invariant=False, data_scale=None, ): super().__init__() self.encoder_net = encoder_net self.decoder_net = decoder_net self.z_dim = z_dim self.px_pdf = px_pdf self.kldiv_weight = kldiv_weight self.diversity_weight = diversity_weight self.vq_type = vq_type self.vq_groups = vq_groups self.vq_clusters = vq_clusters self.vq_commitment_cost = vq_commitment_cost self.vq_ema_gamma = vq_ema_gamma self.vq_ema_eps = vq_ema_eps self.flatten_spatial = flatten_spatial self.spatial_shape = spatial_shape self.scale_invariant = scale_invariant self.data_scale = data_scale # infer input feat dimension from encoder network in_shape = encoder_net.in_shape() # number of dimension of input/output enc/dec tensors, # needed to connect the blocks self._enc_in_dim = len(in_shape) self._enc_out_dim = self.encoder_net.out_dim() self._dec_in_dim = self.decoder_net.in_dim() self._dec_out_dim = self.decoder_net.out_dim() # we assume conv nnets with channel in dimension 1 self.in_channels = in_shape[1] self._enc_out_channels = self.encoder_net.out_shape()[1] self._dec_out_channels = self.decoder_net.out_shape()[1] if self.flatten_spatial: self._compute_flatten_unflatten_shapes() qz_in_channels = self._enc_out_tot_feats qz_in_dim = 2 else: qz_in_channels = self._enc_out_channels qz_in_dim = self._enc_out_dim self._make_post_enc_layer() self._make_pre_dec_layer() self._make_post_dec_layer() self._make_vq_layer(qz_in_channels, qz_in_dim) self.t2px = self._make_t2pdf_layer( px_pdf, self._dec_out_channels, self.in_channels, self._dec_out_dim )
def _compute_flatten_unflatten_shapes(self): # if we flatten the spatial dimension to have a single # latent representation for all time/spatial positions # we have to infer the spatial dimension at the encoder # output assert ( spatial_shape is not None ), "you need to specify spatial shape at the input" enc_in_shape = None, self.in_channels, *self.spatial_shape enc_out_shape = self.encoder_net.out_shape(enc_in_shape) self._enc_out_shape = enc_out_shape[1:] # this is the total number of flattened features at the encoder output enc_out_tot_feats = 1 for d in self._enc_out_shape: enc_out_tot_feats *= d self._enc_out_tot_feats = enc_out_tot_feats # now we infer the shape at the decoder input dec_in_shape = self.decoder_net.in_shape() # we keep the spatial dims at the encoder output self._dec_in_shape = dec_in_shape[1], *enc_out_shape[2:] # this is the total number of flattened features at the decoder input dec_in_tot_feats = 1 for d in self._enc_in_shape: dec_in_tot_feats *= d self._dec_in_tot_feats = dec_in_tot_feats def _flatten(self, x): return x.view(-1, self._enc_out_tot_feats) def _unflatten(sef, x): return x.view(-1, *self._dec_in_shape) def _make_t2pdf_layer(self, pdf_name, in_channels, channels, ndims): pdf_dict = { "normal-i-cov": t2pdf.Tensor2NormalICov, "normal-glob-diag-cov": t2pdf.Tensor2NormalGlobDiagCov, "normal-diag-cov": t2pdf.Tensor2NormalDiagCov, } t2pdf_layer = pdf_dict[pdf_name](channels, in_feats=in_channels, in_dim=ndims) return t2pdf_layer def _make_post_enc_layer(self): pass def _make_pre_dec_layer(self): if self.flatten_spatial: self._pre_dec_linear = Linear(self.z_dim, self._dec_in_tot_dim) def _make_post_dec_layer(self): pass def _pre_enc(self, x): if x.dim() == 3 and self._enc_in_dim == 4: return x.unsqueeze(1) return x def _post_enc(self, x): if self.flatten_spatial: x = self._flatten(x) return x def _pre_dec(self, x): if self.flatten_spatial: x = self._prec_dec_linear(x) # linear projection x = self._unflatten(x) return x if self._enc_out_dim == 3 and self._dec_in_dim == 4: return x.unsqueeze(dim=1) if self._enc_out_dim == 4 and self._dec_in_dim == 3: return x.view(x.size(0), -1, x.size(-1)) return x def _make_vq_layer(self, in_feats, in_dim): if self.vq_type == "multi-k-means-vq": vq_layer = vq.MultiKMeansVectorQuantizer( self.vq_groups, self.vq_clusters, self.z_dim, self.vq_commitment_cost, in_feats=in_feats, in_dim=in_dim, ) elif self.vq_type == "multi-ema-k-means-vq": vq_layer = vq.MultiEMAKMeansVectorQuantizer( self.vq_groups, self.vq_clusters, self.z_dim, self.vq_commitment_cost, self.vq_ema_gamma, self.vq_ema_eps, in_feats=in_feats, in_dim=in_dim, ) elif self.vq_type == "k-means-vq": vq_layer = vq.KMeansVectorQuantizer( self.vq_clusters, self.z_dim, self.vq_commitment_cost, in_feats=in_feats, in_dim=in_dim, ) elif self.vq_type == "ema-k-means-vq": vq_layer = vq.EMAKMeansVectorQuantizer( self.vq_clusters, self.z_dim, self.vq_commitment_cost, self.vq_ema_gamma, self.vq_ema_eps, in_feats=in_feats, in_dim=in_dim, ) else: raise ValueError("vq_type=%s not supported" % (self.vq_type)) self.vq_layer = vq_layer
[docs] def forward( self, x, x_target=None, return_x_mean=False, return_x_sample=False, return_z_sample=False, return_px=False, serialize_pdfs=True, use_amp=False, ): if use_amp: with torch.cuda.amp.autocast(): return self._forward( x, x_target, return_x_mean, return_x_sample, return_z_sample, return_px, serialize_pdfs, ) return self._forward( x, x_target, return_x_mean, return_x_sample, return_z_sample, return_px, serialize_pdfs, )
def _forward( self, x, x_target=None, return_x_mean=False, return_x_sample=False, return_z_sample=False, return_px=False, serialize_pdfs=True, ): if x_target is None: x_target = x xx = self._pre_enc(x) xx = self.encoder_net(xx) xx = self._post_enc(xx) vq_output = self.vq_layer(xx) # extract the variables from the dict. z, vq_loss, kldiv_z, log_perplexity = ( vq_output[i] for i in ["z_q", "loss", "kldiv_qrpr", "log_perplexity"] ) zz = self._pre_dec(z) zz = self.decoder_net(zz, target_shape=x_target.shape) squeeze_dim = None if x_target.dim() == 3 and zz.dim() == 4: squeeze_dim = 1 px = self.t2px(zz, squeeze_dim=squeeze_dim) # we normalize the elbo by spatial/time samples and feature dimension log_px = px.log_prob(x_target).view(x.size(0), -1) num_samples = log_px.size(-1) log_px = log_px.mean(dim=-1) # kldiv must be normalized by number of elements in x, not in z!! kldiv_z /= num_samples elbo = log_px - self.kldiv_weight * kldiv_z loss = -elbo + vq_loss - self.diversity_weight * log_perplexity # we build the return dict r = { "loss": loss, "elbo": elbo, "log_px": log_px, "kldiv_z": kldiv_z, "vq_loss": vq_loss, "log_perplexity": log_perplexity, } if return_x_mean: r["x_mean"] = px.mean if return_x_sample: if px.has_rsample: x_sample = px.rsample() else: x_sample = px.sample() r["x_sample"] = x_sample if return_z_sample: r["z"] = z return r
[docs] def compute_z(self, x): x = self._pre_enc(x) xx = self.encoder_net(xx) xx = self._post_enc(xx) vq_output = self.vq_layer(xx) return vq_output["z"]
[docs] def compute_px_given_z(self, z, x_shape=None): zz = self._pre_dec(z) zz = self.decoder_net(zz, target_shape=x_shape) squeeze_dim = None if x_target.dim() == 3 and zz.dim() == 4: squeeze_dim = 1 px = self.t2px(zz, squeeze_dim=squeeze_dim) return px
[docs] def get_config(self): enc_cfg = self.encoder_net.get_config() dec_cfg = self.decoder_net.get_config() config = { "encoder_cfg": enc_cfg, "decoder_cfg": dec_cfg, "z_dim": self.z_dim, "vq_type": self.vq_type, "vq_groups": self.vq_groups, "vq_clusters": self.vq_clusters, "vq_commitment_cost": self.vq_commitment_cost, "vq_ema_gamma": self.vq_ema_gamma, "vq_ema_eps": self.vq_ema_eps, "px_pdf": self.px_pdf, "kldiv_weight": self.kldiv_weight, "diversity_weight": self.diversity_weight, "flatten_spatial": self.flatten_spatial, "spatial_shape": self.spatial_shape, "scale_invariant": self.scale_invariant, "data_scale": self.data_scale, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] @classmethod def load(cls, file_path=None, cfg=None, state_dict=None): cfg, state_dict = cls._load_cfg_state_dict(file_path, cfg, state_dict) encoder_net = TorchNALoader.load_from_cfg(cfg=cfg["encoder_cfg"]) decoder_net = TorchNALoader.load_from_cfg(cfg=cfg["decoder_cfg"]) for k in ("encoder_cfg", "decoder_cfg"): del cfg[k] model = cls(encoder_net, decoder_net, **cfg) if state_dict is not None: model.load_state_dict(state_dict) return model
[docs] @staticmethod def filter_args(**kwargs): valid_args = ( "z_dim", "kldiv_weight", "diversity_weight", "vq_type", "vq_groups", "vq_clusters", "vq_commitment_cost", "vq_ema_gamma", "vq_ema_eps", "px_pdf", ) args = dict((k, kwargs[k]) for k in valid_args if k in kwargs) return args
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") parser.add_argument( "--z-dim", type=int, required=True, help=("latent factor dimension") ) parser.add_argument( "--kldiv-weight", default=1, type=float, help=("weight of the KL divergance in the ELBO"), ) parser.add_argument( "--diversity-weight", default=0.1, type=float, help=("weight of the log-perplexity in the loss"), ) parser.add_argument( "--vq-type", default="ema-k-means-vq", choices=[ "k-means-vq", "multi-k-means-vq", "ema-k-means-vq", "multi-ema-k-means-vq", ], help=("type of vector quantization layer"), ) parser.add_argument( "--vq-groups", default=1, type=int, help=("number of groups in mulit-vq layers"), ) parser.add_argument( "--vq-clusters", default=64, type=int, help=("size of the codebooks") ) parser.add_argument( "--vq-commitment-cost", default=0.25, type=float, help=("commitment loss weight (beta in VQ-VAE paper)"), ) parser.add_argument( "--vq-ema-gamma", default=0.99, type=float, help=( "decay parameter for exponential moving " "average calculation of the embeddings" ), ) parser.add_argument( "--vq-ema-eps", default=1e-5, type=float, help=( "pseudo-count value for Laplace smoothing " "of cluster counts for exponential moving " "avarage calculation of the embeddings" ), ) parser.add_argument( "--px-pdf", default="normal-glob-diag-cov", choices=["normal-i-cov", "normal-glob-diag-cov", "normal-diag-cov"], help=("pdf for data likelihood p(x|z)"), ) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='vae options') add_argparse_args = add_class_args