Source code for hyperion.pdfs.plda.plda

"""
 Copyright 2018 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import numpy as np
from scipy import linalg as sla

from ...hyp_defs import float_cpu
from ...utils.math import invert_pdmat, invert_trimat, logdet_pdmat
from .plda_base import PLDABase


[docs]class PLDA(PLDABase):
[docs] def __init__( self, y_dim=None, z_dim=None, mu=None, V=None, U=None, D=None, floor_iD=1e-5, update_mu=True, update_V=True, update_U=True, update_D=True, **kwargs ): super(PLDA, self).__init__(y_dim=y_dim, mu=mu, update_mu=update_mu, **kwargs) self.z_dim = z_dim if V is not None: self.y_dim = V.shape[0] if U is not None: self.z_dim = U.shape[0] self.V = V self.U = U self.D = D self.floor_iD = floor_iD self.update_V = update_V self.update_U = update_U self.update_D = update_D # aux. vars self._DU = None self._Jt = None self._Lz = None self._mult_iLz = None self._log_Lz = None self._W = None self._VW = None self._VWV = None
[docs] def validate(self): assert self.mu.shape[0] >= self.V.shape[0] assert self.mu.shape[0] == self.V.shape[1] assert self.mu.shape[0] >= self.U.shape[0] assert self.mu.shape[0] == self.U.shape[1] assert self.mu.shape[0] == self.D.shape[0]
@property def is_init(self): if self._is_init: return True if ( self.mu is not None and self.V is not None and self.U is not None and self.D is not None ): self.validate() if self._VWV is None: self.compute_aux() self._is_init = True return self._is_init
[docs] def compute_aux(self): DV = self.V * self.D DU = self.U * self.D self._DU = DU self._J = np.dot(self.V, DU.T) self._Lz = np.eye(self.z_dim, dtype=float_cpu()) + np.dot(DU, self.U.T) self._mult_iLz, _, self._log_Lz = invert_pdmat( self._Lz, right_inv=True, return_logdet=True ) DUiLz = self._mult_iLz(DU.T) self._W = np.diag(self.D) - np.dot(DUiLz, DU) self._VW = DV.T - np.dot(DUiLz, self._J.T) self._VWV = np.dot(self.V, self._VW)
[docs] def initialize(self, D): N, F, S = D self.x_dim = F.shape[1] M = F.shape[0] N_tot = np.sum(N) Vytilde = F / N[:, None] mu = np.mean(Vytilde, axis=0) Vy = Vytilde - mu U, s, Vt = sla.svd(Vy, full_matrices=False, overwrite_a=True) V = s[: self.y_dim, None] * Vt[: self.y_dim, :] NVytilde = N[:, None] * Vytilde C = (S - np.dot(NVytilde.T, Vytilde)) / N_tot w, U = sla.eigh(C) U = np.fliplr(U * np.sqrt(w))[:, : self.z_dim].T iD = np.diag(C - np.dot(U.T, U)).copy() iD[iD < self.floor_iD] = self.floor_iD self.mu = mu self.V = V self.U = U self.D = 1 / iD self.compute_aux()
[docs] def compute_py_g_x( self, D, return_cov=False, return_logpy_0=False, return_acc=False ): assert self.is_init N, F, S = D Fc = F - self.mu M = F.shape[0] y_dim = self.y_dim compute_inv = return_cov or return_acc return_tuple = compute_inv or return_logpy_0 N_is_int = False if np.all(np.ceil(N) == N): N_is_int = True I = np.eye(y_dim, dtype=float_cpu()) gamma = np.dot(Fc, self._VW) if N_is_int: iterator = np.unique(N) else: iterator = range(M) y = np.zeros((M, y_dim), dtype=float_cpu()) if return_cov: Sigma_y = np.zeros((M, y_dim, y_dim), dtype=float_cpu()) else: Sigma_y = None if return_logpy_0: logpy = -0.5 * y_dim * np.log(2 * np.pi) * np.ones((M,), dtype=float_cpu()) if return_acc: Py = np.zeros((y_dim, y_dim), dtype=float_cpu()) Ry = np.zeros((y_dim, y_dim), dtype=float_cpu()) for k in iterator: if N_is_int: i = (N == k).nonzero()[0] N_i = k M_i = len(i) else: i = k N_i = N[k] M_i = 1 L_i = I + N_i * self._VWV r = invert_pdmat( L_i, right_inv=True, return_logdet=return_logpy_0, return_inv=compute_inv, ) mult_iL = r[0] if return_logpy_0: logL = r[2] if compute_inv: iL = r[-1] y[i, :] = mult_iL(gamma[i, :]) if return_cov: Sigma_y[i, :, :] = iL if return_logpy_0: logpy[i] += 0.5 * (logL - np.sum(y[i, :] * gamma[i, :], axis=-1)) if return_acc: Py += M_i * iL Ry += N_i * M_i * iL if not return_tuple: return y r = [y] if return_cov: r += [Sigma_y] if return_logpy_0: r += [logpy] if return_acc: r += [Ry, Py] return tuple(r)
[docs] def Estep(self, D): N, F, S = D y, logpy, Ry, Py = self.compute_py_g_x(D, return_logpy_0=True, return_acc=True) M = F.shape[0] N_tot = np.sum(N) F_tot = np.sum(F, axis=0) # Cy y_acc = np.sum(y, axis=0) Cy = np.dot(F.T, y) # Cz A = np.dot(S - np.dot(F_tot.T, self.mu), self._DU.T) - np.dot(Cy, self._J) Cz = self._mult_iLz(A) # Ry Ry1, Py Niy = y * N[:, None] Ry1 = np.sum(Niy, axis=0) Ry += np.dot(Niy.T, y) Py += np.dot(y.T, y) # acc logpy logpy_acc = np.sum(logpy) # Rz, Pz _, Fc, Sc = self.center_stats(D, self.mu) Fc_acc = np.sum(Fc, axis=0) Rz1 = self._mult_iLz(np.dot(Fc_acc, self._DU.T) - np.dot(Ry1, self._J)) Cbary = Cy - np.outer(self.mu, Ry1) Ryz = self._mult_iLz(np.dot(Cbary.T, self._DU.T) - np.dot(Ry, self._J)) A = np.dot(np.dot(self._DU, Cbary), self._J) B = ( np.dot(np.dot(self._DU, Sc), self._DU.T) - A - A.T + np.dot(np.dot(self._J.T, Ry), self._J) ) B = self._mult_iLz(B) Rz = self._mult_iLz(B.T).T + N_tot * self._mult_iLz( np.eye(self.z_dim, dtype=float_cpu()) ) stats = ( N_tot, M, F_tot, S, logpy_acc, y_acc, Ry1, Ry, Cy, Py, Rz1, Rz, Ryz, Cz, ) return stats
[docs] def elbo(self, stats): N, M, F, S, logpy_x = stats[:5] logD = np.sum(np.log(self.D)) Fmu = np.outer(F, self.mu) Shat = S - Fmu - Fmu.T + N * np.outer(self.mu, self.mu) logpx_y = 0.5 * ( -N * self.x_dim * np.log(2 * np.pi) + N * (logD - self._log_Lz) - np.inner(self._W.ravel(), Shat.ravel()) ) logpy = -0.5 * M * self.y_dim * np.log(2 * np.pi) elbo = logpx_y + logpy - logpy_x return elbo
[docs] def MstepML(self, stats): N, M, F, S, _, y_acc, Ry1, Ry, Cy, Py, Rz1, Rz, Ryz, Cz = stats if self.update_mu and not self.update_V and not self.update_U: self.mu = (F - np.dot(Ry1, self.V) - np.dot(Rz1, self.U)) / N if not self.update_mu and self.update_V and not self.update_U: iRy_mult = invert_pdmat(Ry, right_inv=False)[0] C = Cy.T - np.outer(Ry1, self.mu) - np.dot(Ryz, self.U) self.V = iRy_mult(C) if not self.update_mu and not self.update_V and self.update_U: iRz_mult = invert_pdmat(Rz, right_inv=False)[0] C = Cz.T - np.dot(Ryz.T, self.V) - np.outer(Rz1, self.mu) self.U = iRz_mult(C) if not self.update_mu and self.update_V and self.update_U: a = np.hstack((Ry, Ryz)) b = np.hstack((Ryz.T, Rz)) Rytilde = np.vstack((a, b)) iRytilde_mult = invert_pdmat(Rytilde, right_inv=False)[0] a = Cy.T - np.outer(Ry1, self.mu) b = Cz.T - np.outer(Rz1, self.mu) C = np.vstack((a, b)) Vtilde = iRytilde_mult(C) self.V = Vtilde[: self.y_dim] self.U = Vtilde[self.y_dim :] if self.update_mu and not self.update_V and self.update_U: a = np.hstack((Rz, Rz1[:, None])) b = np.hstack((Rz1, N)) Rytilde = np.vstack((a, b)) iRytilde_mult = invert_pdmat(Rytilde, right_inv=False)[0] a = Cz.T - np.outer(Ryz, self.V) b = F[:, None] - np.outer(Ry1, self.V) C = np.vstack((a, b)) Vtilde = iRytilde_mult(C) self.U = Vtilde[:-1] self.mu = Vtilde[-1] if not self.update_mu and self.update_V and not self.update_U: a = np.hstack((Ry, Ry1[:, None])) b = np.hstack((Ry1, N)) Rytilde = np.vstack((a, b)) iRytilde_mult = invert_pdmat(Rytilde, right_inv=False)[0] a = Cy.T - np.dot(Ryz, self.U) b = F[:, None] - np.dot(Rz1, self.U) C = np.vstack((a, b)) Vtilde = iRytilde_mult(C) self.V = Vtilde[:-1] self.U = Vtilde[-1] a = np.hstack((Ry, Ryz, Ry1[:, None])) b = np.hstack((Ryz.T, Rz, Rz1[:, None])) c = np.hstack((Ry1, Rz1, N)) Rytilde = np.vstack((a, b, c)) Cytilde = np.hstack((Cy, Cz, F[:, None])) if self.update_mu and self.update_V and self.update_U: iRytilde_mult = invert_pdmat(Rytilde, right_inv=False)[0] Vtilde = iRytilde_mult(Cytilde.T) self.V = Vtilde[: self.y_dim, :] self.U = Vtilde[self.y_dim : -1] self.mu = Vtilde[-1] if self.update_D: Vtilde = np.vstack((self.V, self.U, self.mu)) CVt = np.dot(Cytilde, Vtilde) iD = np.diag( (S - CVt - CVt.T + np.dot(np.dot(Vtilde.T, Rytilde), Vtilde)) / N ).copy() iD[iD < self.floor_iD] = self.floor_iD self.D = 1 / iD self.compute_aux()
[docs] def MstepMD(self, stats): N, M, F, S, _, y_acc, Ry1, Ry, Cy, Py, Rz1, Rz, Ryz, Cz = stats mu_y = y_acc / M Cov_y = Py / M - np.outer(mu_y, mu_y) chol_Cov_y = sla.cholesky(Cov_y, lower=False, overwrite_a=True) R = Ry - np.outer(Ry1, Ry1) / N mult_iR = invert_pdmat(R, right_inv=True)[0] H = mult_iR(Ryz.T - np.outer(Rz1, Ry1) / N) mu_z = (Rz1 - np.dot(Ry1, H.T)) / N RzyH = np.dot(Ryz.T, H.T) Cov_z = (Rz - RzyH - RzyH.T + np.dot(np.dot(H, Ry), H.T)) / N - np.outer( mu_z, mu_z ) chol_Cov_z = sla.cholesky(Cov_z, lower=False, overwrite_a=True) if self.update_mu: self.mu += np.dot(mu_y, self.V + np.dot(H.T, self.U)) + np.dot(mu_z, self.U) if self.update_V: self.V = np.dot(chol_Cov_y, self.V + np.dot(H.T, self.U)) if self.update_U: self.U = np.dot(chol_Cov_z, self.U) self.compute_aux()
[docs] def get_config(self): config = { "update_D": self.update_D, "update_U": self.update_U, "update_V": self.update_V, "floor_iD": self.floor_iD, } base_config = super(PLDA, self).get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] def save_params(self, f): params = {"mu": self.mu, "V": self.V, "U": self.U, "D": self.D} self._save_params_from_dict(f, params)
[docs] @classmethod def load_params(cls, f, config): param_list = ["mu", "V", "U", "D"] params = cls._load_params_to_dict(f, config["name"], param_list) kwargs = dict(list(config.items()) + list(params.items())) return cls(**kwargs)
[docs] def log_probx_g_y(self, x, y): iW = np.diag(1 / self.D) + np.dot(self.U.T, self.U) mult_W, _, logiW = invert_pdmat(iW, return_logdet=True) delta = x - self.mu - np.dot(y, self.V) logp = ( -x.shape[-1] * np.log(2 * np.pi) - logiW - np.sum(mult_W(delta) * delta, axis=-1) ) logp /= 2 return logp
[docs] def log_probx_g_yz(self, x, y, z): logD = np.sum(np.log(self.D)) delta = x - self.mu - np.dot(y, self.V) - np.dot(z, self.U) logp = ( -x.shape[-1] * np.log(2 * np.pi) + logD - np.sum(self.D * delta ** 2, axis=-1) ) logp /= 2 return logp
[docs] def llr_1vs1(self, x1, x2): assert self.is_init WV = self._VW VV = self._VWV I = np.eye(self.y_dim, dtype=float_cpu()) Lnon = I + VV mult_icholLnon, logcholLnon = invert_trimat( sla.cholesky(Lnon, lower=False, overwrite_a=True), right_inv=True, return_logdet=True, )[:2] logLnon = 2 * logcholLnon Ltar = I + 2 * VV mult_icholLtar, logcholLtar = invert_trimat( sla.cholesky(Ltar, lower=False, overwrite_a=True), right_inv=True, return_logdet=True, )[:2] logLtar = 2 * logcholLtar VWF1 = np.dot(x1 - self.mu, WV) VWF2 = np.dot(x2 - self.mu, WV) gamma_non_1 = mult_icholLnon(VWF1) gamma_non_2 = mult_icholLnon(VWF2) Qnon_1 = np.sum(gamma_non_1 * gamma_non_1, axis=1)[:, None] Qnon_2 = np.sum(gamma_non_2 * gamma_non_2, axis=1) gamma_tar_1 = mult_icholLtar(VWF1) gamma_tar_2 = mult_icholLtar(VWF2) Qtar_1 = np.sum(gamma_tar_1 * gamma_tar_1, axis=1)[:, None] Qtar_2 = np.sum(gamma_tar_2 * gamma_tar_2, axis=1) scores = 2 * np.dot(gamma_tar_1, gamma_tar_2.T) scores += Qtar_1 - Qnon_1 + Qtar_2 - Qnon_2 scores += 2 * logLnon - logLtar scores *= 0.5 return scores
[docs] def llr_NvsM_book(self, D1, D2): assert self.is_init N1, F1, _ = D1 N2, F2, _ = D2 WV = self._WV VV = self._VWV I = np.eye(self.y_dim, dtype=float_cpu()) F1 -= N1[:, None] * self.mu F2 -= N2[:, None] * self.mu scores = np.zeros((len(N1), len(N2)), dtype=float_cpu()) for N1_i in np.unique(N1): for N2_j in np.unique(N2): i = np.where(N1 == N1_i)[0] j = np.where(N2 == N2_j)[0] L1 = I + N1_i * VV mult_icholL1, logcholL1 = invert_trimat( sla.cholesky(L1, lower=False, overwrite_a=True), right_inv=True, return_logdet=True, )[:2] logL1 = 2 * logcholL1 L2 = I + N2_j * VV mult_icholL2, logcholL2 = invert_trimat( sla.cholesky(L2, lower=False, overwrite_a=True), right_inv=True, return_logdet=True, )[:2] logL2 = 2 * logcholL2 Ltar = I + (N1_i + N2_j) * VV mult_icholLtar, logcholLtar = invert_trimat( sla.cholesky(Ltar, lower=False, overwrite_a=True), right_inv=True, return_logdet=True, )[:2] logLtar = 2 * logcholLtar VWF1 = np.dot(F1[i, :], WV) VWF2 = np.dot(F2[j, :], WV) gamma_non_1 = mult_icholL1(VWF1) gamma_non_2 = mult_icholL2(VWF2) Qnon_1 = np.sum(gamma_non_1 * gamma_non_1, axis=1)[:, None] Qnon_2 = np.sum(gamma_non_2 * gamma_non_2, axis=1) gamma_tar_1 = mult_icholLtar(VWF1) gamma_tar_2 = mult_icholLtar(VWF2) Qtar_1 = np.sum(gamma_tar_1 * gamma_tar_1, axis=1)[:, None] Qtar_2 = np.sum(gamma_tar_2 * gamma_tar_2, axis=1) scores_ij = 2 * np.dot(gamma_tar_1, gamma_tar_2.T) scores_ij += Qtar_1 - Qnon_1 + Qtar_2 - Qnon_2 scores_ij += logL1 + logL2 - logLtar scores[np.ix_(i, j)] = scores_ij scores *= 0.5 return scores
[docs] def sample(self, num_classes, num_samples_per_class, rng=None, seed=1024): if rng is None: rng = np.random.RandomState(seed=seed) x_dim = self.mu.shape[0] z1 = rng.normal(size=(num_classes * num_samples_per_class, x_dim)).astype( dtype=float_cpu(), copy=False ) z1 /= self.D z2 = rng.normal(size=(num_classes * num_samples_per_class, self.z_dim)).astype( dtype=float_cpu(), copy=False ) z2 = np.dot(z2, self.U) y = rng.normal(size=(num_classes, self.y_dim)).astype( dtype=float_cpu(), copy=False ) y = np.dot(y, self.V) + self.mu y = np.repeat(y, num_samples_per_class, axis=0) return y + z1 + z2
[docs] def weighted_avg_params(self, mu, V, U, D, w_mu, w_B, w_W): super(PLDA, self).weigthed_avg_params(mu, w_mu) if w_B > 0: Sb0 = np.dot(self.V.T, self.V) Sb = np.dot(V.T, V) Sb = w_B * Sb + (1 - w_B) * Sb0 w, V = sla.eigh(Sb, overwrite_a=True) V = np.sqrt(w) * V V = V[:, -self.y_dim :] self.V = V.T if w_W > 0: Sw0 = np.dot(self.U.T, self.U) + np.diag(1 / self.D) Sw = np.dot(U.T, U) + np.diag(1 / D) Sw = w_W * Sw + (1 - w_W) * Sw0 w, U = sla.eigh(Sw, overwrite_a=False) U = np.sqrt(w) * U U = U[:, -self.z_dim :] self.U = U.T iD = np.diag(Sw - np.dot(self.U.T, self.U)).copy() # print(Sw[:10,:10]) # print(np.dot(self.U.T, self.U)) # print(iD[:10]) iD[iD < self.floor_iD] = self.floor_iD self.D = 1 / iD
# if w_W > 0: # Sw0 = np.dot(self.U.T, self.U) # Sw = np.dot(U.T, U) # Sw = w_W*Sw + (1-w_W)*Sw0 # w, U = sla.eigh(Sw, overwrite_a=True) # U = np.sqrt(w)*U # U = U[:,-self.z_dim:] # self.U = U.T # if w_D > 0: # Sd0 = 1/self.D # Sd = 1/D # Sd = w_D*Sd + (1-w_D)*Sd0 # self.D = 1/Sd
[docs] def weighted_avg_model(self, plda, w_mu, w_B, w_W): self.weighted_avg_params(plda.mu, plda.V, plda.U, plda.D, w_mu, w_B, w_W)