From 19eaa682026a160286bb353cdeeb48dbb6d71e32 Mon Sep 17 00:00:00 2001 From: babysor00 Date: Sun, 13 Feb 2022 11:28:41 +0800 Subject: [PATCH] add preprocess and training --- .gitignore | 3 +- encoder/inference.py | 27 +++++++++++++--- ppg2mel/__init__.py | 3 +- ppg2mel/preprocess.py | 61 ++++++++++++++++++++++------------- ppg2mel/train/__init__.py | 1 + pre4ppg.py | 2 +- train.py | 67 +++++++++++++++++++++++++++++++++++++++ utils/audio_utils.py | 3 +- utils/data_load.py | 40 ++++++++++++----------- 9 files changed, 156 insertions(+), 51 deletions(-) create mode 100644 ppg2mel/train/__init__.py create mode 100644 train.py diff --git a/.gitignore b/.gitignore index 18d33d3..7df88c7 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ */saved_models !vocoder/saved_models/pretrained/** !encoder/saved_models/pretrained.pt -wavs \ No newline at end of file +wavs +log \ No newline at end of file diff --git a/encoder/inference.py b/encoder/inference.py index 4ca417b..af9a529 100644 --- a/encoder/inference.py +++ b/encoder/inference.py @@ -34,8 +34,16 @@ def load_model(weights_fpath: Path, device=None): _model.load_state_dict(checkpoint["model_state"]) _model.eval() print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) + return _model - +def set_model(model, device=None): + global _model, _device + _model = model + if device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _device = device + _model.to(device) + def is_loaded(): return _model is not None @@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch): def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, - min_pad_coverage=0.75, overlap=0.5): + min_pad_coverage=0.75, overlap=0.5, rate=None): """ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain partial utterances of each. Both the waveform and the mel @@ -85,9 +93,18 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram assert 0 <= overlap < 1 assert 0 < min_pad_coverage <= 1 - samples_per_frame = int((sampling_rate * mel_window_step / 1000)) - n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) - frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) + if rate != None: + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) + else: + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) + + assert 0 < frame_step, "The rate is too high" + assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \ + (sampling_rate / (samples_per_frame * partials_n_frames)) # Compute the slices wav_slices, mel_slices = [], [] diff --git a/ppg2mel/__init__.py b/ppg2mel/__init__.py index 09ce34b..53ee3b2 100644 --- a/ppg2mel/__init__.py +++ b/ppg2mel/__init__.py @@ -148,8 +148,9 @@ class MelDecoderMOLv2(AbsMelDecoder): decoder_inputs = self.reduce_proj(decoder_inputs) # (B, num_mels, T_dec) + T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor') mel_outputs, predicted_stop, alignments = self.decoder( - decoder_inputs, speech, feature_lengths//int(self.encoder_down_factor)) + decoder_inputs, speech, T_dec) ## Post-processing mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2) mel_outputs_postnet = mel_outputs + mel_outputs_postnet diff --git a/ppg2mel/preprocess.py b/ppg2mel/preprocess.py index 1ca0016..06e6158 100644 --- a/ppg2mel/preprocess.py +++ b/ppg2mel/preprocess.py @@ -8,6 +8,7 @@ from pathlib import Path from ppg_extractor import load_model import encoder.inference as Encoder from encoder.audio import preprocess_wav +from encoder import audio from utils.f0_utils import compute_f0 from torch.multiprocessing import Pool, cpu_count @@ -37,17 +38,32 @@ def _compute_f0_from_wav(wav, output_fpath): f0 = compute_f0(wav, SAMPLE_RATE) np.save(output_fpath, f0, allow_pickle=False) -def _compute_spkEmbed(wav, output_fpath): - embed = Encoder.embed_utterance(wav) +def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device): + Encoder.set_model(encoder_model_local) + # Compute where to split the utterance into partials and pad if necessary + wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75) + max_wave_length = wave_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials + frames = audio.wav_to_mel_spectrogram(wav) + frames_batch = np.array([frames[s] for s in mel_slices]) + partial_embeds = Encoder.embed_frames_batch(frames_batch) + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + np.save(output_fpath, embed, allow_pickle=False) -def preprocess_one(wav_path, out_dir, device, ppg_model_local): - wav = preprocess_wav(wav_path) +def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local): + wav = preprocess_wav(wav_path, 24000, False, False) utt_id = os.path.basename(wav_path).rstrip(".wav") _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local) _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav) - _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", wav=wav) + _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav) def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model): # Glob wav files @@ -57,27 +73,28 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True) out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True) out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True) - ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") - Encoder.load_model(speaker_encoder_model, "cpu") + # ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") + ppg_model_local = None + encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu") if n_processes is None: n_processes = cpu_count() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, device=device) + func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device) job = Pool(n_processes).imap(func, wav_file_list) list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav")) - # t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8") - # d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8") - # e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8") - # for file in wav_file_list: - # id = os.path.basename(file).rstrip(".wav") - # if id.endswith("1"): - # d_fid_file.write(id + "\n") - # elif id.endswith("9"): - # e_fid_file.write(id + "\n") - # else: - # t_fid_file.write(id + "\n") - # t_fid_file.close() - # d_fid_file.close() - # e_fid_file.close() + t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8") + d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8") + e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8") + for file in wav_file_list: + id = os.path.basename(file).rstrip(".wav") + if id.endswith("1"): + d_fid_file.write(id + "\n") + elif id.endswith("9"): + e_fid_file.write(id + "\n") + else: + t_fid_file.write(id + "\n") + t_fid_file.close() + d_fid_file.close() + e_fid_file.close() diff --git a/ppg2mel/train/__init__.py b/ppg2mel/train/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/ppg2mel/train/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/pre4ppg.py b/pre4ppg.py index 87bacd2..fcfa0fa 100644 --- a/pre4ppg.py +++ b/pre4ppg.py @@ -34,7 +34,7 @@ if __name__ == "__main__": # "Preprocess audio without trimming silences (not recommended).") parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\ "Path your trained ppg encoder model.") - parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained.pt", help=\ + parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\ "Path your trained speaker encoder model.") args = parser.parse_args() diff --git a/train.py b/train.py new file mode 100644 index 0000000..fed7501 --- /dev/null +++ b/train.py @@ -0,0 +1,67 @@ +import sys +import torch +import argparse +import numpy as np +from utils.load_yaml import HpsYaml +from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + +# For reproducibility, comment these may speed up training +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +def main(): + # Arguments + parser = argparse.ArgumentParser(description= + 'Training PPG2Mel VC model.') + parser.add_argument('--config', type=str, + help='Path to experiment config, e.g., config/vc.yaml') + parser.add_argument('--name', default=None, type=str, help='Name for logging.') + parser.add_argument('--logdir', default='log/', type=str, + help='Logging path.', required=False) + parser.add_argument('--ckpdir', default='ckpt/', type=str, + help='Checkpoint path.', required=False) + parser.add_argument('--outdir', default='result/', type=str, + help='Decode output path.', required=False) + parser.add_argument('--load', default=None, type=str, + help='Load pre-trained model (for training only)', required=False) + parser.add_argument('--warm_start', action='store_true', + help='Load model weights only, ignore specified layers.') + parser.add_argument('--seed', default=0, type=int, + help='Random seed for reproducable results.', required=False) + parser.add_argument('--njobs', default=8, type=int, + help='Number of threads for dataloader/decoding.', required=False) + parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') + parser.add_argument('--no-pin', action='store_true', + help='Disable pin-memory for dataloader') + parser.add_argument('--test', action='store_true', help='Test the model.') + parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') + parser.add_argument('--finetune', action='store_true', help='Finetune model') + parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') + parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') + parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') + + ### + + paras = parser.parse_args() + setattr(paras, 'gpu', not paras.cpu) + setattr(paras, 'pin_memory', not paras.no_pin) + setattr(paras, 'verbose', not paras.no_msg) + # Make the config dict dot visitable + config = HpsYaml(paras.config) + + np.random.seed(paras.seed) + torch.manual_seed(paras.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(paras.seed) + + print(">>> OneShot VC training ...") + mode = "train" + solver = Solver(config, paras, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/utils/audio_utils.py b/utils/audio_utils.py index 58c6129..1dbeddb 100644 --- a/utils/audio_utils.py +++ b/utils/audio_utils.py @@ -49,8 +49,7 @@ def mel_spectrogram( y = y.squeeze(1) spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], - center=center, pad_mode='reflect', normalized=False, onesided=True) - + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) mel_spec = _spectral_normalize_torch(mel_spec) diff --git a/utils/data_load.py b/utils/data_load.py index 67785df..063fa66 100644 --- a/utils/data_load.py +++ b/utils/data_load.py @@ -3,7 +3,9 @@ import numpy as np import torch from utils.f0_utils import get_cont_lf0 import resampy -from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram, normalize +from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram +from librosa.util import normalize +import os def read_fids(fid_list_f): @@ -11,7 +13,6 @@ def read_fids(fid_list_f): fids = [l.strip().split()[0] for l in f if l.strip()] return fids - class OneshotVcDataset(torch.utils.data.Dataset): def __init__( self, @@ -61,17 +62,17 @@ class OneshotVcDataset(torch.utils.data.Dataset): return len(self.fid_list) def get_spk_dvec(self, fid): - spk_name = fid.split("_")[0] + spk_name = fid if spk_name.startswith("p"): - spk_dvec_path = f"{self.vctk_spk_dvec_dir}/{spk_name}.npy" + spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy" else: - spk_dvec_path = f"{self.libri_spk_dvec_dir}/{spk_name}.npy" + spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy" return torch.from_numpy(np.load(spk_dvec_path)) def compute_mel(self, wav_path): audio, sr = load_wav(wav_path) - if sr != 16000: - audio = resampy.resample(audio, sr, 16000) + if sr != 24000: + audio = resampy.resample(audio, sr, 24000) audio = audio / MAX_WAV_VALUE audio = normalize(audio) * 0.95 audio = torch.FloatTensor(audio).unsqueeze(0) @@ -79,11 +80,11 @@ class OneshotVcDataset(torch.utils.data.Dataset): audio, n_fft=1024, num_mels=80, - sampling_rate=16000, - hop_size=200, - win_size=800, + sampling_rate=24000, + hop_size=240, + win_size=1024, fmin=0, - fmax=7600, + fmax=8000, ) return melspec.squeeze(0).numpy().T @@ -98,14 +99,16 @@ class OneshotVcDataset(torch.utils.data.Dataset): # 1. Load features if fid.startswith("p"): # vctk - ppg = np.load(f"{self.vctk_ppg_dir}/{fid}.{self.ppg_file_ext}") - f0 = np.load(f"{self.vctk_f0_dir}/{fid}.{self.f0_file_ext}") - mel = self.compute_mel(f"{self.vctk_wav_dir}/{fid}.{self.wav_file_ext}") + sub = fid.split("_")[0] + ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}") + f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}") + mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}") else: - # libritts - ppg = np.load(f"{self.libri_ppg_dir}/{fid}.{self.ppg_file_ext}") - f0 = np.load(f"{self.libri_f0_dir}/{fid}.{self.f0_file_ext}") - mel = self.compute_mel(f"{self.libri_wav_dir}/{fid}.{self.wav_file_ext}") + # aidatatang + sub = fid[5:10] + ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}") + f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}") + mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}") if self.min_max_norm_mel: mel = self.bin_level_min_max_norm(mel) @@ -148,7 +151,6 @@ class OneshotVcDataset(torch.utils.data.Dataset): mel = mel[:min_len] return f0, ppg, mel - class MultiSpkVcCollate(): """Zero-pads model inputs and targets based on number of frames per step """