mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
add preprocess and training
This commit is contained in:
parent
379fd2b9fd
commit
19eaa68202
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -18,4 +18,5 @@
|
|||
*/saved_models
|
||||
!vocoder/saved_models/pretrained/**
|
||||
!encoder/saved_models/pretrained.pt
|
||||
wavs
|
||||
wavs
|
||||
log
|
|
@ -34,8 +34,16 @@ def load_model(weights_fpath: Path, device=None):
|
|||
_model.load_state_dict(checkpoint["model_state"])
|
||||
_model.eval()
|
||||
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
||||
return _model
|
||||
|
||||
|
||||
def set_model(model, device=None):
|
||||
global _model, _device
|
||||
_model = model
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_device = device
|
||||
_model.to(device)
|
||||
|
||||
def is_loaded():
|
||||
return _model is not None
|
||||
|
||||
|
@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch):
|
|||
|
||||
|
||||
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
||||
min_pad_coverage=0.75, overlap=0.5):
|
||||
min_pad_coverage=0.75, overlap=0.5, rate=None):
|
||||
"""
|
||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
||||
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
||||
|
@ -85,9 +93,18 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram
|
|||
assert 0 <= overlap < 1
|
||||
assert 0 < min_pad_coverage <= 1
|
||||
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||
if rate != None:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||
else:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||
|
||||
assert 0 < frame_step, "The rate is too high"
|
||||
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
||||
(sampling_rate / (samples_per_frame * partials_n_frames))
|
||||
|
||||
# Compute the slices
|
||||
wav_slices, mel_slices = [], []
|
||||
|
|
|
@ -148,8 +148,9 @@ class MelDecoderMOLv2(AbsMelDecoder):
|
|||
decoder_inputs = self.reduce_proj(decoder_inputs)
|
||||
|
||||
# (B, num_mels, T_dec)
|
||||
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
|
||||
mel_outputs, predicted_stop, alignments = self.decoder(
|
||||
decoder_inputs, speech, feature_lengths//int(self.encoder_down_factor))
|
||||
decoder_inputs, speech, T_dec)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
from ppg_extractor import load_model
|
||||
import encoder.inference as Encoder
|
||||
from encoder.audio import preprocess_wav
|
||||
from encoder import audio
|
||||
from utils.f0_utils import compute_f0
|
||||
|
||||
from torch.multiprocessing import Pool, cpu_count
|
||||
|
@ -37,17 +38,32 @@ def _compute_f0_from_wav(wav, output_fpath):
|
|||
f0 = compute_f0(wav, SAMPLE_RATE)
|
||||
np.save(output_fpath, f0, allow_pickle=False)
|
||||
|
||||
def _compute_spkEmbed(wav, output_fpath):
|
||||
embed = Encoder.embed_utterance(wav)
|
||||
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
|
||||
Encoder.set_model(encoder_model_local)
|
||||
# Compute where to split the utterance into partials and pad if necessary
|
||||
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
|
||||
max_wave_length = wave_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials
|
||||
frames = audio.wav_to_mel_spectrogram(wav)
|
||||
frames_batch = np.array([frames[s] for s in mel_slices])
|
||||
partial_embeds = Encoder.embed_frames_batch(frames_batch)
|
||||
|
||||
# Compute the utterance embedding from the partial embeddings
|
||||
raw_embed = np.mean(partial_embeds, axis=0)
|
||||
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
|
||||
np.save(output_fpath, embed, allow_pickle=False)
|
||||
|
||||
def preprocess_one(wav_path, out_dir, device, ppg_model_local):
|
||||
wav = preprocess_wav(wav_path)
|
||||
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
|
||||
wav = preprocess_wav(wav_path, 24000, False, False)
|
||||
utt_id = os.path.basename(wav_path).rstrip(".wav")
|
||||
|
||||
_compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
|
||||
_compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
|
||||
_compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", wav=wav)
|
||||
_compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
|
||||
|
||||
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
|
||||
# Glob wav files
|
||||
|
@ -57,27 +73,28 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
|
|||
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
|
||||
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
||||
Encoder.load_model(speaker_encoder_model, "cpu")
|
||||
# ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
||||
ppg_model_local = None
|
||||
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
|
||||
if n_processes is None:
|
||||
n_processes = cpu_count()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, device=device)
|
||||
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
|
||||
job = Pool(n_processes).imap(func, wav_file_list)
|
||||
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
|
||||
|
||||
# t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
||||
# d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
||||
# e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
||||
# for file in wav_file_list:
|
||||
# id = os.path.basename(file).rstrip(".wav")
|
||||
# if id.endswith("1"):
|
||||
# d_fid_file.write(id + "\n")
|
||||
# elif id.endswith("9"):
|
||||
# e_fid_file.write(id + "\n")
|
||||
# else:
|
||||
# t_fid_file.write(id + "\n")
|
||||
# t_fid_file.close()
|
||||
# d_fid_file.close()
|
||||
# e_fid_file.close()
|
||||
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
||||
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
||||
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
||||
for file in wav_file_list:
|
||||
id = os.path.basename(file).rstrip(".wav")
|
||||
if id.endswith("1"):
|
||||
d_fid_file.write(id + "\n")
|
||||
elif id.endswith("9"):
|
||||
e_fid_file.write(id + "\n")
|
||||
else:
|
||||
t_fid_file.write(id + "\n")
|
||||
t_fid_file.close()
|
||||
d_fid_file.close()
|
||||
e_fid_file.close()
|
||||
|
|
1
ppg2mel/train/__init__.py
Normal file
1
ppg2mel/train/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
#
|
|
@ -34,7 +34,7 @@ if __name__ == "__main__":
|
|||
# "Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
|
||||
"Path your trained ppg encoder model.")
|
||||
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained.pt", help=\
|
||||
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
|
||||
"Path your trained speaker encoder model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
67
train.py
Normal file
67
train.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ckpt/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -49,8 +49,7 @@ def mel_spectrogram(
|
|||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
||||
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
||||
mel_spec = _spectral_normalize_torch(mel_spec)
|
||||
|
|
|
@ -3,7 +3,9 @@ import numpy as np
|
|||
import torch
|
||||
from utils.f0_utils import get_cont_lf0
|
||||
import resampy
|
||||
from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram, normalize
|
||||
from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram
|
||||
from librosa.util import normalize
|
||||
import os
|
||||
|
||||
|
||||
def read_fids(fid_list_f):
|
||||
|
@ -11,7 +13,6 @@ def read_fids(fid_list_f):
|
|||
fids = [l.strip().split()[0] for l in f if l.strip()]
|
||||
return fids
|
||||
|
||||
|
||||
class OneshotVcDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -61,17 +62,17 @@ class OneshotVcDataset(torch.utils.data.Dataset):
|
|||
return len(self.fid_list)
|
||||
|
||||
def get_spk_dvec(self, fid):
|
||||
spk_name = fid.split("_")[0]
|
||||
spk_name = fid
|
||||
if spk_name.startswith("p"):
|
||||
spk_dvec_path = f"{self.vctk_spk_dvec_dir}/{spk_name}.npy"
|
||||
spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy"
|
||||
else:
|
||||
spk_dvec_path = f"{self.libri_spk_dvec_dir}/{spk_name}.npy"
|
||||
spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy"
|
||||
return torch.from_numpy(np.load(spk_dvec_path))
|
||||
|
||||
def compute_mel(self, wav_path):
|
||||
audio, sr = load_wav(wav_path)
|
||||
if sr != 16000:
|
||||
audio = resampy.resample(audio, sr, 16000)
|
||||
if sr != 24000:
|
||||
audio = resampy.resample(audio, sr, 24000)
|
||||
audio = audio / MAX_WAV_VALUE
|
||||
audio = normalize(audio) * 0.95
|
||||
audio = torch.FloatTensor(audio).unsqueeze(0)
|
||||
|
@ -79,11 +80,11 @@ class OneshotVcDataset(torch.utils.data.Dataset):
|
|||
audio,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
sampling_rate=16000,
|
||||
hop_size=200,
|
||||
win_size=800,
|
||||
sampling_rate=24000,
|
||||
hop_size=240,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=7600,
|
||||
fmax=8000,
|
||||
)
|
||||
return melspec.squeeze(0).numpy().T
|
||||
|
||||
|
@ -98,14 +99,16 @@ class OneshotVcDataset(torch.utils.data.Dataset):
|
|||
# 1. Load features
|
||||
if fid.startswith("p"):
|
||||
# vctk
|
||||
ppg = np.load(f"{self.vctk_ppg_dir}/{fid}.{self.ppg_file_ext}")
|
||||
f0 = np.load(f"{self.vctk_f0_dir}/{fid}.{self.f0_file_ext}")
|
||||
mel = self.compute_mel(f"{self.vctk_wav_dir}/{fid}.{self.wav_file_ext}")
|
||||
sub = fid.split("_")[0]
|
||||
ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
|
||||
f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
|
||||
mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
|
||||
else:
|
||||
# libritts
|
||||
ppg = np.load(f"{self.libri_ppg_dir}/{fid}.{self.ppg_file_ext}")
|
||||
f0 = np.load(f"{self.libri_f0_dir}/{fid}.{self.f0_file_ext}")
|
||||
mel = self.compute_mel(f"{self.libri_wav_dir}/{fid}.{self.wav_file_ext}")
|
||||
# aidatatang
|
||||
sub = fid[5:10]
|
||||
ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
|
||||
f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
|
||||
mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
|
||||
if self.min_max_norm_mel:
|
||||
mel = self.bin_level_min_max_norm(mel)
|
||||
|
||||
|
@ -148,7 +151,6 @@ class OneshotVcDataset(torch.utils.data.Dataset):
|
|||
mel = mel[:min_len]
|
||||
return f0, ppg, mel
|
||||
|
||||
|
||||
class MultiSpkVcCollate():
|
||||
"""Zero-pads model inputs and targets based on number of frames per step
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user