Source code for hyperion.torch.layers.interpolate

"""
 Copyright 2021 Johns Hopkins University  (Author: Jesus Villalba, Nanxin Chen)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import torch
import torch.nn as nn
import torch.nn.functional as nnf


[docs]class Interpolate(nn.Module):
[docs] def __init__(self, scale_factor, mode="nearest"): super().__init__() self.interp = nnf.interpolate self.scale_factor = scale_factor self.mode = mode
def __repr__(self): s = "{}(scale_factor={}, mode={})".format( self.__class__.__name__, self.scale_factor, self.mode, ) return s
[docs] def forward(self, x): x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode) return x