Fix issue for training and preprocessing

This commit is contained in:
babysor00 2023-02-10 20:34:01 +08:00
parent beec0b93ed
commit 3ce874ab46
8 changed files with 157 additions and 274 deletions

10
.vscode/launch.json vendored
View File

@ -64,6 +64,14 @@
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml", "args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\" "-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
] ]
} },
{
"name": "Python: Vits Train",
"type": "python",
"request": "launch",
"program": "train.py",
"console": "integratedTerminal",
"args": ["--type", "vits"]
},
] ]
} }

View File

@ -3,10 +3,10 @@ from utils.hparams import HParams
hparams = HParams( hparams = HParams(
### Signal Processing (used in both synthesizer and vocoder) ### Signal Processing (used in both synthesizer and vocoder)
sample_rate = 16000, sample_rate = 16000,
n_fft = 800, n_fft = 1024, # filter_length
num_mels = 80, num_mels = 80,
hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125) hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050) win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
fmin = 55, fmin = 55,
min_level_db = -100, min_level_db = -100,
ref_level_db = 20, ref_level_db = 20,
@ -67,7 +67,7 @@ hparams = HParams(
use_lws = False, # "Fast spectrogram phase recovery using local weighted sums" use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True, symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
# and [0, max_abs_value] if False # and [0, max_abs_value] if False
trim_silence = True, # Use with sample_rate of 16000 for best results trim_silence = False, # Use with sample_rate of 16000 for best results
### SV2TTS ### SV2TTS
speaker_embedding_size = 256, # Dimension for the speaker embedding speaker_embedding_size = 256, # Dimension for the speaker embedding

View File

@ -2,12 +2,12 @@ import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from loguru import logger
from .sublayer.vits_modules import * from .sublayer.vits_modules import *
import monotonic_align import monotonic_align
from .base import Base from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils.util import init_weights, get_padding, sequence_mask, rand_slice_segments, generate_path from utils.util import init_weights, get_padding, sequence_mask, rand_slice_segments, generate_path
@ -386,7 +386,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
return y_d_rs, y_d_gs, fmap_rs, fmap_gs return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class Vits(Base): class Vits(nn.Module):
""" """
Synthesizer of Vits Synthesizer of Vits
""" """
@ -408,13 +408,12 @@ class Vits(Base):
upsample_rates, upsample_rates,
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
stop_threshold,
n_speakers=0, n_speakers=0,
gin_channels=0, gin_channels=0,
use_sdp=True, use_sdp=True,
**kwargs): **kwargs):
super().__init__(stop_threshold) super().__init__()
self.n_vocab = n_vocab self.n_vocab = n_vocab
self.spec_channels = spec_channels self.spec_channels = spec_channels
self.inter_channels = inter_channels self.inter_channels = inter_channels
@ -457,7 +456,7 @@ class Vits(Base):
self.emb_g = nn.Embedding(n_speakers, gin_channels) self.emb_g = nn.Embedding(n_speakers, gin_channels)
def forward(self, x, x_lengths, y, y_lengths, sid=None, emo=None): def forward(self, x, x_lengths, y, y_lengths, sid=None, emo=None):
# logger.info(f'====> Forward: 1.1.0')
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emo) x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emo)
if self.n_speakers > 0: if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
@ -466,7 +465,7 @@ class Vits(Base):
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g) z_p = self.flow(z, y_mask, g=g)
# logger.info(f'====> Forward: 1.1.1')
with torch.no_grad(): with torch.no_grad():
# negative cross-entropy # negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
@ -475,10 +474,11 @@ class Vits(Base):
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
#logger.info(f'====> Forward: 1.1.1.1')
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
# logger.info(f'====> Forward: 1.1.2')
w = attn.sum(2) w = attn.sum(2)
if self.use_sdp: if self.use_sdp:
l_length = self.dp(x, x_mask, w, g=g) l_length = self.dp(x, x_mask, w, g=g)
@ -487,7 +487,6 @@ class Vits(Base):
logw_ = torch.log(w + 1e-6) * x_mask logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g) logw = self.dp(x, x_mask, g=g)
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
# expand prior # expand prior
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
@ -497,7 +496,9 @@ class Vits(Base):
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
def infer(self, x, x_lengths, sid=None, emo=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): def infer(self, x, x_lengths, sid=None, emo=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
# logger.info(f'====> Infer: 1.1.0')
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths,emo) x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths,emo)
# logger.info(f'====> Infer: 1.1.1')
if self.n_speakers > 0: if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else: else:
@ -514,11 +515,14 @@ class Vits(Base):
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = generate_path(w_ceil, attn_mask) attn = generate_path(w_ceil, attn_mask)
# logger.info(f'====> Infer: 1.1.2')
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True) z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:,:,:max_len], g=g) o = self.dec((z * y_mask)[:,:,:max_len], g=g)
# logger.info(f'====> Infer: 1.1.3')
return o, attn, y_mask, (z, z_p, m_p, logs_p) return o, attn, y_mask, (z, z_p, m_p, logs_p)

View File

@ -20,8 +20,6 @@ device = 'cuda' if torch.cuda.is_available() else "cpu"
model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim' model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
processor = Wav2Vec2Processor.from_pretrained(model_name) processor = Wav2Vec2Processor.from_pretrained(model_name)
model = EmotionExtractorModel.from_pretrained(model_name).to(device) model = EmotionExtractorModel.from_pretrained(model_name).to(device)
embs = []
wavnames = []
def extract_emo( def extract_emo(
x: np.ndarray, x: np.ndarray,
@ -48,8 +46,6 @@ class PinyinConverter(NeutralToneWith5Mixin, DefaultConverter):
pinyin = Pinyin(PinyinConverter()).pinyin pinyin = Pinyin(PinyinConverter()).pinyin
def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str, def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
skip_existing: bool, hparams, emotion_extract: bool): skip_existing: bool, hparams, emotion_extract: bool):
## FOR REFERENCE: ## FOR REFERENCE:
@ -67,9 +63,8 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
# Skip existing utterances if needed # Skip existing utterances if needed
mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename) mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename) wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
emo_fpath = out_dir.joinpath("emo", "emo-%s.npy" % basename)
skip_emo_extract = not emotion_extract or (skip_existing and emo_fpath.exists()) if skip_existing and mel_fpath.exists() and wav_fpath.exists():
if skip_existing and mel_fpath.exists() and wav_fpath.exists() and skip_emo_extract:
return None return None
# Trim silence # Trim silence
@ -91,18 +86,14 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False) np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
np.save(wav_fpath, wav, allow_pickle=False) np.save(wav_fpath, wav, allow_pickle=False)
if not skip_emo_extract:
emo = extract_emo(np.expand_dims(wav, 0), hparams.sample_rate, True)
np.save(emo_fpath, emo, allow_pickle=False)
# Return a tuple describing this training example # Return a tuple describing this training example
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, wav, mel_frames, text
def _split_on_silences(wav_fpath, words, hparams): def _split_on_silences(wav_fpath, words, hparams):
# Load the audio waveform # Load the audio waveform
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate) wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0] wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=1024)[0]
if hparams.rescale: if hparams.rescale:
wav = wav / np.abs(wav).max() * hparams.rescaling_max wav = wav / np.abs(wav).max() * hparams.rescaling_max
# denoise, we may not need it here. # denoise, we may not need it here.
@ -132,6 +123,15 @@ def preprocess_general(speaker_dir, out_dir: Path, skip_existing: bool, hparams,
continue continue
sub_basename = "%s_%02d" % (wav_fpath.name, 0) sub_basename = "%s_%02d" % (wav_fpath.name, 0)
wav, text = _split_on_silences(wav_fpath, words, hparams) wav, text = _split_on_silences(wav_fpath, words, hparams)
metadata.append(_process_utterance(wav, text, out_dir, sub_basename, result = _process_utterance(wav, text, out_dir, sub_basename,
skip_existing, hparams, emotion_extract)) skip_existing, hparams, emotion_extract)
if result is None:
continue
wav_fpath_name, mel_fpath_name, embed_fpath_name, wav, mel_frames, text = result
emo_fpath = out_dir.joinpath("emo", "emo-%s.npy" % sub_basename)
skip_emo_extract = not emotion_extract or (skip_existing and emo_fpath.exists())
if not skip_emo_extract and wav is not None:
emo = extract_emo(np.expand_dims(wav, 0), hparams.sample_rate, True)
np.save(emo_fpath, emo.squeeze(0), allow_pickle=False)
metadata.append([wav_fpath_name, mel_fpath_name, embed_fpath_name, len(wav), mel_frames, text])
return [m for m in metadata if m is not None] return [m for m in metadata if m is not None]

View File

@ -39,7 +39,7 @@ def new_train():
parser.add_argument("--syn_dir", type=str, default="../audiodata/SV2TTS/synthesizer", help= \ parser.add_argument("--syn_dir", type=str, default="../audiodata/SV2TTS/synthesizer", help= \
"Path to the synthesizer directory that contains the ground truth mel spectrograms, " "Path to the synthesizer directory that contains the ground truth mel spectrograms, "
"the wavs, the emos and the embeds.") "the wavs, the emos and the embeds.")
parser.add_argument("-m", "--model_dir", type=str, default="data/ckpt/synthesizer/vits", help=\ parser.add_argument("-m", "--model_dir", type=str, default="data/ckpt/synthesizer/vits2", help=\
"Path to the output directory that will contain the saved model weights and the logs.") "Path to the output directory that will contain the saved model weights and the logs.")
parser.add_argument('--ckptG', type=str, required=False, parser.add_argument('--ckptG', type=str, required=False,
help='original VITS G checkpoint path') help='original VITS G checkpoint path')
@ -65,7 +65,7 @@ def new_train():
run(0, 1, hparams) run(0, 1, hparams)
def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False): def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False, epochs=10000):
assert os.path.isfile(checkpoint_path) assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration'] iteration = checkpoint_dict['iteration']
@ -89,8 +89,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False):
try: try:
new_state_dict[k] = saved_state_dict[k] new_state_dict[k] = saved_state_dict[k]
except: except:
logger.info("%s is not in the checkpoint" % k) if k == 'step':
new_state_dict[k] = v new_state_dict[k] = iteration * epochs
else:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'): if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict, strict=False) model.module.load_state_dict(new_state_dict, strict=False)
else: else:
@ -173,13 +177,13 @@ def run(rank, n_gpus, hps):
print("加载原版VITS模型G记录点成功") print("加载原版VITS模型G记录点成功")
else: else:
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
optim_g) optim_g, epochs=hps.train.epochs)
if ckptD is not None: if ckptD is not None:
_, _, _, epoch_str = load_checkpoint(ckptG, net_g, optim_g, is_old=True) _, _, _, epoch_str = load_checkpoint(ckptG, net_g, optim_g, is_old=True)
print("加载原版VITS模型D记录点成功") print("加载原版VITS模型D记录点成功")
else: else:
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
optim_d) optim_d, epochs=hps.train.epochs)
global_step = (epoch_str - 1) * len(train_loader) global_step = (epoch_str - 1) * len(train_loader)
except: except:
epoch_str = 1 epoch_str = 1
@ -216,17 +220,17 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train() net_g.train()
net_d.train() net_d.train()
for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers, emo) in enumerate(train_loader): for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers, emo) in enumerate(train_loader):
logger.info(f'====> Step: 1 {batch_idx}') # logger.info(f'====> Step: 1 {batch_idx}')
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True) x, x_lengths = x.cuda(rank), x_lengths.cuda(rank)
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) spec, spec_lengths = spec.cuda(rank), spec_lengths.cuda(rank)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) y, y_lengths = y.cuda(rank), y_lengths.cuda(rank)
speakers = speakers.cuda(rank, non_blocking=True) speakers = speakers.cuda(rank)
emo = emo.cuda(rank, non_blocking=True) emo = emo.cuda(rank)
# logger.info(f'====> Step: 1.0 {batch_idx}')
with autocast(enabled=hps.train.fp16_run): with autocast(enabled=hps.train.fp16_run):
y_hat, l_length, attn, ids_slice, x_mask, z_mask, \ y_hat, l_length, attn, ids_slice, x_mask, z_mask, \
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers, emo) (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers, emo)
# logger.info(f'====> Step: 1.1 {batch_idx}')
mel = spec_to_mel( mel = spec_to_mel(
spec, spec,
hps.data.filter_length, hps.data.filter_length,
@ -247,7 +251,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
) )
y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
# logger.info(f'====> Step: 1.3 {batch_idx}')
# Discriminator # Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False): with autocast(enabled=False):
@ -258,7 +262,6 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
scaler.unscale_(optim_d) scaler.unscale_(optim_d)
grad_norm_d = clip_grad_value_(net_d.parameters(), None) grad_norm_d = clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d) scaler.step(optim_d)
logger.info(f'====> Step: 2 {batch_idx}')
with autocast(enabled=hps.train.fp16_run): with autocast(enabled=hps.train.fp16_run):
# Generator # Generator
@ -277,7 +280,6 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
grad_norm_g = clip_grad_value_(net_g.parameters(), None) grad_norm_g = clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g) scaler.step(optim_g)
scaler.update() scaler.update()
# logger.info(f'====> Step: 3 {batch_idx}')
if rank == 0: if rank == 0:
if global_step % hps.train.log_interval == 0: if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr'] lr = optim_g.param_groups[0]['lr']
@ -339,6 +341,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
emo = emo[:1] emo = emo[:1]
break break
y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, emo, max_len=1000) y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, emo, max_len=1000)
# y_hat, attn, mask, *_ = generator.infer(x, x_lengths, speakers, emo, max_len=1000) # for non DistributedDataParallel object
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
mel = spec_to_mel( mel = spec_to_mel(

View File

@ -4,7 +4,7 @@ import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
from utils.audio_utils import spectrogram, load_wav from utils.audio_utils import spectrogram1, load_wav_to_torch, spectrogram
from utils.util import intersperse from utils.util import intersperse
from models.synthesizer.utils.text import text_to_sequence from models.synthesizer.utils.text import text_to_sequence
@ -57,6 +57,8 @@ class VitsDataset(torch.utils.data.Dataset):
if self.min_text_len <= len(text) and len(text) <= self.max_text_len: if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
# TODO: for magic data only # TODO: for magic data only
speaker_name = wav_fpath.split("_")[1] speaker_name = wav_fpath.split("_")[1]
# # TODO: for ai data only
# speaker_name = wav_fpath.split("-")[1][6:9]
if speaker_name not in spk_to_sid: if speaker_name not in spk_to_sid:
sid += 1 sid += 1
spk_to_sid[speaker_name] = sid spk_to_sid[speaker_name] = sid
@ -72,35 +74,44 @@ class VitsDataset(torch.utils.data.Dataset):
wav_fpath, text, sid = audio_metadata[0], audio_metadata[5], audio_metadata[6] wav_fpath, text, sid = audio_metadata[0], audio_metadata[5], audio_metadata[6]
text = self.get_text(text) text = self.get_text(text)
spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}audio{os.sep}{wav_fpath}') # TODO: add original audio data root for loading
file_name = wav_fpath.split("_00")[0].split('-')[1]
spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}..{os.sep}..{os.sep}magicdata{os.sep}train{os.sep}{"_".join(file_name.split("_")[:2])}{os.sep}{file_name}')
# spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}audio{os.sep}{wav_fpath}')
sid = self.get_sid(sid) sid = self.get_sid(sid)
emo = torch.FloatTensor(np.load(f'{self.datasets_root}{os.sep}emo{os.sep}{wav_fpath.replace("audio", "emo")}')) emo = torch.FloatTensor(np.load(f'{self.datasets_root}{os.sep}emo{os.sep}{wav_fpath.replace("audio", "emo")}'))
return (text, spec, wav, sid, emo) return (text, spec, wav, sid, emo)
def get_audio(self, filename): def get_audio(self, filename):
# audio, sampling_rate = load_wav(filename) audio, sampling_rate = load_wav_to_torch(filename)
if sampling_rate != self.sampling_rate:
# if sampling_rate != self.sampling_rate: raise ValueError("{} {} SR doesn't match target {} SR".format(
# raise ValueError("{} {} SR doesn't match target {} SR".format( sampling_rate, self.sampling_rate))
# sampling_rate, self.sampling_rate)) audio_norm = audio / self.max_wav_value
# audio = torch.load(filename) audio_norm = audio_norm.unsqueeze(0)
audio = torch.FloatTensor(np.load(filename).astype(np.float32)) spec = spectrogram(audio_norm, self.filter_length, self.hop_length, self.win_length,
audio = audio.unsqueeze(0)
# audio_norm = audio / self.max_wav_value
# audio_norm = audio_norm.unsqueeze(0)
# spec_filename = filename.replace(".wav", ".spec.pt")
# if os.path.exists(spec_filename):
# spec = torch.load(spec_filename)
# else:
# spec = spectrogram(audio, self.filter_length,
# self.sampling_rate, self.hop_length, self.win_length,
# center=False)
# spec = torch.squeeze(spec, 0)
# torch.save(spec, spec_filename)
spec = spectrogram(audio, self.filter_length, self.hop_length, self.win_length,
center=False) center=False)
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
return spec, audio return spec, audio_norm
# print("Loading", filename)
# # audio = torch.FloatTensor(np.load(filename).astype(np.float32))
# audio = audio.unsqueeze(0)
# audio_norm = audio / self.max_wav_value
# audio_norm = audio_norm.unsqueeze(0)
# # spec_filename = filename.replace(".wav", ".spec.pt")
# # if os.path.exists(spec_filename):
# # spec = torch.load(spec_filename)
# # else:
# # spec = spectrogram(audio, self.filter_length,self.hop_length, self.win_length,
# # center=False)
# # spec = torch.squeeze(spec, 0)
# # torch.save(spec, spec_filename)
# spec = spectrogram(audio, self.filter_length, self.hop_length, self.win_length,
# center=False)
# spec = torch.squeeze(spec, 0)
# return spec, audio
def get_text(self, text): def get_text(self, text):
if self.cleaned_text: if self.cleaned_text:

View File

@ -17,6 +17,27 @@ def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path) sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def spectrogram1(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def spectrogram(y, n_fft, hop_size, win_size, center=False): def spectrogram(y, n_fft, hop_size, win_size, center=False):
if torch.min(y) < -1.: if torch.min(y) < -1.:
@ -34,7 +55,7 @@ def spectrogram(y, n_fft, hop_size, win_size, center=False):
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True) center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec return spec

237
vits.ipynb vendored

File diff suppressed because one or more lines are too long