2022-12-03 16:54:06 +08:00
|
|
|
from models.encoder.params_model import *
|
|
|
|
from models.encoder.params_data import *
|
2021-08-07 11:56:00 +08:00
|
|
|
from scipy.interpolate import interp1d
|
|
|
|
from sklearn.metrics import roc_curve
|
|
|
|
from torch.nn.utils import clip_grad_norm_
|
|
|
|
from scipy.optimize import brentq
|
|
|
|
from torch import nn
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class SpeakerEncoder(nn.Module):
|
|
|
|
def __init__(self, device, loss_device):
|
|
|
|
super().__init__()
|
|
|
|
self.loss_device = loss_device
|
|
|
|
|
|
|
|
# Network defition
|
|
|
|
self.lstm = nn.LSTM(input_size=mel_n_channels,
|
|
|
|
hidden_size=model_hidden_size,
|
|
|
|
num_layers=model_num_layers,
|
|
|
|
batch_first=True).to(device)
|
|
|
|
self.linear = nn.Linear(in_features=model_hidden_size,
|
|
|
|
out_features=model_embedding_size).to(device)
|
|
|
|
self.relu = torch.nn.ReLU().to(device)
|
|
|
|
|
|
|
|
# Cosine similarity scaling (with fixed initial parameter values)
|
|
|
|
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
|
|
|
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
|
|
|
|
|
|
|
# Loss
|
|
|
|
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
|
|
|
|
|
|
|
def do_gradient_ops(self):
|
|
|
|
# Gradient scale
|
|
|
|
self.similarity_weight.grad *= 0.01
|
|
|
|
self.similarity_bias.grad *= 0.01
|
|
|
|
|
|
|
|
# Gradient clipping
|
|
|
|
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
|
|
|
|
|
|
|
def forward(self, utterances, hidden_init=None):
|
|
|
|
"""
|
|
|
|
Computes the embeddings of a batch of utterance spectrograms.
|
|
|
|
|
|
|
|
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
|
|
|
(batch_size, n_frames, n_channels)
|
|
|
|
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
|
|
|
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
|
|
|
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
|
|
|
"""
|
|
|
|
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
|
|
|
# and the final cell state.
|
|
|
|
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
|
|
|
|
|
|
|
# We take only the hidden state of the last layer
|
|
|
|
embeds_raw = self.relu(self.linear(hidden[-1]))
|
|
|
|
|
|
|
|
# L2-normalize it
|
|
|
|
embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
|
|
|
|
|
|
|
|
return embeds
|
|
|
|
|
|
|
|
def similarity_matrix(self, embeds):
|
|
|
|
"""
|
|
|
|
Computes the similarity matrix according the section 2.1 of GE2E.
|
|
|
|
|
|
|
|
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
|
|
|
utterances_per_speaker, embedding_size)
|
|
|
|
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
|
|
|
utterances_per_speaker, speakers_per_batch)
|
|
|
|
"""
|
|
|
|
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
|
|
|
|
|
|
|
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
|
|
|
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
|
|
|
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
|
|
|
|
|
|
|
|
# Exclusive centroids (1 per utterance)
|
|
|
|
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
|
|
|
|
centroids_excl /= (utterances_per_speaker - 1)
|
|
|
|
centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
|
|
|
|
|
|
|
|
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
|
|
|
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
|
|
|
# We vectorize the computation for efficiency.
|
|
|
|
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
|
|
|
|
speakers_per_batch).to(self.loss_device)
|
|
|
|
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
|
|
|
|
for j in range(speakers_per_batch):
|
|
|
|
mask = np.where(mask_matrix[j])[0]
|
|
|
|
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
|
|
|
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
|
|
|
|
|
|
|
## Even more vectorized version (slower maybe because of transpose)
|
|
|
|
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
|
|
|
# ).to(self.loss_device)
|
|
|
|
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
|
|
|
# mask = np.where(1 - eye)
|
|
|
|
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
|
|
|
# mask = np.where(eye)
|
|
|
|
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
|
|
|
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
|
|
|
|
|
|
|
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
|
|
|
return sim_matrix
|
|
|
|
|
|
|
|
def loss(self, embeds):
|
|
|
|
"""
|
|
|
|
Computes the softmax loss according the section 2.1 of GE2E.
|
|
|
|
|
|
|
|
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
|
|
|
utterances_per_speaker, embedding_size)
|
|
|
|
:return: the loss and the EER for this batch of embeddings.
|
|
|
|
"""
|
|
|
|
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
|
|
|
|
|
|
|
# Loss
|
|
|
|
sim_matrix = self.similarity_matrix(embeds)
|
|
|
|
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
|
|
|
|
speakers_per_batch))
|
|
|
|
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
|
|
|
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
|
|
|
loss = self.loss_fn(sim_matrix, target)
|
|
|
|
|
|
|
|
# EER (not backpropagated)
|
|
|
|
with torch.no_grad():
|
|
|
|
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
|
|
|
labels = np.array([inv_argmax(i) for i in ground_truth])
|
|
|
|
preds = sim_matrix.detach().cpu().numpy()
|
|
|
|
|
|
|
|
# Snippet from https://yangcha.github.io/EER-ROC/
|
|
|
|
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
|
|
|
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
|
|
|
|
|
|
|
return loss, eer
|