mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
114 lines
4.5 KiB
Python
114 lines
4.5 KiB
Python
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
import soundfile
|
|
import resampy
|
|
|
|
from ppg_extractor import load_model
|
|
import encoder.inference as Encoder
|
|
from encoder.audio import preprocess_wav
|
|
from encoder import audio
|
|
from utils.f0_utils import compute_f0
|
|
|
|
from torch.multiprocessing import Pool, cpu_count
|
|
from functools import partial
|
|
|
|
SAMPLE_RATE=16000
|
|
|
|
def _compute_bnf(
|
|
wav: any,
|
|
output_fpath: str,
|
|
device: torch.device,
|
|
ppg_model_local: any,
|
|
):
|
|
"""
|
|
Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
|
|
"""
|
|
ppg_model_local.to(device)
|
|
wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
|
|
wav_length = torch.LongTensor([wav.shape[0]]).to(device)
|
|
with torch.no_grad():
|
|
bnf = ppg_model_local(wav_tensor, wav_length)
|
|
bnf_npy = bnf.squeeze(0).cpu().numpy()
|
|
np.save(output_fpath, bnf_npy, allow_pickle=False)
|
|
return bnf_npy, len(bnf_npy)
|
|
|
|
def _compute_f0_from_wav(wav, output_fpath):
|
|
"""Compute merged f0 values."""
|
|
f0 = compute_f0(wav, SAMPLE_RATE)
|
|
np.save(output_fpath, f0, allow_pickle=False)
|
|
return f0, len(f0)
|
|
|
|
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
|
|
Encoder.set_model(encoder_model_local)
|
|
# Compute where to split the utterance into partials and pad if necessary
|
|
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
|
|
max_wave_length = wave_slices[-1].stop
|
|
if max_wave_length >= len(wav):
|
|
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
|
|
|
# Split the utterance into partials
|
|
frames = audio.wav_to_mel_spectrogram(wav)
|
|
frames_batch = np.array([frames[s] for s in mel_slices])
|
|
partial_embeds = Encoder.embed_frames_batch(frames_batch)
|
|
|
|
# Compute the utterance embedding from the partial embeddings
|
|
raw_embed = np.mean(partial_embeds, axis=0)
|
|
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
|
|
|
np.save(output_fpath, embed, allow_pickle=False)
|
|
return embed, len(embed)
|
|
|
|
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
|
|
# wav = preprocess_wav(wav_path)
|
|
# try:
|
|
wav, sr = soundfile.read(wav_path)
|
|
if len(wav) < sr:
|
|
return None, sr, len(wav)
|
|
if sr != SAMPLE_RATE:
|
|
wav = resampy.resample(wav, sr, SAMPLE_RATE)
|
|
sr = SAMPLE_RATE
|
|
utt_id = os.path.basename(wav_path).rstrip(".wav")
|
|
|
|
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
|
|
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
|
|
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
|
|
|
|
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
|
|
# Glob wav files
|
|
wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
|
|
print(f"Globbed {len(wav_file_list)} wav files.")
|
|
|
|
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
|
|
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
|
|
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
|
|
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
|
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
|
|
if n_processes is None:
|
|
n_processes = cpu_count()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
|
|
job = Pool(n_processes).imap(func, wav_file_list)
|
|
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
|
|
|
|
# finish processing and mark
|
|
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
|
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
|
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
|
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
|
|
id = os.path.basename(file).split(".f0.npy")[0]
|
|
if id.endswith("01"):
|
|
d_fid_file.write(id + "\n")
|
|
elif id.endswith("09"):
|
|
e_fid_file.write(id + "\n")
|
|
else:
|
|
t_fid_file.write(id + "\n")
|
|
t_fid_file.close()
|
|
d_fid_file.close()
|
|
e_fid_file.close()
|
|
return len(wav_file_list)
|