diff --git a/ppg2mel/preprocess.py b/ppg2mel/preprocess.py index 06e6158..1009ffd 100644 --- a/ppg2mel/preprocess.py +++ b/ppg2mel/preprocess.py @@ -4,6 +4,8 @@ import torch import numpy as np from tqdm import tqdm from pathlib import Path +import soundfile +import resampy from ppg_extractor import load_model import encoder.inference as Encoder @@ -32,11 +34,13 @@ def _compute_bnf( bnf = ppg_model_local(wav_tensor, wav_length) bnf_npy = bnf.squeeze(0).cpu().numpy() np.save(output_fpath, bnf_npy, allow_pickle=False) + return bnf_npy, len(bnf_npy) def _compute_f0_from_wav(wav, output_fpath): """Compute merged f0 values.""" f0 = compute_f0(wav, SAMPLE_RATE) np.save(output_fpath, f0, allow_pickle=False) + return f0, len(f0) def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device): Encoder.set_model(encoder_model_local) @@ -56,14 +60,22 @@ def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device): embed = raw_embed / np.linalg.norm(raw_embed, 2) np.save(output_fpath, embed, allow_pickle=False) + return embed, len(embed) def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local): - wav = preprocess_wav(wav_path, 24000, False, False) + # wav = preprocess_wav(wav_path) + # try: + wav, sr = soundfile.read(wav_path) + if len(wav) < sr: + return None, sr, len(wav) + if sr != SAMPLE_RATE: + wav = resampy.resample(wav, sr, SAMPLE_RATE) + sr = SAMPLE_RATE utt_id = os.path.basename(wav_path).rstrip(".wav") - _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local) - _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav) - _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav) + _, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local) + _, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav) + _, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav) def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model): # Glob wav files @@ -73,8 +85,7 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True) out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True) out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True) - # ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") - ppg_model_local = None + ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu") if n_processes is None: n_processes = cpu_count() @@ -84,14 +95,15 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder job = Pool(n_processes).imap(func, wav_file_list) list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav")) + # finish processing and mark t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8") d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8") e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8") - for file in wav_file_list: - id = os.path.basename(file).rstrip(".wav") - if id.endswith("1"): + for file in sorted(out_dir.joinpath("f0").glob("*.npy")): + id = os.path.basename(file).rstrip(".f0.npy") + if id.endswith("01"): d_fid_file.write(id + "\n") - elif id.endswith("9"): + elif id.endswith("09"): e_fid_file.write(id + "\n") else: t_fid_file.write(id + "\n") diff --git a/pre4ppg.py b/pre4ppg.py index fcfa0fa..408cb05 100644 --- a/pre4ppg.py +++ b/pre4ppg.py @@ -23,7 +23,7 @@ if __name__ == "__main__": parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ "Path to the output directory that will contain the mel spectrograms, the audios and the " "embeds. Defaults to /PPGVC/ppg2mel/") - parser.add_argument("-n", "--n_processes", type=int, default=8, help=\ + parser.add_argument("-n", "--n_processes", type=int, default=16, help=\ "Number of processes in parallel.") # parser.add_argument("-s", "--skip_existing", action="store_true", help=\ # "Whether to overwrite existing files with the same name. Useful if the preprocessing was " diff --git a/run.py b/run.py index 08019bf..b57ef4a 100644 --- a/run.py +++ b/run.py @@ -39,7 +39,7 @@ def convert(args): # Build models print("Load PPG-model, PPG2Mel-model, Vocoder-model...") ppg_model = load_model( - './ppg_extractor/saved_models/24epoch.pt', + Path('./ppg_extractor/saved_models/24epoch.pt'), device, ) ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device) diff --git a/utils/data_load.py b/utils/data_load.py index 063fa66..90212b5 100644 --- a/utils/data_load.py +++ b/utils/data_load.py @@ -8,6 +8,8 @@ from librosa.util import normalize import os +SAMPLE_RATE=16000 + def read_fids(fid_list_f): with open(fid_list_f, 'r') as f: fids = [l.strip().split()[0] for l in f if l.strip()] @@ -71,8 +73,8 @@ class OneshotVcDataset(torch.utils.data.Dataset): def compute_mel(self, wav_path): audio, sr = load_wav(wav_path) - if sr != 24000: - audio = resampy.resample(audio, sr, 24000) + if sr != SAMPLE_RATE: + audio = resampy.resample(audio, sr, SAMPLE_RATE) audio = audio / MAX_WAV_VALUE audio = normalize(audio) * 0.95 audio = torch.FloatTensor(audio).unsqueeze(0) @@ -80,9 +82,9 @@ class OneshotVcDataset(torch.utils.data.Dataset): audio, n_fft=1024, num_mels=80, - sampling_rate=24000, - hop_size=240, - win_size=1024, + sampling_rate=SAMPLE_RATE, + hop_size=200, + win_size=800, fmin=0, fmax=8000, ) @@ -112,7 +114,7 @@ class OneshotVcDataset(torch.utils.data.Dataset): if self.min_max_norm_mel: mel = self.bin_level_min_max_norm(mel) - f0, ppg, mel = self._adjust_lengths(f0, ppg, mel) + f0, ppg, mel = self._adjust_lengths(f0, ppg, mel, fid) spk_dvec = self.get_spk_dvec(fid) # 2. Convert f0 to continuous log-f0 and u/v flags @@ -132,15 +134,15 @@ class OneshotVcDataset(torch.utils.data.Dataset): return (ppg, lf0_uv, mel, spk_dvec, fid) - def check_lengths(self, f0, ppg, mel): + def check_lengths(self, f0, ppg, mel, fid): LEN_THRESH = 10 assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \ - f"{abs(len(ppg) - len(f0))}" + f"{abs(len(ppg) - len(f0))}: for file {fid}" assert abs(len(mel) - len(f0)) <= LEN_THRESH, \ - f"{abs(len(mel) - len(f0))}" + f"{abs(len(mel) - len(f0))}: for file {fid}" - def _adjust_lengths(self, f0, ppg, mel): - self.check_lengths(f0, ppg, mel) + def _adjust_lengths(self, f0, ppg, mel, fid): + self.check_lengths(f0, ppg, mel, fid) min_len = min( len(f0), len(ppg),