Fix issue for training and preprocessing

This commit is contained in:
babysor00 2023-02-10 20:34:01 +08:00
parent beec0b93ed
commit 3ce874ab46
8 changed files with 157 additions and 274 deletions

10
.vscode/launch.json vendored
View File

@ -64,6 +64,14 @@
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
]
}
},
{
"name": "Python: Vits Train",
"type": "python",
"request": "launch",
"program": "train.py",
"console": "integratedTerminal",
"args": ["--type", "vits"]
},
]
}

View File

@ -3,10 +3,10 @@ from utils.hparams import HParams
hparams = HParams(
### Signal Processing (used in both synthesizer and vocoder)
sample_rate = 16000,
n_fft = 800,
n_fft = 1024, # filter_length
num_mels = 80,
hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
fmin = 55,
min_level_db = -100,
ref_level_db = 20,
@ -67,7 +67,7 @@ hparams = HParams(
use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
# and [0, max_abs_value] if False
trim_silence = True, # Use with sample_rate of 16000 for best results
trim_silence = False, # Use with sample_rate of 16000 for best results
### SV2TTS
speaker_embedding_size = 256, # Dimension for the speaker embedding

View File

@ -2,12 +2,12 @@ import math
import torch
from torch import nn
from torch.nn import functional as F
from loguru import logger
from .sublayer.vits_modules import *
import monotonic_align
from .base import Base
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils.util import init_weights, get_padding, sequence_mask, rand_slice_segments, generate_path
@ -386,7 +386,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class Vits(Base):
class Vits(nn.Module):
"""
Synthesizer of Vits
"""
@ -408,13 +408,12 @@ class Vits(Base):
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
stop_threshold,
n_speakers=0,
gin_channels=0,
use_sdp=True,
**kwargs):
super().__init__(stop_threshold)
super().__init__()
self.n_vocab = n_vocab
self.spec_channels = spec_channels
self.inter_channels = inter_channels
@ -457,7 +456,7 @@ class Vits(Base):
self.emb_g = nn.Embedding(n_speakers, gin_channels)
def forward(self, x, x_lengths, y, y_lengths, sid=None, emo=None):
# logger.info(f'====> Forward: 1.1.0')
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emo)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
@ -466,7 +465,7 @@ class Vits(Base):
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
# logger.info(f'====> Forward: 1.1.1')
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
@ -475,10 +474,11 @@ class Vits(Base):
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
#logger.info(f'====> Forward: 1.1.1.1')
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
# logger.info(f'====> Forward: 1.1.2')
w = attn.sum(2)
if self.use_sdp:
l_length = self.dp(x, x_mask, w, g=g)
@ -487,7 +487,6 @@ class Vits(Base):
logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g)
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
# expand prior
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
@ -497,7 +496,9 @@ class Vits(Base):
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
def infer(self, x, x_lengths, sid=None, emo=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
# logger.info(f'====> Infer: 1.1.0')
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths,emo)
# logger.info(f'====> Infer: 1.1.1')
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
@ -514,11 +515,14 @@ class Vits(Base):
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = generate_path(w_ceil, attn_mask)
# logger.info(f'====> Infer: 1.1.2')
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
# logger.info(f'====> Infer: 1.1.3')
return o, attn, y_mask, (z, z_p, m_p, logs_p)

View File

@ -20,8 +20,6 @@ device = 'cuda' if torch.cuda.is_available() else "cpu"
model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = EmotionExtractorModel.from_pretrained(model_name).to(device)
embs = []
wavnames = []
def extract_emo(
x: np.ndarray,
@ -48,8 +46,6 @@ class PinyinConverter(NeutralToneWith5Mixin, DefaultConverter):
pinyin = Pinyin(PinyinConverter()).pinyin
def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
skip_existing: bool, hparams, emotion_extract: bool):
## FOR REFERENCE:
@ -67,9 +63,8 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
# Skip existing utterances if needed
mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
emo_fpath = out_dir.joinpath("emo", "emo-%s.npy" % basename)
skip_emo_extract = not emotion_extract or (skip_existing and emo_fpath.exists())
if skip_existing and mel_fpath.exists() and wav_fpath.exists() and skip_emo_extract:
if skip_existing and mel_fpath.exists() and wav_fpath.exists():
return None
# Trim silence
@ -91,18 +86,14 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
np.save(wav_fpath, wav, allow_pickle=False)
if not skip_emo_extract:
emo = extract_emo(np.expand_dims(wav, 0), hparams.sample_rate, True)
np.save(emo_fpath, emo, allow_pickle=False)
# Return a tuple describing this training example
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, wav, mel_frames, text
def _split_on_silences(wav_fpath, words, hparams):
# Load the audio waveform
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0]
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=1024)[0]
if hparams.rescale:
wav = wav / np.abs(wav).max() * hparams.rescaling_max
# denoise, we may not need it here.
@ -132,6 +123,15 @@ def preprocess_general(speaker_dir, out_dir: Path, skip_existing: bool, hparams,
continue
sub_basename = "%s_%02d" % (wav_fpath.name, 0)
wav, text = _split_on_silences(wav_fpath, words, hparams)
metadata.append(_process_utterance(wav, text, out_dir, sub_basename,
skip_existing, hparams, emotion_extract))
result = _process_utterance(wav, text, out_dir, sub_basename,
skip_existing, hparams, emotion_extract)
if result is None:
continue
wav_fpath_name, mel_fpath_name, embed_fpath_name, wav, mel_frames, text = result
emo_fpath = out_dir.joinpath("emo", "emo-%s.npy" % sub_basename)
skip_emo_extract = not emotion_extract or (skip_existing and emo_fpath.exists())
if not skip_emo_extract and wav is not None:
emo = extract_emo(np.expand_dims(wav, 0), hparams.sample_rate, True)
np.save(emo_fpath, emo.squeeze(0), allow_pickle=False)
metadata.append([wav_fpath_name, mel_fpath_name, embed_fpath_name, len(wav), mel_frames, text])
return [m for m in metadata if m is not None]

View File

@ -39,7 +39,7 @@ def new_train():
parser.add_argument("--syn_dir", type=str, default="../audiodata/SV2TTS/synthesizer", help= \
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
"the wavs, the emos and the embeds.")
parser.add_argument("-m", "--model_dir", type=str, default="data/ckpt/synthesizer/vits", help=\
parser.add_argument("-m", "--model_dir", type=str, default="data/ckpt/synthesizer/vits2", help=\
"Path to the output directory that will contain the saved model weights and the logs.")
parser.add_argument('--ckptG', type=str, required=False,
help='original VITS G checkpoint path')
@ -65,7 +65,7 @@ def new_train():
run(0, 1, hparams)
def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False):
def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False, epochs=10000):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
@ -89,8 +89,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False):
try:
new_state_dict[k] = saved_state_dict[k]
except:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if k == 'step':
new_state_dict[k] = iteration * epochs
else:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict, strict=False)
else:
@ -173,13 +177,13 @@ def run(rank, n_gpus, hps):
print("加载原版VITS模型G记录点成功")
else:
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
optim_g)
optim_g, epochs=hps.train.epochs)
if ckptD is not None:
_, _, _, epoch_str = load_checkpoint(ckptG, net_g, optim_g, is_old=True)
print("加载原版VITS模型D记录点成功")
else:
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
optim_d)
optim_d, epochs=hps.train.epochs)
global_step = (epoch_str - 1) * len(train_loader)
except:
epoch_str = 1
@ -216,17 +220,17 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train()
net_d.train()
for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers, emo) in enumerate(train_loader):
logger.info(f'====> Step: 1 {batch_idx}')
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
speakers = speakers.cuda(rank, non_blocking=True)
emo = emo.cuda(rank, non_blocking=True)
# logger.info(f'====> Step: 1 {batch_idx}')
x, x_lengths = x.cuda(rank), x_lengths.cuda(rank)
spec, spec_lengths = spec.cuda(rank), spec_lengths.cuda(rank)
y, y_lengths = y.cuda(rank), y_lengths.cuda(rank)
speakers = speakers.cuda(rank)
emo = emo.cuda(rank)
# logger.info(f'====> Step: 1.0 {batch_idx}')
with autocast(enabled=hps.train.fp16_run):
y_hat, l_length, attn, ids_slice, x_mask, z_mask, \
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers, emo)
# logger.info(f'====> Step: 1.1 {batch_idx}')
mel = spec_to_mel(
spec,
hps.data.filter_length,
@ -247,7 +251,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
)
y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
# logger.info(f'====> Step: 1.3 {batch_idx}')
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
@ -258,7 +262,6 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
scaler.unscale_(optim_d)
grad_norm_d = clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d)
logger.info(f'====> Step: 2 {batch_idx}')
with autocast(enabled=hps.train.fp16_run):
# Generator
@ -277,7 +280,6 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
grad_norm_g = clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
# logger.info(f'====> Step: 3 {batch_idx}')
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
@ -339,6 +341,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
emo = emo[:1]
break
y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, emo, max_len=1000)
# y_hat, attn, mask, *_ = generator.infer(x, x_lengths, speakers, emo, max_len=1000) # for non DistributedDataParallel object
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
mel = spec_to_mel(

View File

@ -4,7 +4,7 @@ import numpy as np
import torch
import torch.utils.data
from utils.audio_utils import spectrogram, load_wav
from utils.audio_utils import spectrogram1, load_wav_to_torch, spectrogram
from utils.util import intersperse
from models.synthesizer.utils.text import text_to_sequence
@ -57,6 +57,8 @@ class VitsDataset(torch.utils.data.Dataset):
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
# TODO: for magic data only
speaker_name = wav_fpath.split("_")[1]
# # TODO: for ai data only
# speaker_name = wav_fpath.split("-")[1][6:9]
if speaker_name not in spk_to_sid:
sid += 1
spk_to_sid[speaker_name] = sid
@ -71,36 +73,45 @@ class VitsDataset(torch.utils.data.Dataset):
# separate filename, speaker_id and text
wav_fpath, text, sid = audio_metadata[0], audio_metadata[5], audio_metadata[6]
text = self.get_text(text)
spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}audio{os.sep}{wav_fpath}')
# TODO: add original audio data root for loading
file_name = wav_fpath.split("_00")[0].split('-')[1]
spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}..{os.sep}..{os.sep}magicdata{os.sep}train{os.sep}{"_".join(file_name.split("_")[:2])}{os.sep}{file_name}')
# spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}audio{os.sep}{wav_fpath}')
sid = self.get_sid(sid)
emo = torch.FloatTensor(np.load(f'{self.datasets_root}{os.sep}emo{os.sep}{wav_fpath.replace("audio", "emo")}'))
return (text, spec, wav, sid, emo)
def get_audio(self, filename):
# audio, sampling_rate = load_wav(filename)
# if sampling_rate != self.sampling_rate:
# raise ValueError("{} {} SR doesn't match target {} SR".format(
# sampling_rate, self.sampling_rate))
# audio = torch.load(filename)
audio = torch.FloatTensor(np.load(filename).astype(np.float32))
audio = audio.unsqueeze(0)
# audio_norm = audio / self.max_wav_value
# audio_norm = audio_norm.unsqueeze(0)
# spec_filename = filename.replace(".wav", ".spec.pt")
# if os.path.exists(spec_filename):
# spec = torch.load(spec_filename)
# else:
# spec = spectrogram(audio, self.filter_length,
# self.sampling_rate, self.hop_length, self.win_length,
# center=False)
# spec = torch.squeeze(spec, 0)
# torch.save(spec, spec_filename)
spec = spectrogram(audio, self.filter_length, self.hop_length, self.win_length,
audio, sampling_rate = load_wav_to_torch(filename)
if sampling_rate != self.sampling_rate:
raise ValueError("{} {} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate))
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram(audio_norm, self.filter_length, self.hop_length, self.win_length,
center=False)
spec = torch.squeeze(spec, 0)
return spec, audio
return spec, audio_norm
# print("Loading", filename)
# # audio = torch.FloatTensor(np.load(filename).astype(np.float32))
# audio = audio.unsqueeze(0)
# audio_norm = audio / self.max_wav_value
# audio_norm = audio_norm.unsqueeze(0)
# # spec_filename = filename.replace(".wav", ".spec.pt")
# # if os.path.exists(spec_filename):
# # spec = torch.load(spec_filename)
# # else:
# # spec = spectrogram(audio, self.filter_length,self.hop_length, self.win_length,
# # center=False)
# # spec = torch.squeeze(spec, 0)
# # torch.save(spec, spec_filename)
# spec = spectrogram(audio, self.filter_length, self.hop_length, self.win_length,
# center=False)
# spec = torch.squeeze(spec, 0)
# return spec, audio
def get_text(self, text):
if self.cleaned_text:

View File

@ -17,8 +17,7 @@ def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def spectrogram(y, n_fft, hop_size, win_size, center=False):
def spectrogram1(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
@ -34,7 +33,29 @@ def spectrogram(y, n_fft, hop_size, win_size, center=False):
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True)
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def spectrogram(y, n_fft, hop_size, win_size, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec

237
vits.ipynb vendored

File diff suppressed because one or more lines are too long