add preprocess and training

This commit is contained in:
babysor00 2022-02-13 11:28:41 +08:00
parent 379fd2b9fd
commit 19eaa68202
9 changed files with 156 additions and 51 deletions

1
.gitignore vendored
View File

@ -19,3 +19,4 @@
!vocoder/saved_models/pretrained/**
!encoder/saved_models/pretrained.pt
wavs
log

View File

@ -34,7 +34,15 @@ def load_model(weights_fpath: Path, device=None):
_model.load_state_dict(checkpoint["model_state"])
_model.eval()
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
return _model
def set_model(model, device=None):
global _model, _device
_model = model
if device is None:
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_device = device
_model.to(device)
def is_loaded():
return _model is not None
@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch):
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
min_pad_coverage=0.75, overlap=0.5):
min_pad_coverage=0.75, overlap=0.5, rate=None):
"""
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
@ -85,10 +93,19 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram
assert 0 <= overlap < 1
assert 0 < min_pad_coverage <= 1
if rate != None:
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
else:
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
assert 0 < frame_step, "The rate is too high"
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
(sampling_rate / (samples_per_frame * partials_n_frames))
# Compute the slices
wav_slices, mel_slices = [], []
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)

View File

@ -148,8 +148,9 @@ class MelDecoderMOLv2(AbsMelDecoder):
decoder_inputs = self.reduce_proj(decoder_inputs)
# (B, num_mels, T_dec)
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
mel_outputs, predicted_stop, alignments = self.decoder(
decoder_inputs, speech, feature_lengths//int(self.encoder_down_factor))
decoder_inputs, speech, T_dec)
## Post-processing
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet

View File

@ -8,6 +8,7 @@ from pathlib import Path
from ppg_extractor import load_model
import encoder.inference as Encoder
from encoder.audio import preprocess_wav
from encoder import audio
from utils.f0_utils import compute_f0
from torch.multiprocessing import Pool, cpu_count
@ -37,17 +38,32 @@ def _compute_f0_from_wav(wav, output_fpath):
f0 = compute_f0(wav, SAMPLE_RATE)
np.save(output_fpath, f0, allow_pickle=False)
def _compute_spkEmbed(wav, output_fpath):
embed = Encoder.embed_utterance(wav)
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
Encoder.set_model(encoder_model_local)
# Compute where to split the utterance into partials and pad if necessary
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
max_wave_length = wave_slices[-1].stop
if max_wave_length >= len(wav):
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
# Split the utterance into partials
frames = audio.wav_to_mel_spectrogram(wav)
frames_batch = np.array([frames[s] for s in mel_slices])
partial_embeds = Encoder.embed_frames_batch(frames_batch)
# Compute the utterance embedding from the partial embeddings
raw_embed = np.mean(partial_embeds, axis=0)
embed = raw_embed / np.linalg.norm(raw_embed, 2)
np.save(output_fpath, embed, allow_pickle=False)
def preprocess_one(wav_path, out_dir, device, ppg_model_local):
wav = preprocess_wav(wav_path)
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
wav = preprocess_wav(wav_path, 24000, False, False)
utt_id = os.path.basename(wav_path).rstrip(".wav")
_compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
_compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
_compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", wav=wav)
_compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
# Glob wav files
@ -57,27 +73,28 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
Encoder.load_model(speaker_encoder_model, "cpu")
# ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
ppg_model_local = None
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
if n_processes is None:
n_processes = cpu_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, device=device)
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
job = Pool(n_processes).imap(func, wav_file_list)
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
# t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
# d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
# e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
# for file in wav_file_list:
# id = os.path.basename(file).rstrip(".wav")
# if id.endswith("1"):
# d_fid_file.write(id + "\n")
# elif id.endswith("9"):
# e_fid_file.write(id + "\n")
# else:
# t_fid_file.write(id + "\n")
# t_fid_file.close()
# d_fid_file.close()
# e_fid_file.close()
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
for file in wav_file_list:
id = os.path.basename(file).rstrip(".wav")
if id.endswith("1"):
d_fid_file.write(id + "\n")
elif id.endswith("9"):
e_fid_file.write(id + "\n")
else:
t_fid_file.write(id + "\n")
t_fid_file.close()
d_fid_file.close()
e_fid_file.close()

View File

@ -0,0 +1 @@
#

View File

@ -34,7 +34,7 @@ if __name__ == "__main__":
# "Preprocess audio without trimming silences (not recommended).")
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
"Path your trained ppg encoder model.")
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained.pt", help=\
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
"Path your trained speaker encoder model.")
args = parser.parse_args()

67
train.py Normal file
View File

@ -0,0 +1,67 @@
import sys
import torch
import argparse
import numpy as np
from utils.load_yaml import HpsYaml
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
# For reproducibility, comment these may speed up training
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
# Arguments
parser = argparse.ArgumentParser(description=
'Training PPG2Mel VC model.')
parser.add_argument('--config', type=str,
help='Path to experiment config, e.g., config/vc.yaml')
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--logdir', default='log/', type=str,
help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='ckpt/', type=str,
help='Checkpoint path.', required=False)
parser.add_argument('--outdir', default='result/', type=str,
help='Decode output path.', required=False)
parser.add_argument('--load', default=None, type=str,
help='Load pre-trained model (for training only)', required=False)
parser.add_argument('--warm_start', action='store_true',
help='Load model weights only, ignore specified layers.')
parser.add_argument('--seed', default=0, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=8, type=int,
help='Number of threads for dataloader/decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--no-pin', action='store_true',
help='Disable pin-memory for dataloader')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--finetune', action='store_true', help='Finetune model')
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
###
paras = parser.parse_args()
setattr(paras, 'gpu', not paras.cpu)
setattr(paras, 'pin_memory', not paras.no_pin)
setattr(paras, 'verbose', not paras.no_msg)
# Make the config dict dot visitable
config = HpsYaml(paras.config)
np.random.seed(paras.seed)
torch.manual_seed(paras.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(paras.seed)
print(">>> OneShot VC training ...")
mode = "train"
solver = Solver(config, paras, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
sys.exit(0)
if __name__ == "__main__":
main()

View File

@ -49,8 +49,7 @@ def mel_spectrogram(
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True)
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
mel_spec = _spectral_normalize_torch(mel_spec)

View File

@ -3,7 +3,9 @@ import numpy as np
import torch
from utils.f0_utils import get_cont_lf0
import resampy
from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram, normalize
from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram
from librosa.util import normalize
import os
def read_fids(fid_list_f):
@ -11,7 +13,6 @@ def read_fids(fid_list_f):
fids = [l.strip().split()[0] for l in f if l.strip()]
return fids
class OneshotVcDataset(torch.utils.data.Dataset):
def __init__(
self,
@ -61,17 +62,17 @@ class OneshotVcDataset(torch.utils.data.Dataset):
return len(self.fid_list)
def get_spk_dvec(self, fid):
spk_name = fid.split("_")[0]
spk_name = fid
if spk_name.startswith("p"):
spk_dvec_path = f"{self.vctk_spk_dvec_dir}/{spk_name}.npy"
spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy"
else:
spk_dvec_path = f"{self.libri_spk_dvec_dir}/{spk_name}.npy"
spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy"
return torch.from_numpy(np.load(spk_dvec_path))
def compute_mel(self, wav_path):
audio, sr = load_wav(wav_path)
if sr != 16000:
audio = resampy.resample(audio, sr, 16000)
if sr != 24000:
audio = resampy.resample(audio, sr, 24000)
audio = audio / MAX_WAV_VALUE
audio = normalize(audio) * 0.95
audio = torch.FloatTensor(audio).unsqueeze(0)
@ -79,11 +80,11 @@ class OneshotVcDataset(torch.utils.data.Dataset):
audio,
n_fft=1024,
num_mels=80,
sampling_rate=16000,
hop_size=200,
win_size=800,
sampling_rate=24000,
hop_size=240,
win_size=1024,
fmin=0,
fmax=7600,
fmax=8000,
)
return melspec.squeeze(0).numpy().T
@ -98,14 +99,16 @@ class OneshotVcDataset(torch.utils.data.Dataset):
# 1. Load features
if fid.startswith("p"):
# vctk
ppg = np.load(f"{self.vctk_ppg_dir}/{fid}.{self.ppg_file_ext}")
f0 = np.load(f"{self.vctk_f0_dir}/{fid}.{self.f0_file_ext}")
mel = self.compute_mel(f"{self.vctk_wav_dir}/{fid}.{self.wav_file_ext}")
sub = fid.split("_")[0]
ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
else:
# libritts
ppg = np.load(f"{self.libri_ppg_dir}/{fid}.{self.ppg_file_ext}")
f0 = np.load(f"{self.libri_f0_dir}/{fid}.{self.f0_file_ext}")
mel = self.compute_mel(f"{self.libri_wav_dir}/{fid}.{self.wav_file_ext}")
# aidatatang
sub = fid[5:10]
ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
if self.min_max_norm_mel:
mel = self.bin_level_min_max_norm(mel)
@ -148,7 +151,6 @@ class OneshotVcDataset(torch.utils.data.Dataset):
mel = mel[:min_len]
return f0, ppg, mel
class MultiSpkVcCollate():
"""Zero-pads model inputs and targets based on number of frames per step
"""