"""
Copyright 2018 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import numpy as np
import logging
from abc import ABCMeta, abstractmethod
from ...hyp_defs import float_cpu
from ...utils.math import softmax, logsumexp
from ...utils.queues import GeneratorQueue
from ..core import PDF
[docs]class ExpFamilyMixture(PDF):
__metaclass__ = ABCMeta
[docs] def __init__(
self, num_comp=1, pi=None, eta=None, min_N=0, update_pi=True, **kwargs
):
super().__init__(**kwargs)
if pi is not None:
num_comp = len(pi)
self.num_comp = num_comp
self.pi = pi
self.eta = eta
self.min_N = min_N
self.A = None
self._log_pi = None
self.update_pi = update_pi
@property
def is_init(self):
if not self._is_init:
if self.eta is not None and self.A is not None and self.pi is not None:
self.validate()
self._is_init = True
return self._is_init
@property
def log_pi(self):
if self._log_pi is None:
self._log_pi = np.log(self.pi + 1e-15)
return self._log_pi
def _validate_pi(self):
assert len(self.pi) == self.num_comp
[docs] def fit(
self,
x,
sample_weight=None,
x_val=None,
sample_weight_val=None,
epochs=10,
batch_size=None,
):
if not self.is_init:
self.initialize(x)
log_h = self.accum_log_h(x, sample_weight)
if x_val is not None:
log_h_val = self.accum_log_h(x_val, sample_weight_val)
elbo = np.zeros((epochs,), dtype=float_cpu())
elbo_val = np.zeros((epochs,), dtype=float_cpu())
for epoch in range(epochs):
N, u_x = self.Estep(x=x, sample_weight=sample_weight, batch_size=batch_size)
elbo[epoch] = self.elbo(None, N=N, u_x=u_x, log_h=log_h)
self.Mstep(N, u_x)
if x_val is not None:
N, u_x = self.Estep(
x=x_val, sample_weight=sample_weight_val, batch_size=batch_size
)
elbo_val[epoch] = self.elbo(None, N=N, u_x=u_x, log_h=log_h_val)
if x_val is None:
return elbo, elbo / x.shape[0]
else:
return elbo, elbo / x.shape[0], elbo_val, elbo_val / x.shape[0]
[docs] def fit_generator(
self,
generator,
train_steps,
epochs=10,
val_data=None,
val_steps=0,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
):
do_validation = bool(validation_data)
val_gen = (
hasattr(validation_data, "next")
or hasattr(validation_data, "__next__")
or isinstance(validation_data, Sequence)
)
if val_gen and not validation_steps:
raise ValueError(
"When using a generator for validation data, "
"you must specify a value for "
"`validation_steps`."
)
if do_validation and not val_gen:
x, u_x_val, sample_weight_val = self.tuple2data(val_data)
log_h_val = self.accum_log_h(x, sample_weight_val)
elbo = np.zeros((epochs,), dtype=float_cpu())
elbo_val = np.zeros((epochs,), dtype=float_cpu())
for epoch in range(epochs):
N, u_x, log_h = self.Estep_generator(
generator,
train_steps,
return_log_h=True,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
)
self.Mstep(N, u_x)
elbo[epoch] = self.elbo(None, N=N, u_x=u_x, log_h=log_h)
if val_data is not None:
if val_gen:
N, u_x, log_h_val = self.Estep_generator(
generator,
train_steps,
return_log_h=True,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
)
else:
N, u_x = self.Estep(x_val, u_x_val, sample_weight_val)
elbo_val[epoch] = self.elbo(None, N=N, u_x=u_x, log_h=log_h_val)
if x_val is None:
return elbo, elbo / x.shape[0]
else:
return elbo, elbo / x.shape[0], elbo_val, elbo_val / x.shape[0]
[docs] def log_h(self, x):
return 0
[docs] def accum_log_h(self, x, sample_weight=None):
if sample_weight is None:
return np.sum(self.log_h(x))
return np.sum(sample_weight * self.log_h(x))
[docs] def compute_log_pz(self, x, u_x=None, mode="nat"):
if u_x is None:
u_x = self.compute_suff_stats(x)
return np.dot(u_x, self.eta.T) - self.A + self.log_pi
[docs] def compute_pz(self, x, u_x=None, mode="nat"):
if mode == "nat":
return self.compute_pz_nat(x, u_x)
else:
return self.compute_pz_std(x)
[docs] def compute_pz_nat(self, x, u_x=None):
if u_x is None:
u_x = self.compute_suff_stats(x)
logr = np.dot(u_x, self.eta.T) - self.A + self.log_pi
return softmax(logr)
[docs] def compute_pz_std(self, x):
return self.compute_pz_nat(x)
[docs] def compute_suff_stats(self, x):
return x
[docs] def accum_suff_stats(self, x, u_x=None, sample_weight=None, batch_size=None):
if u_x is not None or batch_size is None:
return self._accum_suff_stats_1batch(x, u_x, sample_weight)
else:
return self._accum_suff_stats_nbatches(x, sample_weight, batch_size)
def _accum_suff_stats_1batch(self, x, u_x=None, sample_weight=None):
if u_x is None:
u_x = self.compute_suff_stats(x)
z = self.compute_pz_nat(x, u_x)
if sample_weight is not None:
z *= sample_weight[:, None]
N = np.sum(z, axis=0)
acc_u_x = np.dot(z.T, u_x)
# L_z=gmm.ElnP_z_w(N,gmm.lnw)-gmm.Elnq_z(z);
return N, acc_u_x
def _accum_suff_stats_nbatches(self, x, sample_weight, batch_size):
sw_i = None
for i1 in range(0, x.shape[0], batch_size):
i2 = np.minimum(i1 + batch_size, x.shape[0])
x_i = x[i1:i2, :]
if sample_weight is not None:
sw_i = sample_weight[i1:i2]
N_i, u_x_i = self._accum_suff_stats_1batch(x_i, sample_weight=sw_i)
if i1 == 0:
N = N_i
u_x = u_x_i
else:
N += N_i
u_x += u_x_i
return N, u_x
[docs] def accum_suff_stats_segments(
self, x, segments, u_x=None, sample_weight=None, batch_size=None
):
K = self.num_comp
num_segments = len(segments)
N = np.zeros((num_segments, K), dtype=float_cpu())
acc_u_x = np.zeros((num_segments, K, self.eta.shape[1]), dtype=float_cpu())
u_x_i = None
sw_i = None
for i in range(num_segments):
start = int(segments[i][0])
end = int(segments[i][1]) + 1
x_i = x[start:end]
if u_x is not None:
u_x_i = u_x[start:end]
if sample_weight is not None:
sw_i = sample_weight[start:end]
N_i, acc_u_x_i = self.accum_suff_stats(
x_i, u_x=u_x_i, sample_weight=sw_i, batch_size=batch_size
)
N[i] = N_i
acc_u_x[i] = acc_u_x_i
return N, acc_u_x
[docs] def accum_suff_stats_segments_prob(
self, x, prob, u_x=None, sample_weight=None, batch_size=None
):
if u_x is not None or batch_size is None:
return self._accum_suff_stats_segments_prob_1batch(
x, prob, u_x, sample_weight
)
else:
return self._accum_suff_stats_segments_prob_nbatches(
x, prob, sample_weight, batch_size
)
def _accum_suff_stats_segments_prob_1batch(
self, x, prob, u_x=None, sample_weight=None
):
if u_x is None:
u_x = self.compute_suff_stats(x)
z = self.compute_pz_nat(x, u_x)
if sample_weight is not None:
z *= sample_weight[:, None]
K = len(self.pi)
num_segments = prob.shape[1]
N = np.zeros((num_segments, K), float_cpu())
acc_u_x = np.zeros((num_segments, K, self.eta.shape[1]), float_cpu())
for i in range(num_segments):
z_i = z * prob[:, i][:, None]
N[i] = np.sum(z_i, axis=0)
acc_u_x[i] = np.dot(z_i.T, u_x)
return N, acc_u_x
def _accum_suff_stats_segments_prob_nbatches(
self, x, prob, sample_weight, batch_size
):
sw_i = None
for i1 in range(0, x.shape[0], batch_size):
i2 = np.minimum(i1 + batch_size, x.shape[0])
x_i = x[i1:i2, :]
prob_i = prob[i1:i2, :]
if sample_weight is not None:
sw_i = sample_weight[i1:i2]
N_i, u_x_i = self._accum_suff_stats_segments_prob_1batch(
x_i, prob_i, sample_weight=sw_i
)
if i1 == 0:
N = N_i
u_x = u_x_i
else:
N += N_i
u_x += u_x_i
return N, u_x
[docs] def accum_suff_stats_sorttime(
self,
x,
frame_length,
frame_shift,
u_x=None,
sample_weight=None,
batch_size=None,
):
if u_x is not None or batch_size is None:
return self._accum_suff_stats_sorttime_1batch(
x, frame_length, frame_shift, u_x, sample_weight
)
else:
return self._accum_suff_stats_sorttime_nbatches(
x, frame_length, frame_shift, sample_weight, batch_size
)
def _accum_suff_stats_sorttime_1batch(
self, x, frame_length, frame_shift, u_x=None, sample_weight=None
):
K = len(self.pi)
num_frames = x.shape[0]
num_segments = int(np.floor((num_frames - frame_length) / frame_shift + 1))
if num_segments == 1:
return self._accum_suff_stats_1batch(self, x, u_x, sample_weight)
if u_x is None:
u_x = self.compute_suff_stats(x)
z = self.compute_pz_nat(x, u_x)
if sample_weight is not None:
z *= sample_weight[:, None]
N = np.zeros((num_segments, K), float_cpu())
acc_u_x = np.zeros((num_segments, K, self.eta.shape[1]), float_cpu())
start1 = int(frame_shift - 1)
end1 = int((num_segments - 1) * frame_shift)
start2 = int(start1 + frame_length)
end2 = int(end1 + frame_length)
cum_N = np.cumsum(z, axis=0)
N[0] = cum_N[frame_length - 1]
N[1:] = cum_N[start2:end2:frame_shift] - cum_N[start1:end1:frame_shift]
for k in range(K):
cum_u_x_k = np.cumsum(z[:, k][:, None] * u_x, axis=0)
acc_u_x[0, k] = cum_u_x_k[frame_length - 1]
acc_u_x[1:, k] = (
cum_u_x_k[start2:end2:frame_shift] - cum_u_x_k[start1:end1:frame_shift]
)
return N, acc_u_x
def _accum_suff_stats_sorttime_nbatches(
self, x, frame_length, frame_shift, sample_weight, batch_size
):
K = len(self.pi)
num_frames = x.shape[0]
num_segments = int(np.floor((num_frames - frame_length) / frame_shift + 1))
if num_segments == 1:
return self._accum_suff_stats_1batch(self, x, u_x, sample_weight)
num_segments_per_batch = np.floor((num_frames - frame_length) / frame_shift + 1)
batch_size = int((num_segments_per_batch - 1) * frame_shift + frame_length)
batch_shift = int(num_segments_per_batch * frame_shift)
N = np.zeros((num_segments, K), float_cpu())
acc_u_x = np.zeros((num_segments, K, self.eta.shape[1]), float_cpu())
sw_i = None
cur_segment = 0
for i1 in range(0, x.shape[0], batch_shift):
i2 = np.minimum(i1 + batch_size, x.shape[0])
x_i = x[i1:i2, :]
if sample_weight is not None:
sw_i = sample_weight[i1:i2]
N_i, u_x_i = self._accum_suff_stats_sorttime_1batch(
x_i, frame_length, frame_shift, sample_weight=sw_i
)
num_segments_i = N_i.shape[0]
N[cur_segment : cur_segment + num_segments_i] = N_i
acc_u_x[cur_segment : cur_segment + num_segments_i] = u_x_i
cur_segment += num_segments_i
return N, acc_u_x
[docs] def Estep(self, x, u_x=None, sample_weight=None, batch_size=None):
return self.accum_suff_stats(x, u_x, sample_weight, batch_size)
[docs] def Estep_generator(
self,
generator,
num_steps,
return_log_h,
max_queue_size=10,
workers=1,
use_multiprocessin=False,
):
wait_time = 0.01 # in secs
queue = None
N = None
acc_u_x = None
log_h = 0
try:
queue = GeneratorQueue(
generator, use_multiprocessing=use_multiprocessing, wait_time=wait_time
)
queue.start(workers=workers, max_queue_size=max_queue_size)
queue_generator = queue.get()
cur_step = 0
for cur_step in range(num_steps):
data = next(queue_generator)
x, u_x, sample_weight = self.tuple2data(data)
N_i, u_x_i = self.Estep(x, u_x, sample_weight)
if return_log_h:
log_h += self.accum_log_h(x)
if cur_step == 0:
N = N_i
acc_u_x = u_x_i
else:
N += N_i
acc_u_x += u_x_i
finally:
if enqueuer is not None:
enqueuer.stop()
if return_log_h:
return N, acc_u_x, log_h
else:
return N, acc_u_x
[docs] def sum_suff_stats(self, N, u_x):
assert len(N) == len(u_x)
acc_N = N[1]
acc_u_x = u_x[1]
for i in range(1, len(N)):
acc_N += N
acc_u_x += u[i]
return acc_N, acc_u_x
[docs] @abstractmethod
def Mstep(self, stats):
pass
[docs] def elbo(self, x, u_x=None, N=1, log_h=None, sample_weight=None, batch_size=None):
if u_x is None:
N, u_x = self.accum_suff_stats(
x, sample_weight=sample_weight, batch_size=batch_size
)
if log_h is None:
log_h = self.accum_log_h(x, sample_weight=sample_weight)
return log_h + np.sum(u_x * self.eta) + np.inner(N, self.log_pi - self.A)
[docs] def log_prob(self, x, u_x=None, mode="nat"):
if mode == "nat":
return self.log_prob_nat(x, u_x)
else:
return self.log_prob_std(x)
[docs] def log_prob_nat(self, x, u_x=None):
if u_x is None:
u_x = self.compute_suff_stats(x)
llk_k = np.dot(u_x, self.eta.T) - self.A + self.log_pi
llk = logsumexp(llk_k)
return self.log_h(x) + llk
[docs] @abstractmethod
def log_prob_std(self, x):
pass
[docs] def log_prob_nbest(self, x, u_x=None, mode="nat", nbest_mode="master", nbest=1):
if mode == "nat":
return self.log_prob_nbest_nat(x, u_x, nbest_mode=nbest_mode, nbest=nbest)
else:
return self.log_prob_std(x, nbest_mode=nbest_mode, nbest=nbest)
[docs] def log_prob_nbest_nat(self, x, u_x=None, nbest_mode="master", nbest=1):
if u_x is None:
u_x = self.compute_suff_stats(x)
if nbest_mode == "master":
assert isinstance(nbest, int)
llk_k = np.dot(u_x, self.eta.T) - self.A + self.log_pi
nbest = np.argsort(llk_k)[: -(nbest + 1) : -1]
llk_k = llk_k[nbest]
else:
llk_k = np.dot(u_x, self.eta[nbest, :].T) - self.A + self.log_pi
llk = logsumexp(llk_k)
return self.log_h(x) + llk
[docs] @abstractmethod
def log_prob_nbest_std(self, x, nbest_mode="master", nbest=1):
pass
[docs] def get_config(self):
config = {"min_n": self.min_N, "update_pi": self.update_pi}
base_config = super(ExpFamilyMixture, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs] @staticmethod
def tuple2data(data):
if isinstance(data, tuple):
if len(data) == 2:
x, u_x = data
if u_x.ndim == 2:
sample_weight = None
elif u_x.ndim == 1:
sample_weight = u_x
u_x = None
else:
raise ValueError("Generator output: " + str(data))
elif len(data) == 3:
x, u_x, sample_weight = data
else:
raise ValueError("Generator output: " + str(data))
else:
x = data
u_x = None
sample_weight = None
return x, u_x, sample_weight
[docs] @staticmethod
def compute_A_nat(eta):
raise NotImplementedError()
[docs] @staticmethod
def compute_A_std(params):
raise NotImplementedError()
[docs] @staticmethod
def compute_eta(param):
raise NotImplementedError()
[docs] @staticmethod
def compute_std(eta):
raise NotImplementedError()
@abstractmethod
def _compute_nat_params(self):
pass
@abstractmethod
def _compute_std_params(self):
pass