mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Fix issue for training and preprocessing
This commit is contained in:
parent
beec0b93ed
commit
3ce874ab46
10
.vscode/launch.json
vendored
10
.vscode/launch.json
vendored
|
@ -64,6 +64,14 @@
|
||||||
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
|
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
|
||||||
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"name": "Python: Vits Train",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "train.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": ["--type", "vits"]
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,10 +3,10 @@ from utils.hparams import HParams
|
||||||
hparams = HParams(
|
hparams = HParams(
|
||||||
### Signal Processing (used in both synthesizer and vocoder)
|
### Signal Processing (used in both synthesizer and vocoder)
|
||||||
sample_rate = 16000,
|
sample_rate = 16000,
|
||||||
n_fft = 800,
|
n_fft = 1024, # filter_length
|
||||||
num_mels = 80,
|
num_mels = 80,
|
||||||
hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
|
hop_size = 256, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
|
||||||
win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
|
win_size = 1024, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
|
||||||
fmin = 55,
|
fmin = 55,
|
||||||
min_level_db = -100,
|
min_level_db = -100,
|
||||||
ref_level_db = 20,
|
ref_level_db = 20,
|
||||||
|
@ -67,7 +67,7 @@ hparams = HParams(
|
||||||
use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
|
use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
|
||||||
symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
|
symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
|
||||||
# and [0, max_abs_value] if False
|
# and [0, max_abs_value] if False
|
||||||
trim_silence = True, # Use with sample_rate of 16000 for best results
|
trim_silence = False, # Use with sample_rate of 16000 for best results
|
||||||
|
|
||||||
### SV2TTS
|
### SV2TTS
|
||||||
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
||||||
|
|
|
@ -2,12 +2,12 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from .sublayer.vits_modules import *
|
from .sublayer.vits_modules import *
|
||||||
import monotonic_align
|
import monotonic_align
|
||||||
|
|
||||||
from .base import Base
|
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
|
||||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||||
from utils.util import init_weights, get_padding, sequence_mask, rand_slice_segments, generate_path
|
from utils.util import init_weights, get_padding, sequence_mask, rand_slice_segments, generate_path
|
||||||
|
|
||||||
|
@ -386,7 +386,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
|
||||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
class Vits(Base):
|
class Vits(nn.Module):
|
||||||
"""
|
"""
|
||||||
Synthesizer of Vits
|
Synthesizer of Vits
|
||||||
"""
|
"""
|
||||||
|
@ -408,13 +408,12 @@ class Vits(Base):
|
||||||
upsample_rates,
|
upsample_rates,
|
||||||
upsample_initial_channel,
|
upsample_initial_channel,
|
||||||
upsample_kernel_sizes,
|
upsample_kernel_sizes,
|
||||||
stop_threshold,
|
|
||||||
n_speakers=0,
|
n_speakers=0,
|
||||||
gin_channels=0,
|
gin_channels=0,
|
||||||
use_sdp=True,
|
use_sdp=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
super().__init__(stop_threshold)
|
super().__init__()
|
||||||
self.n_vocab = n_vocab
|
self.n_vocab = n_vocab
|
||||||
self.spec_channels = spec_channels
|
self.spec_channels = spec_channels
|
||||||
self.inter_channels = inter_channels
|
self.inter_channels = inter_channels
|
||||||
|
@ -457,7 +456,7 @@ class Vits(Base):
|
||||||
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, y, y_lengths, sid=None, emo=None):
|
def forward(self, x, x_lengths, y, y_lengths, sid=None, emo=None):
|
||||||
|
# logger.info(f'====> Forward: 1.1.0')
|
||||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emo)
|
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emo)
|
||||||
if self.n_speakers > 0:
|
if self.n_speakers > 0:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
@ -466,7 +465,7 @@ class Vits(Base):
|
||||||
|
|
||||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
# logger.info(f'====> Forward: 1.1.1')
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# negative cross-entropy
|
# negative cross-entropy
|
||||||
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
||||||
|
@ -475,10 +474,11 @@ class Vits(Base):
|
||||||
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||||
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
|
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
|
||||||
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
||||||
|
#logger.info(f'====> Forward: 1.1.1.1')
|
||||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
|
||||||
|
# logger.info(f'====> Forward: 1.1.2')
|
||||||
w = attn.sum(2)
|
w = attn.sum(2)
|
||||||
if self.use_sdp:
|
if self.use_sdp:
|
||||||
l_length = self.dp(x, x_mask, w, g=g)
|
l_length = self.dp(x, x_mask, w, g=g)
|
||||||
|
@ -487,7 +487,6 @@ class Vits(Base):
|
||||||
logw_ = torch.log(w + 1e-6) * x_mask
|
logw_ = torch.log(w + 1e-6) * x_mask
|
||||||
logw = self.dp(x, x_mask, g=g)
|
logw = self.dp(x, x_mask, g=g)
|
||||||
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
|
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
||||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
@ -497,7 +496,9 @@ class Vits(Base):
|
||||||
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||||
|
|
||||||
def infer(self, x, x_lengths, sid=None, emo=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
|
def infer(self, x, x_lengths, sid=None, emo=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
|
||||||
|
# logger.info(f'====> Infer: 1.1.0')
|
||||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths,emo)
|
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths,emo)
|
||||||
|
# logger.info(f'====> Infer: 1.1.1')
|
||||||
if self.n_speakers > 0:
|
if self.n_speakers > 0:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
else:
|
else:
|
||||||
|
@ -514,11 +515,14 @@ class Vits(Base):
|
||||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||||
attn = generate_path(w_ceil, attn_mask)
|
attn = generate_path(w_ceil, attn_mask)
|
||||||
|
|
||||||
|
# logger.info(f'====> Infer: 1.1.2')
|
||||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||||
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
||||||
|
|
||||||
|
# logger.info(f'====> Infer: 1.1.3')
|
||||||
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,6 @@ device = 'cuda' if torch.cuda.is_available() else "cpu"
|
||||||
model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
|
model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
|
||||||
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
||||||
model = EmotionExtractorModel.from_pretrained(model_name).to(device)
|
model = EmotionExtractorModel.from_pretrained(model_name).to(device)
|
||||||
embs = []
|
|
||||||
wavnames = []
|
|
||||||
|
|
||||||
def extract_emo(
|
def extract_emo(
|
||||||
x: np.ndarray,
|
x: np.ndarray,
|
||||||
|
@ -48,8 +46,6 @@ class PinyinConverter(NeutralToneWith5Mixin, DefaultConverter):
|
||||||
pinyin = Pinyin(PinyinConverter()).pinyin
|
pinyin = Pinyin(PinyinConverter()).pinyin
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
||||||
skip_existing: bool, hparams, emotion_extract: bool):
|
skip_existing: bool, hparams, emotion_extract: bool):
|
||||||
## FOR REFERENCE:
|
## FOR REFERENCE:
|
||||||
|
@ -67,9 +63,8 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
||||||
# Skip existing utterances if needed
|
# Skip existing utterances if needed
|
||||||
mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
|
mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
|
||||||
wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
|
wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
|
||||||
emo_fpath = out_dir.joinpath("emo", "emo-%s.npy" % basename)
|
|
||||||
skip_emo_extract = not emotion_extract or (skip_existing and emo_fpath.exists())
|
if skip_existing and mel_fpath.exists() and wav_fpath.exists():
|
||||||
if skip_existing and mel_fpath.exists() and wav_fpath.exists() and skip_emo_extract:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Trim silence
|
# Trim silence
|
||||||
|
@ -91,18 +86,14 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
||||||
np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
|
np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
|
||||||
np.save(wav_fpath, wav, allow_pickle=False)
|
np.save(wav_fpath, wav, allow_pickle=False)
|
||||||
|
|
||||||
if not skip_emo_extract:
|
|
||||||
emo = extract_emo(np.expand_dims(wav, 0), hparams.sample_rate, True)
|
|
||||||
np.save(emo_fpath, emo, allow_pickle=False)
|
|
||||||
|
|
||||||
# Return a tuple describing this training example
|
# Return a tuple describing this training example
|
||||||
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
|
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, wav, mel_frames, text
|
||||||
|
|
||||||
|
|
||||||
def _split_on_silences(wav_fpath, words, hparams):
|
def _split_on_silences(wav_fpath, words, hparams):
|
||||||
# Load the audio waveform
|
# Load the audio waveform
|
||||||
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
|
wav, _ = librosa.load(wav_fpath, sr= hparams.sample_rate)
|
||||||
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=512)[0]
|
wav = librosa.effects.trim(wav, top_db= 40, frame_length=2048, hop_length=1024)[0]
|
||||||
if hparams.rescale:
|
if hparams.rescale:
|
||||||
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
||||||
# denoise, we may not need it here.
|
# denoise, we may not need it here.
|
||||||
|
@ -132,6 +123,15 @@ def preprocess_general(speaker_dir, out_dir: Path, skip_existing: bool, hparams,
|
||||||
continue
|
continue
|
||||||
sub_basename = "%s_%02d" % (wav_fpath.name, 0)
|
sub_basename = "%s_%02d" % (wav_fpath.name, 0)
|
||||||
wav, text = _split_on_silences(wav_fpath, words, hparams)
|
wav, text = _split_on_silences(wav_fpath, words, hparams)
|
||||||
metadata.append(_process_utterance(wav, text, out_dir, sub_basename,
|
result = _process_utterance(wav, text, out_dir, sub_basename,
|
||||||
skip_existing, hparams, emotion_extract))
|
skip_existing, hparams, emotion_extract)
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
wav_fpath_name, mel_fpath_name, embed_fpath_name, wav, mel_frames, text = result
|
||||||
|
emo_fpath = out_dir.joinpath("emo", "emo-%s.npy" % sub_basename)
|
||||||
|
skip_emo_extract = not emotion_extract or (skip_existing and emo_fpath.exists())
|
||||||
|
if not skip_emo_extract and wav is not None:
|
||||||
|
emo = extract_emo(np.expand_dims(wav, 0), hparams.sample_rate, True)
|
||||||
|
np.save(emo_fpath, emo.squeeze(0), allow_pickle=False)
|
||||||
|
metadata.append([wav_fpath_name, mel_fpath_name, embed_fpath_name, len(wav), mel_frames, text])
|
||||||
return [m for m in metadata if m is not None]
|
return [m for m in metadata if m is not None]
|
||||||
|
|
|
@ -39,7 +39,7 @@ def new_train():
|
||||||
parser.add_argument("--syn_dir", type=str, default="../audiodata/SV2TTS/synthesizer", help= \
|
parser.add_argument("--syn_dir", type=str, default="../audiodata/SV2TTS/synthesizer", help= \
|
||||||
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
||||||
"the wavs, the emos and the embeds.")
|
"the wavs, the emos and the embeds.")
|
||||||
parser.add_argument("-m", "--model_dir", type=str, default="data/ckpt/synthesizer/vits", help=\
|
parser.add_argument("-m", "--model_dir", type=str, default="data/ckpt/synthesizer/vits2", help=\
|
||||||
"Path to the output directory that will contain the saved model weights and the logs.")
|
"Path to the output directory that will contain the saved model weights and the logs.")
|
||||||
parser.add_argument('--ckptG', type=str, required=False,
|
parser.add_argument('--ckptG', type=str, required=False,
|
||||||
help='original VITS G checkpoint path')
|
help='original VITS G checkpoint path')
|
||||||
|
@ -65,7 +65,7 @@ def new_train():
|
||||||
run(0, 1, hparams)
|
run(0, 1, hparams)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False):
|
def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False, epochs=10000):
|
||||||
assert os.path.isfile(checkpoint_path)
|
assert os.path.isfile(checkpoint_path)
|
||||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||||
iteration = checkpoint_dict['iteration']
|
iteration = checkpoint_dict['iteration']
|
||||||
|
@ -89,8 +89,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False):
|
||||||
try:
|
try:
|
||||||
new_state_dict[k] = saved_state_dict[k]
|
new_state_dict[k] = saved_state_dict[k]
|
||||||
except:
|
except:
|
||||||
|
if k == 'step':
|
||||||
|
new_state_dict[k] = iteration * epochs
|
||||||
|
else:
|
||||||
logger.info("%s is not in the checkpoint" % k)
|
logger.info("%s is not in the checkpoint" % k)
|
||||||
new_state_dict[k] = v
|
new_state_dict[k] = v
|
||||||
|
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model.module.load_state_dict(new_state_dict, strict=False)
|
model.module.load_state_dict(new_state_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
|
@ -173,13 +177,13 @@ def run(rank, n_gpus, hps):
|
||||||
print("加载原版VITS模型G记录点成功")
|
print("加载原版VITS模型G记录点成功")
|
||||||
else:
|
else:
|
||||||
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
|
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
|
||||||
optim_g)
|
optim_g, epochs=hps.train.epochs)
|
||||||
if ckptD is not None:
|
if ckptD is not None:
|
||||||
_, _, _, epoch_str = load_checkpoint(ckptG, net_g, optim_g, is_old=True)
|
_, _, _, epoch_str = load_checkpoint(ckptG, net_g, optim_g, is_old=True)
|
||||||
print("加载原版VITS模型D记录点成功")
|
print("加载原版VITS模型D记录点成功")
|
||||||
else:
|
else:
|
||||||
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
|
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
|
||||||
optim_d)
|
optim_d, epochs=hps.train.epochs)
|
||||||
global_step = (epoch_str - 1) * len(train_loader)
|
global_step = (epoch_str - 1) * len(train_loader)
|
||||||
except:
|
except:
|
||||||
epoch_str = 1
|
epoch_str = 1
|
||||||
|
@ -216,17 +220,17 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||||
net_g.train()
|
net_g.train()
|
||||||
net_d.train()
|
net_d.train()
|
||||||
for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers, emo) in enumerate(train_loader):
|
for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers, emo) in enumerate(train_loader):
|
||||||
logger.info(f'====> Step: 1 {batch_idx}')
|
# logger.info(f'====> Step: 1 {batch_idx}')
|
||||||
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True)
|
x, x_lengths = x.cuda(rank), x_lengths.cuda(rank)
|
||||||
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
|
spec, spec_lengths = spec.cuda(rank), spec_lengths.cuda(rank)
|
||||||
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True)
|
y, y_lengths = y.cuda(rank), y_lengths.cuda(rank)
|
||||||
speakers = speakers.cuda(rank, non_blocking=True)
|
speakers = speakers.cuda(rank)
|
||||||
emo = emo.cuda(rank, non_blocking=True)
|
emo = emo.cuda(rank)
|
||||||
|
# logger.info(f'====> Step: 1.0 {batch_idx}')
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
y_hat, l_length, attn, ids_slice, x_mask, z_mask, \
|
y_hat, l_length, attn, ids_slice, x_mask, z_mask, \
|
||||||
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers, emo)
|
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers, emo)
|
||||||
|
# logger.info(f'====> Step: 1.1 {batch_idx}')
|
||||||
mel = spec_to_mel(
|
mel = spec_to_mel(
|
||||||
spec,
|
spec,
|
||||||
hps.data.filter_length,
|
hps.data.filter_length,
|
||||||
|
@ -247,7 +251,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||||
)
|
)
|
||||||
|
|
||||||
y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
||||||
|
# logger.info(f'====> Step: 1.3 {batch_idx}')
|
||||||
# Discriminator
|
# Discriminator
|
||||||
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
|
@ -258,7 +262,6 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||||
scaler.unscale_(optim_d)
|
scaler.unscale_(optim_d)
|
||||||
grad_norm_d = clip_grad_value_(net_d.parameters(), None)
|
grad_norm_d = clip_grad_value_(net_d.parameters(), None)
|
||||||
scaler.step(optim_d)
|
scaler.step(optim_d)
|
||||||
logger.info(f'====> Step: 2 {batch_idx}')
|
|
||||||
|
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
# Generator
|
# Generator
|
||||||
|
@ -277,7 +280,6 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||||
grad_norm_g = clip_grad_value_(net_g.parameters(), None)
|
grad_norm_g = clip_grad_value_(net_g.parameters(), None)
|
||||||
scaler.step(optim_g)
|
scaler.step(optim_g)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
# logger.info(f'====> Step: 3 {batch_idx}')
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if global_step % hps.train.log_interval == 0:
|
if global_step % hps.train.log_interval == 0:
|
||||||
lr = optim_g.param_groups[0]['lr']
|
lr = optim_g.param_groups[0]['lr']
|
||||||
|
@ -339,6 +341,8 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
||||||
emo = emo[:1]
|
emo = emo[:1]
|
||||||
break
|
break
|
||||||
y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, emo, max_len=1000)
|
y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, emo, max_len=1000)
|
||||||
|
# y_hat, attn, mask, *_ = generator.infer(x, x_lengths, speakers, emo, max_len=1000) # for non DistributedDataParallel object
|
||||||
|
|
||||||
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
||||||
|
|
||||||
mel = spec_to_mel(
|
mel = spec_to_mel(
|
||||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
from utils.audio_utils import spectrogram, load_wav
|
from utils.audio_utils import spectrogram1, load_wav_to_torch, spectrogram
|
||||||
from utils.util import intersperse
|
from utils.util import intersperse
|
||||||
from models.synthesizer.utils.text import text_to_sequence
|
from models.synthesizer.utils.text import text_to_sequence
|
||||||
|
|
||||||
|
@ -57,6 +57,8 @@ class VitsDataset(torch.utils.data.Dataset):
|
||||||
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
||||||
# TODO: for magic data only
|
# TODO: for magic data only
|
||||||
speaker_name = wav_fpath.split("_")[1]
|
speaker_name = wav_fpath.split("_")[1]
|
||||||
|
# # TODO: for ai data only
|
||||||
|
# speaker_name = wav_fpath.split("-")[1][6:9]
|
||||||
if speaker_name not in spk_to_sid:
|
if speaker_name not in spk_to_sid:
|
||||||
sid += 1
|
sid += 1
|
||||||
spk_to_sid[speaker_name] = sid
|
spk_to_sid[speaker_name] = sid
|
||||||
|
@ -72,35 +74,44 @@ class VitsDataset(torch.utils.data.Dataset):
|
||||||
wav_fpath, text, sid = audio_metadata[0], audio_metadata[5], audio_metadata[6]
|
wav_fpath, text, sid = audio_metadata[0], audio_metadata[5], audio_metadata[6]
|
||||||
text = self.get_text(text)
|
text = self.get_text(text)
|
||||||
|
|
||||||
spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}audio{os.sep}{wav_fpath}')
|
# TODO: add original audio data root for loading
|
||||||
|
file_name = wav_fpath.split("_00")[0].split('-')[1]
|
||||||
|
spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}..{os.sep}..{os.sep}magicdata{os.sep}train{os.sep}{"_".join(file_name.split("_")[:2])}{os.sep}{file_name}')
|
||||||
|
|
||||||
|
# spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}audio{os.sep}{wav_fpath}')
|
||||||
sid = self.get_sid(sid)
|
sid = self.get_sid(sid)
|
||||||
emo = torch.FloatTensor(np.load(f'{self.datasets_root}{os.sep}emo{os.sep}{wav_fpath.replace("audio", "emo")}'))
|
emo = torch.FloatTensor(np.load(f'{self.datasets_root}{os.sep}emo{os.sep}{wav_fpath.replace("audio", "emo")}'))
|
||||||
return (text, spec, wav, sid, emo)
|
return (text, spec, wav, sid, emo)
|
||||||
|
|
||||||
def get_audio(self, filename):
|
def get_audio(self, filename):
|
||||||
# audio, sampling_rate = load_wav(filename)
|
audio, sampling_rate = load_wav_to_torch(filename)
|
||||||
|
if sampling_rate != self.sampling_rate:
|
||||||
# if sampling_rate != self.sampling_rate:
|
raise ValueError("{} {} SR doesn't match target {} SR".format(
|
||||||
# raise ValueError("{} {} SR doesn't match target {} SR".format(
|
sampling_rate, self.sampling_rate))
|
||||||
# sampling_rate, self.sampling_rate))
|
audio_norm = audio / self.max_wav_value
|
||||||
# audio = torch.load(filename)
|
audio_norm = audio_norm.unsqueeze(0)
|
||||||
audio = torch.FloatTensor(np.load(filename).astype(np.float32))
|
spec = spectrogram(audio_norm, self.filter_length, self.hop_length, self.win_length,
|
||||||
audio = audio.unsqueeze(0)
|
|
||||||
# audio_norm = audio / self.max_wav_value
|
|
||||||
# audio_norm = audio_norm.unsqueeze(0)
|
|
||||||
# spec_filename = filename.replace(".wav", ".spec.pt")
|
|
||||||
# if os.path.exists(spec_filename):
|
|
||||||
# spec = torch.load(spec_filename)
|
|
||||||
# else:
|
|
||||||
# spec = spectrogram(audio, self.filter_length,
|
|
||||||
# self.sampling_rate, self.hop_length, self.win_length,
|
|
||||||
# center=False)
|
|
||||||
# spec = torch.squeeze(spec, 0)
|
|
||||||
# torch.save(spec, spec_filename)
|
|
||||||
spec = spectrogram(audio, self.filter_length, self.hop_length, self.win_length,
|
|
||||||
center=False)
|
center=False)
|
||||||
spec = torch.squeeze(spec, 0)
|
spec = torch.squeeze(spec, 0)
|
||||||
return spec, audio
|
return spec, audio_norm
|
||||||
|
|
||||||
|
# print("Loading", filename)
|
||||||
|
# # audio = torch.FloatTensor(np.load(filename).astype(np.float32))
|
||||||
|
# audio = audio.unsqueeze(0)
|
||||||
|
# audio_norm = audio / self.max_wav_value
|
||||||
|
# audio_norm = audio_norm.unsqueeze(0)
|
||||||
|
# # spec_filename = filename.replace(".wav", ".spec.pt")
|
||||||
|
# # if os.path.exists(spec_filename):
|
||||||
|
# # spec = torch.load(spec_filename)
|
||||||
|
# # else:
|
||||||
|
# # spec = spectrogram(audio, self.filter_length,self.hop_length, self.win_length,
|
||||||
|
# # center=False)
|
||||||
|
# # spec = torch.squeeze(spec, 0)
|
||||||
|
# # torch.save(spec, spec_filename)
|
||||||
|
# spec = spectrogram(audio, self.filter_length, self.hop_length, self.win_length,
|
||||||
|
# center=False)
|
||||||
|
# spec = torch.squeeze(spec, 0)
|
||||||
|
# return spec, audio
|
||||||
|
|
||||||
def get_text(self, text):
|
def get_text(self, text):
|
||||||
if self.cleaned_text:
|
if self.cleaned_text:
|
||||||
|
|
|
@ -17,6 +17,27 @@ def load_wav_to_torch(full_path):
|
||||||
sampling_rate, data = read(full_path)
|
sampling_rate, data = read(full_path)
|
||||||
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
||||||
|
|
||||||
|
def spectrogram1(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||||
|
if torch.min(y) < -1.:
|
||||||
|
print('min value is ', torch.min(y))
|
||||||
|
if torch.max(y) > 1.:
|
||||||
|
print('max value is ', torch.max(y))
|
||||||
|
|
||||||
|
global hann_window
|
||||||
|
dtype_device = str(y.dtype) + '_' + str(y.device)
|
||||||
|
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
||||||
|
if wnsize_dtype_device not in hann_window:
|
||||||
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||||
|
|
||||||
|
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
||||||
|
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
||||||
|
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def spectrogram(y, n_fft, hop_size, win_size, center=False):
|
def spectrogram(y, n_fft, hop_size, win_size, center=False):
|
||||||
if torch.min(y) < -1.:
|
if torch.min(y) < -1.:
|
||||||
|
@ -34,7 +55,7 @@ def spectrogram(y, n_fft, hop_size, win_size, center=False):
|
||||||
y = y.squeeze(1)
|
y = y.squeeze(1)
|
||||||
|
|
||||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
||||||
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
||||||
|
|
||||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
return spec
|
return spec
|
||||||
|
|
233
vits.ipynb
vendored
233
vits.ipynb
vendored
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user