import os import torch import numpy as np from tqdm import tqdm from pathlib import Path import soundfile import resampy from models.ppg_extractor import load_model import encoder.inference as Encoder from models.encoder.audio import preprocess_wav from models.encoder import audio from utils.f0_utils import compute_f0 from torch.multiprocessing import Pool, cpu_count from functools import partial SAMPLE_RATE=16000 def _compute_bnf( wav: any, output_fpath: str, device: torch.device, ppg_model_local: any, ): """ Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF). """ ppg_model_local.to(device) wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0) wav_length = torch.LongTensor([wav.shape[0]]).to(device) with torch.no_grad(): bnf = ppg_model_local(wav_tensor, wav_length) bnf_npy = bnf.squeeze(0).cpu().numpy() np.save(output_fpath, bnf_npy, allow_pickle=False) return bnf_npy, len(bnf_npy) def _compute_f0_from_wav(wav, output_fpath): """Compute merged f0 values.""" f0 = compute_f0(wav, SAMPLE_RATE) np.save(output_fpath, f0, allow_pickle=False) return f0, len(f0) def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device): Encoder.set_model(encoder_model_local) # Compute where to split the utterance into partials and pad if necessary wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75) max_wave_length = wave_slices[-1].stop if max_wave_length >= len(wav): wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") # Split the utterance into partials frames = audio.wav_to_mel_spectrogram(wav) frames_batch = np.array([frames[s] for s in mel_slices]) partial_embeds = Encoder.embed_frames_batch(frames_batch) # Compute the utterance embedding from the partial embeddings raw_embed = np.mean(partial_embeds, axis=0) embed = raw_embed / np.linalg.norm(raw_embed, 2) np.save(output_fpath, embed, allow_pickle=False) return embed, len(embed) def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local): # wav = preprocess_wav(wav_path) # try: wav, sr = soundfile.read(wav_path) if len(wav) < sr: return None, sr, len(wav) if sr != SAMPLE_RATE: wav = resampy.resample(wav, sr, SAMPLE_RATE) sr = SAMPLE_RATE utt_id = os.path.basename(wav_path).rstrip(".wav") _, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local) _, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav) _, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav) def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model): # Glob wav files wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav")) print(f"Globbed {len(wav_file_list)} wav files.") out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True) out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True) out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True) ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu") if n_processes is None: n_processes = cpu_count() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device) job = Pool(n_processes).imap(func, wav_file_list) list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav")) # finish processing and mark t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8") d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8") e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8") for file in sorted(out_dir.joinpath("f0").glob("*.npy")): id = os.path.basename(file).split(".f0.npy")[0] if id.endswith("01"): d_fid_file.write(id + "\n") elif id.endswith("09"): e_fid_file.write(id + "\n") else: t_fid_file.write(id + "\n") t_fid_file.close() d_fid_file.close() e_fid_file.close() return len(wav_file_list)