add preprocess and training

This commit is contained in:
babysor00 2022-02-13 11:28:41 +08:00
parent 379fd2b9fd
commit 19eaa68202
9 changed files with 156 additions and 51 deletions

1
.gitignore vendored
View File

@ -19,3 +19,4 @@
!vocoder/saved_models/pretrained/** !vocoder/saved_models/pretrained/**
!encoder/saved_models/pretrained.pt !encoder/saved_models/pretrained.pt
wavs wavs
log

View File

@ -34,7 +34,15 @@ def load_model(weights_fpath: Path, device=None):
_model.load_state_dict(checkpoint["model_state"]) _model.load_state_dict(checkpoint["model_state"])
_model.eval() _model.eval()
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
return _model
def set_model(model, device=None):
global _model, _device
_model = model
if device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_device = device
_model.to(device)
def is_loaded(): def is_loaded():
return _model is not None return _model is not None
@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch):
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
min_pad_coverage=0.75, overlap=0.5): min_pad_coverage=0.75, overlap=0.5, rate=None):
""" """
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
@ -85,9 +93,18 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram
assert 0 <= overlap < 1 assert 0 <= overlap < 1
assert 0 < min_pad_coverage <= 1 assert 0 < min_pad_coverage <= 1
samples_per_frame = int((sampling_rate * mel_window_step / 1000)) if rate != None:
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) samples_per_frame = int((sampling_rate * mel_window_step / 1000))
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
else:
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
assert 0 < frame_step, "The rate is too high"
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
(sampling_rate / (samples_per_frame * partials_n_frames))
# Compute the slices # Compute the slices
wav_slices, mel_slices = [], [] wav_slices, mel_slices = [], []

View File

@ -148,8 +148,9 @@ class MelDecoderMOLv2(AbsMelDecoder):
decoder_inputs = self.reduce_proj(decoder_inputs) decoder_inputs = self.reduce_proj(decoder_inputs)
# (B, num_mels, T_dec) # (B, num_mels, T_dec)
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
mel_outputs, predicted_stop, alignments = self.decoder( mel_outputs, predicted_stop, alignments = self.decoder(
decoder_inputs, speech, feature_lengths//int(self.encoder_down_factor)) decoder_inputs, speech, T_dec)
## Post-processing ## Post-processing
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2) mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs_postnet = mel_outputs + mel_outputs_postnet

View File

@ -8,6 +8,7 @@ from pathlib import Path
from ppg_extractor import load_model from ppg_extractor import load_model
import encoder.inference as Encoder import encoder.inference as Encoder
from encoder.audio import preprocess_wav from encoder.audio import preprocess_wav
from encoder import audio
from utils.f0_utils import compute_f0 from utils.f0_utils import compute_f0
from torch.multiprocessing import Pool, cpu_count from torch.multiprocessing import Pool, cpu_count
@ -37,17 +38,32 @@ def _compute_f0_from_wav(wav, output_fpath):
f0 = compute_f0(wav, SAMPLE_RATE) f0 = compute_f0(wav, SAMPLE_RATE)
np.save(output_fpath, f0, allow_pickle=False) np.save(output_fpath, f0, allow_pickle=False)
def _compute_spkEmbed(wav, output_fpath): def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
embed = Encoder.embed_utterance(wav) Encoder.set_model(encoder_model_local)
# Compute where to split the utterance into partials and pad if necessary
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
max_wave_length = wave_slices[-1].stop
if max_wave_length >= len(wav):
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
# Split the utterance into partials
frames = audio.wav_to_mel_spectrogram(wav)
frames_batch = np.array([frames[s] for s in mel_slices])
partial_embeds = Encoder.embed_frames_batch(frames_batch)
# Compute the utterance embedding from the partial embeddings
raw_embed = np.mean(partial_embeds, axis=0)
embed = raw_embed / np.linalg.norm(raw_embed, 2)
np.save(output_fpath, embed, allow_pickle=False) np.save(output_fpath, embed, allow_pickle=False)
def preprocess_one(wav_path, out_dir, device, ppg_model_local): def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
wav = preprocess_wav(wav_path) wav = preprocess_wav(wav_path, 24000, False, False)
utt_id = os.path.basename(wav_path).rstrip(".wav") utt_id = os.path.basename(wav_path).rstrip(".wav")
_compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local) _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
_compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav) _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
_compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", wav=wav) _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model): def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
# Glob wav files # Glob wav files
@ -57,27 +73,28 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True) out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True) out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True) out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") # ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
Encoder.load_model(speaker_encoder_model, "cpu") ppg_model_local = None
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
if n_processes is None: if n_processes is None:
n_processes = cpu_count() n_processes = cpu_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, device=device) func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
job = Pool(n_processes).imap(func, wav_file_list) job = Pool(n_processes).imap(func, wav_file_list)
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav")) list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
# t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8") t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
# d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8") d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
# e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8") e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
# for file in wav_file_list: for file in wav_file_list:
# id = os.path.basename(file).rstrip(".wav") id = os.path.basename(file).rstrip(".wav")
# if id.endswith("1"): if id.endswith("1"):
# d_fid_file.write(id + "\n") d_fid_file.write(id + "\n")
# elif id.endswith("9"): elif id.endswith("9"):
# e_fid_file.write(id + "\n") e_fid_file.write(id + "\n")
# else: else:
# t_fid_file.write(id + "\n") t_fid_file.write(id + "\n")
# t_fid_file.close() t_fid_file.close()
# d_fid_file.close() d_fid_file.close()
# e_fid_file.close() e_fid_file.close()

View File

@ -0,0 +1 @@
#

View File

@ -34,7 +34,7 @@ if __name__ == "__main__":
# "Preprocess audio without trimming silences (not recommended).") # "Preprocess audio without trimming silences (not recommended).")
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\ parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
"Path your trained ppg encoder model.") "Path your trained ppg encoder model.")
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained.pt", help=\ parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
"Path your trained speaker encoder model.") "Path your trained speaker encoder model.")
args = parser.parse_args() args = parser.parse_args()

67
train.py Normal file
View File

@ -0,0 +1,67 @@
import sys
import torch
import argparse
import numpy as np
from utils.load_yaml import HpsYaml
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
# For reproducibility, comment these may speed up training
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
# Arguments
parser = argparse.ArgumentParser(description=
'Training PPG2Mel VC model.')
parser.add_argument('--config', type=str,
help='Path to experiment config, e.g., config/vc.yaml')
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--logdir', default='log/', type=str,
help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='ckpt/', type=str,
help='Checkpoint path.', required=False)
parser.add_argument('--outdir', default='result/', type=str,
help='Decode output path.', required=False)
parser.add_argument('--load', default=None, type=str,
help='Load pre-trained model (for training only)', required=False)
parser.add_argument('--warm_start', action='store_true',
help='Load model weights only, ignore specified layers.')
parser.add_argument('--seed', default=0, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
###
paras = parser.parse_args()
setattr(paras, 'gpu', not paras.cpu)
setattr(paras, 'pin_memory', not paras.no_pin)
setattr(paras, 'verbose', not paras.no_msg)
# Make the config dict dot visitable
config = HpsYaml(paras.config)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(paras.seed)
print(">>> OneShot VC training ...")
mode = "train"
solver = Solver(config, paras, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
sys.exit(0)
if __name__ == "__main__":
main()

View File

@ -49,8 +49,7 @@ def mel_spectrogram(
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True) center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
mel_spec = _spectral_normalize_torch(mel_spec) mel_spec = _spectral_normalize_torch(mel_spec)

View File

@ -3,7 +3,9 @@ import numpy as np
import torch import torch
from utils.f0_utils import get_cont_lf0 from utils.f0_utils import get_cont_lf0
import resampy import resampy
from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram, normalize from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram
from librosa.util import normalize
import os
def read_fids(fid_list_f): def read_fids(fid_list_f):
@ -11,7 +13,6 @@ def read_fids(fid_list_f):
fids = [l.strip().split()[0] for l in f if l.strip()] fids = [l.strip().split()[0] for l in f if l.strip()]
return fids return fids
class OneshotVcDataset(torch.utils.data.Dataset): class OneshotVcDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
@ -61,17 +62,17 @@ class OneshotVcDataset(torch.utils.data.Dataset):
return len(self.fid_list) return len(self.fid_list)
def get_spk_dvec(self, fid): def get_spk_dvec(self, fid):
spk_name = fid.split("_")[0] spk_name = fid
if spk_name.startswith("p"): if spk_name.startswith("p"):
spk_dvec_path = f"{self.vctk_spk_dvec_dir}/{spk_name}.npy" spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy"
else: else:
spk_dvec_path = f"{self.libri_spk_dvec_dir}/{spk_name}.npy" spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy"
return torch.from_numpy(np.load(spk_dvec_path)) return torch.from_numpy(np.load(spk_dvec_path))
def compute_mel(self, wav_path): def compute_mel(self, wav_path):
audio, sr = load_wav(wav_path) audio, sr = load_wav(wav_path)
if sr != 16000: if sr != 24000:
audio = resampy.resample(audio, sr, 16000) audio = resampy.resample(audio, sr, 24000)
audio = audio / MAX_WAV_VALUE audio = audio / MAX_WAV_VALUE
audio = normalize(audio) * 0.95 audio = normalize(audio) * 0.95
audio = torch.FloatTensor(audio).unsqueeze(0) audio = torch.FloatTensor(audio).unsqueeze(0)
@ -79,11 +80,11 @@ class OneshotVcDataset(torch.utils.data.Dataset):
audio, audio,
n_fft=1024, n_fft=1024,
num_mels=80, num_mels=80,
sampling_rate=16000, sampling_rate=24000,
hop_size=200, hop_size=240,
win_size=800, win_size=1024,
fmin=0, fmin=0,
fmax=7600, fmax=8000,
) )
return melspec.squeeze(0).numpy().T return melspec.squeeze(0).numpy().T
@ -98,14 +99,16 @@ class OneshotVcDataset(torch.utils.data.Dataset):
# 1. Load features # 1. Load features
if fid.startswith("p"): if fid.startswith("p"):
# vctk # vctk
ppg = np.load(f"{self.vctk_ppg_dir}/{fid}.{self.ppg_file_ext}") sub = fid.split("_")[0]
f0 = np.load(f"{self.vctk_f0_dir}/{fid}.{self.f0_file_ext}") ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
mel = self.compute_mel(f"{self.vctk_wav_dir}/{fid}.{self.wav_file_ext}") f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
else: else:
# libritts # aidatatang
ppg = np.load(f"{self.libri_ppg_dir}/{fid}.{self.ppg_file_ext}") sub = fid[5:10]
f0 = np.load(f"{self.libri_f0_dir}/{fid}.{self.f0_file_ext}") ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
mel = self.compute_mel(f"{self.libri_wav_dir}/{fid}.{self.wav_file_ext}") f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
if self.min_max_norm_mel: if self.min_max_norm_mel:
mel = self.bin_level_min_max_norm(mel) mel = self.bin_level_min_max_norm(mel)
@ -148,7 +151,6 @@ class OneshotVcDataset(torch.utils.data.Dataset):
mel = mel[:min_len] mel = mel[:min_len]
return f0, ppg, mel return f0, ppg, mel
class MultiSpkVcCollate(): class MultiSpkVcCollate():
"""Zero-pads model inputs and targets based on number of frames per step """Zero-pads model inputs and targets based on number of frames per step
""" """