"""
Copyright 2020 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import logging
import torch
import torch.nn as nn
import torch.distributions as pdf
from ...torch_model import TorchModel
from ...narchs import TorchNALoader
from ...layers import tensor2pdf as t2pdf
from ...layers import pdf_storage
[docs]class VAE(TorchModel):
"""Variational Autoencoder class
From: https://arxiv.org/abs/1312.6114
Attributes:
encoder_net: NArch encoder network object
decoder_net: NArch decoder network object
z_dim: latent variable dimension
kldiv_weight: weight KL divergene when computing ELBO
qz_pdf: type of prob distribution of the approx. latent posterior
pz_pdf: type of prob distribution of the latent prior
px_pdf: type of prob distribution for the data likelihood
flatten_spatial: if True all time/spatial dimensions are generated from a single latent vector,
if False, we have multiple latents depending on the data size.
spatial_shape: shape of the data, only needed if flatten_spatial=True
scale_invariant: for future use
data_scale = for future use
"""
[docs] def __init__(
self,
encoder_net,
decoder_net,
z_dim,
kldiv_weight=1,
qz_pdf="normal-glob-diag-cov",
pz_pdf="std-normal",
px_pdf="normal-glob-diag-cov",
flatten_spatial=False,
spatial_shape=None,
scale_invariant=False,
data_scale=None,
):
super().__init__()
self.encoder_net = encoder_net
self.decoder_net = decoder_net
self.z_dim = z_dim
self.qz_pdf = qz_pdf
self.pz_pdf = pz_pdf
self.px_pdf = px_pdf
self.kldiv_weight = kldiv_weight
self.flatten_spatial = flatten_spatial
self.spatial_shape = spatial_shape
self.scale_invariant = scale_invariant
self.data_scale = data_scale
# infer input feat dimension from encoder network
in_shape = encoder_net.in_shape()
# number of dimensions of input/output enc/dec tensors,
# needed to connect the blocks
self._enc_in_dim = len(in_shape)
self._enc_out_dim = self.encoder_net.out_dim()
self._dec_in_dim = self.decoder_net.in_dim()
self._dec_out_dim = self.decoder_net.out_dim()
# we assume conv nnets with channel in dimension 1
self.in_channels = in_shape[1]
self._enc_out_channels = self.encoder_net.out_shape()[1]
self._dec_out_channels = self.decoder_net.out_shape()[1]
if self.flatten_spatial:
self._compute_flatten_unflatten_shapes()
qz_in_channels = self._enc_out_tot_feats
qz_in_dim = 2
else:
qz_in_channels = self._enc_out_channels
qz_in_dim = self._enc_out_dim
self._make_post_enc_layer()
self._make_pre_dec_layer()
self._make_post_dec_layer()
self.t2qz = self._make_t2pdf_layer(
qz_pdf, qz_in_channels, self.z_dim, qz_in_dim
)
self.t2px = self._make_t2pdf_layer(
px_pdf, self._dec_out_channels, self.in_channels, self._dec_out_dim
)
self._make_prior()
@property
def pz(self):
return self._pz()
def _compute_flatten_unflatten_shapes(self):
# if we flatten the spatial dimension to have a single
# latent representation for all time/spatial positions
# we have to infer the spatial dimension at the encoder
# output
assert (
spatial_shape is not None
), "you need to specify spatial shape at the input"
enc_in_shape = None, self.in_channels, *self.spatial_shape
enc_out_shape = self.encoder_net.out_shape(enc_in_shape)
self._enc_out_shape = enc_out_shape[1:]
# this is the total number of flattened features at the encoder output
enc_out_tot_feats = 1
for d in self._enc_out_shape:
enc_out_tot_feats *= d
self._enc_out_tot_feats = enc_out_tot_feats
# now we infer the shape at the decoder input
dec_in_shape = self.decoder_net.in_shape()
# we keep the spatial dims at the encoder output
self._dec_in_shape = dec_in_shape[1], *enc_out_shape[2:]
# this is the total number of flattened features at the decoder input
dec_in_tot_feats = 1
for d in self._enc_in_shape:
dec_in_tot_feats *= d
self._dec_in_tot_feats = dec_in_tot_feats
def _flatten(self, x):
return x.view(-1, self._enc_out_tot_feats)
def _unflatten(sef, x):
return x.view(-1, *self._dec_in_shape)
def _make_prior(self):
if self.flatten_spatial:
shape = (self.z_dim,)
else:
shape = self.z_dim, *(1,) * (self._enc_out_dim - 2)
if self.pz_pdf == "std-normal":
self._pz = pdf_storage.StdNormal(shape)
else:
raise ValueError("pz=%s not supported" % self.pz_pdf)
def _make_t2pdf_layer(self, pdf_name, in_channels, channels, ndims):
pdf_dict = {
"normal-i-cov": t2pdf.Tensor2NormalICov,
"normal-glob-diag-cov": t2pdf.Tensor2NormalGlobDiagCov,
"normal-diag-cov": t2pdf.Tensor2NormalDiagCov,
"bay-normal-i-cov": t2pdf.Tensor2BayNormalICovGivenNormalPrior,
"bay-normal-glob-diag-cov": t2pdf.Tensor2BayNormalGlobDiagCovGivenNormalPrior,
"bay-normal-diag-cov": t2pdf.Tensor2BayNormalDiagCovGivenNormalPrior,
}
t2pdf_layer = pdf_dict[pdf_name](channels, in_feats=in_channels, in_dim=ndims)
return t2pdf_layer
def _make_post_enc_layer(self):
pass
def _make_pre_dec_layer(self):
if self.flatten_spatial:
self._pre_dec_linear = Linear(self.z_dim, self._dec_in_tot_dim)
def _make_post_dec_layer(self):
pass
def _pre_enc(self, x):
if x.dim() == 3 and self._enc_in_dim == 4:
return x.unsqueeze(1)
return x
def _post_enc(self, x):
if self.flatten_spatial:
x = self._flatten(x)
return x
def _pre_dec(self, x):
if self.flatten_spatial:
x = self._prec_dec_linear(x) # linear projection
x = self._unflatten(x)
return x
if self._enc_out_dim == 3 and self._dec_in_dim == 4:
return x.unsqueeze(dim=1)
if self._enc_out_dim == 4 and self._dec_in_dim == 3:
return x.view(x.size(0), -1, x.size(-1))
return x
def _post_px(self, px, x_shape):
px_shape = px.batch_shape
if len(px_shape) == 4 and len(x_shape) == 3:
if px_shape[1] == 1:
px = squeeze_pdf(px, dim=1)
else:
raise ValueError("P(x|z)-shape != x-shape")
return px
[docs] def forward(
self,
x,
x_target=None,
return_x_mean=False,
return_x_sample=False,
return_z_sample=False,
return_px=False,
return_qz=False,
serialize_pdfs=True,
use_amp=False,
):
if use_amp:
with torch.cuda.amp.autocast():
return self._forward(
x,
x_target,
return_x_mean,
return_x_sample,
return_z_sample,
return_px,
return_qz,
serialize_pdfs,
)
return self._forward(
x,
x_target,
return_x_mean,
return_x_sample,
return_z_sample,
return_px,
return_qz,
serialize_pdfs,
)
def _forward(
self,
x,
x_target=None,
return_x_mean=False,
return_x_sample=False,
return_z_sample=False,
return_px=False,
return_qz=False,
serialize_pdfs=True,
):
if x_target is None:
x_target = x
x = self._pre_enc(x)
xx = self.encoder_net(x)
xx = self._post_enc(xx)
qz = self.t2qz(xx, prior=self._pz())
# print(qz)
# print(self.pz)
# print(qz.loc)
# print(qz.scale)
# print(self.pz.loc)
# print(self.pz.scale)
kldiv_qzpz = (
pdf.kl.kl_divergence(qz, self._pz()).view(x.size(0), -1).sum(dim=-1)
)
z = qz.rsample()
zz = self._pre_dec(z)
zz = self.decoder_net(zz, target_shape=x_target.shape)
squeeze_dim = None
if x_target.dim() == 3 and zz.dim() == 4:
squeeze_dim = 1
px = self.t2px(zz, squeeze_dim=squeeze_dim)
# we normalize the elbo by spatial/time samples and feature dimension
log_px = px.log_prob(x_target).view(x.size(0), -1)
num_samples = log_px.size(-1)
log_px = log_px.mean(dim=-1)
# kldiv must be normalized by number of elements in x, not in z!!
kldiv_qzpz /= num_samples
elbo = log_px - self.kldiv_weight * kldiv_qzpz
# we build the return dict
r = {"elbo": elbo, "log_px": log_px, "kldiv_z": kldiv_qzpz}
if return_x_mean:
r["x_mean"] = px.mean
if return_x_sample:
if px.has_rsample:
x_sample = px.rsample()
else:
x_sample = px.sample()
r["x_sample"] = x_sample
if return_z_sample:
r["z"] = z
return r
[docs] def compute_qz(self, x):
xx = self._pre_enc(x)
xx = self.encoder_net(xx)
xx = self._post_enc(xx)
qz = self.t2qz(xx, self.pz)
return qz
[docs] def compute_px_given_z(self, z, x_shape=None):
zz = self._pre_dec(z)
zz = self.decoder_net(zz, target_shape=x_shape)
zz = self.pre_px(zz)
squeeze_dim = None
if x_target.dim() == 3 and zz.dim() == 4:
squeeze_dim = 1
px = self.t2px(zz, squeeze_dim=squeeze_dim)
return px
[docs] def get_config(self):
enc_cfg = self.encoder_net.get_config()
dec_cfg = self.decoder_net.get_config()
config = {
"encoder_cfg": enc_cfg,
"decoder_cfg": dec_cfg,
"z_dim": self.z_dim,
"qz_pdf": self.qz_pdf,
"pz_pdf": self.pz_pdf,
"px_pdf": self.px_pdf,
"kldiv_weight": self.kldiv_weight,
"flatten_spatial": self.flatten_spatial,
"spatial_shape": self.spatial_shape,
"scale_invariant": self.scale_invariant,
"data_scale": self.data_scale,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs] @classmethod
def load(cls, file_path=None, cfg=None, state_dict=None):
cfg, state_dict = cls._load_cfg_state_dict(file_path, cfg, state_dict)
encoder_net = TorchNALoader.load_from_cfg(cfg=cfg["encoder_cfg"])
decoder_net = TorchNALoader.load_from_cfg(cfg=cfg["decoder_cfg"])
for k in ("encoder_cfg", "decoder_cfg"):
del cfg[k]
model = cls(encoder_net, decoder_net, **cfg)
if state_dict is not None:
model.load_state_dict(state_dict)
return model
[docs] @staticmethod
def filter_args(**kwargs):
valid_args = ("z_dim", "kldiv_weight", "qz_pdf", "px_pdf")
args = dict((k, kwargs[k]) for k in valid_args if k in kwargs)
return args
[docs] @staticmethod
def add_class_args(parser, prefix=None):
if prefix is not None:
outer_parser = parser
parser = ArgumentParser(prog="")
parser.add_argument(
"--z-dim", type=int, required=True, help=("latent factor dimension")
)
parser.add_argument(
"--kldiv-weight",
default=1,
type=float,
help=("weight of the KL divergance in the ELBO"),
)
parser.add_argument(
"--qz-pdf",
default="normal-glob-diag-cov",
choices=[
"normal-i-cov",
"normal-glob-diag-cov",
"normal-diag-cov",
"bay-normal-i-cov",
"bay-normal-glob-diag-cov",
"bay-normal-diag-cov",
],
help=("pdf for approx posterior q(z)"),
)
parser.add_argument(
"--px-pdf",
default="normal-glob-diag-cov",
choices=["normal-i-cov", "normal-glob-diag-cov", "normal-diag-cov"],
help=("pdf for data likelihood p(x|z)"),
)
if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='vae options')
add_argparse_args = add_class_args