Source code for hyperion.torch.layers.pdf_storage

"""
 Copyright 2020 Johns Hopkins University  (Author: Jesus Villalba, Nanxin Chen)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
#

import torch
import torch.nn as nn
import torch.distributions as pdf


[docs]class StdNormal(nn.Module): """Storage for Standard Normal distribution"""
[docs] def __init__(self, shape): super().__init__() self.register_buffer("loc", torch.zeros(shape)) self.register_buffer("scale", torch.ones(shape))
# self.loc = nn.Parameter(torch.zeros(shape), requires_grad=False) # self.scale = nn.Parameter(torch.ones(shape), requires_grad=False) @property def pdf(self): return pdf.normal.Normal(self.loc, self.scale)
[docs] def forward(self): return self.pdf