mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
74a3fc97d0
Need readme
196 lines
8.5 KiB
Python
196 lines
8.5 KiB
Python
from models.encoder.params_data import *
|
|
from models.encoder.model import SpeakerEncoder
|
|
from models.encoder.audio import preprocess_wav # We want to expose this function from here
|
|
from matplotlib import cm
|
|
from models.encoder import audio
|
|
from pathlib import Path
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
|
|
_model = None # type: SpeakerEncoder
|
|
_device = None # type: torch.device
|
|
|
|
|
|
def load_model(weights_fpath: Path, device=None):
|
|
"""
|
|
Loads the model in memory. If this function is not explicitely called, it will be run on the
|
|
first call to embed_frames() with the default weights file.
|
|
|
|
:param weights_fpath: the path to saved model weights.
|
|
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
|
|
model will be loaded and will run on this device. Outputs will however always be on the cpu.
|
|
If None, will default to your GPU if it"s available, otherwise your CPU.
|
|
"""
|
|
# TODO: I think the slow loading of the encoder might have something to do with the device it
|
|
# was saved on. Worth investigating.
|
|
global _model, _device
|
|
if device is None:
|
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
elif isinstance(device, str):
|
|
_device = torch.device(device)
|
|
_model = SpeakerEncoder(_device, torch.device("cpu"))
|
|
checkpoint = torch.load(weights_fpath, _device)
|
|
_model.load_state_dict(checkpoint["model_state"])
|
|
_model.eval()
|
|
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
|
return _model
|
|
|
|
def set_model(model, device=None):
|
|
global _model, _device
|
|
_model = model
|
|
if device is None:
|
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
_device = device
|
|
_model.to(device)
|
|
|
|
def is_loaded():
|
|
return _model is not None
|
|
|
|
|
|
def embed_frames_batch(frames_batch):
|
|
"""
|
|
Computes embeddings for a batch of mel spectrogram.
|
|
|
|
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
|
(batch_size, n_frames, n_channels)
|
|
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
|
"""
|
|
if _model is None:
|
|
raise Exception("Model was not loaded. Call load_model() before inference.")
|
|
|
|
frames = torch.from_numpy(frames_batch).to(_device)
|
|
embed = _model.forward(frames).detach().cpu().numpy()
|
|
return embed
|
|
|
|
|
|
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
|
min_pad_coverage=0.75, overlap=0.5, rate=None):
|
|
"""
|
|
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
|
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
|
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
|
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
|
defined in params_data.py.
|
|
|
|
The returned ranges may be indexing further than the length of the waveform. It is
|
|
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
|
|
|
:param n_samples: the number of samples in the waveform
|
|
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
|
utterance
|
|
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
|
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
|
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
|
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
|
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
|
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
|
utterances are entirely disjoint.
|
|
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
|
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
|
utterances.
|
|
"""
|
|
assert 0 <= overlap < 1
|
|
assert 0 < min_pad_coverage <= 1
|
|
|
|
if rate != None:
|
|
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
|
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
|
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
|
else:
|
|
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
|
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
|
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
|
|
|
assert 0 < frame_step, "The rate is too high"
|
|
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
|
(sampling_rate / (samples_per_frame * partials_n_frames))
|
|
|
|
# Compute the slices
|
|
wav_slices, mel_slices = [], []
|
|
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
|
for i in range(0, steps, frame_step):
|
|
mel_range = np.array([i, i + partial_utterance_n_frames])
|
|
wav_range = mel_range * samples_per_frame
|
|
mel_slices.append(slice(*mel_range))
|
|
wav_slices.append(slice(*wav_range))
|
|
|
|
# Evaluate whether extra padding is warranted or not
|
|
last_wav_range = wav_slices[-1]
|
|
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
|
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
|
mel_slices = mel_slices[:-1]
|
|
wav_slices = wav_slices[:-1]
|
|
|
|
return wav_slices, mel_slices
|
|
|
|
|
|
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
|
|
"""
|
|
Computes an embedding for a single utterance.
|
|
|
|
# TODO: handle multiple wavs to benefit from batching on GPU
|
|
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
|
:param using_partials: if True, then the utterance is split in partial utterances of
|
|
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
|
normalized average. If False, the utterance is instead computed from feeding the entire
|
|
spectogram to the network.
|
|
:param return_partials: if True, the partial embeddings will also be returned along with the
|
|
wav slices that correspond to the partial embeddings.
|
|
:param kwargs: additional arguments to compute_partial_splits()
|
|
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
|
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
|
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
|
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
|
instead.
|
|
"""
|
|
# Process the entire utterance if not using partials
|
|
if not using_partials:
|
|
frames = audio.wav_to_mel_spectrogram(wav)
|
|
embed = embed_frames_batch(frames[None, ...])[0]
|
|
if return_partials:
|
|
return embed, None, None
|
|
return embed
|
|
|
|
# Compute where to split the utterance into partials and pad if necessary
|
|
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
|
|
max_wave_length = wave_slices[-1].stop
|
|
if max_wave_length >= len(wav):
|
|
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
|
|
|
# Split the utterance into partials
|
|
frames = audio.wav_to_mel_spectrogram(wav)
|
|
frames_batch = np.array([frames[s] for s in mel_slices])
|
|
partial_embeds = embed_frames_batch(frames_batch)
|
|
|
|
# Compute the utterance embedding from the partial embeddings
|
|
raw_embed = np.mean(partial_embeds, axis=0)
|
|
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
|
|
|
if return_partials:
|
|
return embed, partial_embeds, wave_slices
|
|
return embed
|
|
|
|
|
|
def embed_speaker(wavs, **kwargs):
|
|
raise NotImplemented()
|
|
|
|
|
|
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
|
|
if ax is None:
|
|
ax = plt.gca()
|
|
|
|
if shape is None:
|
|
height = int(np.sqrt(len(embed)))
|
|
shape = (height, -1)
|
|
embed = embed.reshape(shape)
|
|
|
|
cmap = cm.get_cmap()
|
|
mappable = ax.imshow(embed, cmap=cmap)
|
|
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
|
|
sm = cm.ScalarMappable(cmap=cmap)
|
|
sm.set_clim(*color_range)
|
|
|
|
ax.set_xticks([]), ax.set_yticks([])
|
|
ax.set_title(title)
|