2022-07-17 09:58:17 +08:00
|
|
|
|
2021-08-07 11:56:00 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2022-07-17 09:58:17 +08:00
|
|
|
from .sublayer.global_style_token import GlobalStyleToken
|
|
|
|
from .sublayer.pre_net import PreNet
|
|
|
|
from .sublayer.cbhg import CBHG
|
|
|
|
from .sublayer.lsa import LSA
|
|
|
|
from .base import Base
|
2021-10-23 10:28:32 +08:00
|
|
|
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
2021-11-10 23:23:13 +08:00
|
|
|
from synthesizer.hparams import hparams
|
2021-08-07 11:56:00 +08:00
|
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
|
|
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
|
|
|
|
super().__init__()
|
|
|
|
prenet_dims = (encoder_dims, encoder_dims)
|
|
|
|
cbhg_channels = encoder_dims
|
|
|
|
self.embedding = nn.Embedding(num_chars, embed_dims)
|
|
|
|
self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
|
|
|
dropout=dropout)
|
|
|
|
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
|
|
|
proj_channels=[cbhg_channels, cbhg_channels],
|
|
|
|
num_highways=num_highways)
|
|
|
|
|
2022-07-17 09:58:17 +08:00
|
|
|
def forward(self, x):
|
2021-08-07 11:56:00 +08:00
|
|
|
x = self.embedding(x)
|
|
|
|
x = self.pre_net(x)
|
|
|
|
x.transpose_(1, 2)
|
2022-07-17 09:58:17 +08:00
|
|
|
return self.cbhg(x)
|
2021-08-07 11:56:00 +08:00
|
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
|
|
# Class variable because its value doesn't change between classes
|
|
|
|
# yet ought to be scoped by class because its a property of a Decoder
|
|
|
|
max_r = 20
|
|
|
|
def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
|
|
|
|
dropout, speaker_embedding_size):
|
|
|
|
super().__init__()
|
|
|
|
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
|
|
|
|
self.n_mels = n_mels
|
|
|
|
prenet_dims = (decoder_dims * 2, decoder_dims * 2)
|
|
|
|
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
|
|
|
dropout=dropout)
|
|
|
|
self.attn_net = LSA(decoder_dims)
|
2021-11-10 23:23:13 +08:00
|
|
|
if hparams.use_gst:
|
|
|
|
speaker_embedding_size += gst_hp.E
|
|
|
|
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
|
|
|
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
2021-08-07 11:56:00 +08:00
|
|
|
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
|
|
|
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
|
|
|
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
2021-11-10 23:23:13 +08:00
|
|
|
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
2021-08-07 11:56:00 +08:00
|
|
|
|
2021-11-27 20:53:08 +08:00
|
|
|
def zoneout(self, prev, current, device, p=0.1):
|
|
|
|
mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
|
2021-08-07 11:56:00 +08:00
|
|
|
return prev * mask + current * (1 - mask)
|
|
|
|
|
|
|
|
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
|
|
|
hidden_states, cell_states, context_vec, t, chars):
|
|
|
|
|
|
|
|
# Need this for reshaping mels
|
|
|
|
batch_size = encoder_seq.size(0)
|
2021-11-27 20:53:08 +08:00
|
|
|
device = encoder_seq.device
|
2021-08-07 11:56:00 +08:00
|
|
|
# Unpack the hidden and cell states
|
|
|
|
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
|
|
|
rnn1_cell, rnn2_cell = cell_states
|
|
|
|
|
|
|
|
# PreNet for the Attention RNN
|
|
|
|
prenet_out = self.prenet(prenet_in)
|
|
|
|
|
|
|
|
# Compute the Attention RNN hidden state
|
|
|
|
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
|
|
|
|
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
|
|
|
|
|
|
|
|
# Compute the attention scores
|
|
|
|
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
|
|
|
|
|
|
|
|
# Dot product to create the context vector
|
|
|
|
context_vec = scores @ encoder_seq
|
|
|
|
context_vec = context_vec.squeeze(1)
|
|
|
|
|
|
|
|
# Concat Attention RNN output w. Context Vector & project
|
|
|
|
x = torch.cat([context_vec, attn_hidden], dim=1)
|
|
|
|
x = self.rnn_input(x)
|
|
|
|
|
2022-06-26 23:21:32 +08:00
|
|
|
# Compute first Residual RNN, training with fixed zoneout rate 0.1
|
2021-08-07 11:56:00 +08:00
|
|
|
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
|
|
|
if self.training:
|
2021-11-27 20:53:08 +08:00
|
|
|
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
|
2021-08-07 11:56:00 +08:00
|
|
|
else:
|
|
|
|
rnn1_hidden = rnn1_hidden_next
|
|
|
|
x = x + rnn1_hidden
|
|
|
|
|
|
|
|
# Compute second Residual RNN
|
|
|
|
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
|
|
|
if self.training:
|
2021-11-27 20:53:08 +08:00
|
|
|
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
|
2021-08-07 11:56:00 +08:00
|
|
|
else:
|
|
|
|
rnn2_hidden = rnn2_hidden_next
|
|
|
|
x = x + rnn2_hidden
|
|
|
|
|
|
|
|
# Project Mels
|
|
|
|
mels = self.mel_proj(x)
|
|
|
|
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
|
|
|
|
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
|
|
|
cell_states = (rnn1_cell, rnn2_cell)
|
|
|
|
|
|
|
|
# Stop token prediction
|
|
|
|
s = torch.cat((x, context_vec), dim=1)
|
|
|
|
s = self.stop_proj(s)
|
|
|
|
stop_tokens = torch.sigmoid(s)
|
|
|
|
|
|
|
|
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
|
|
|
|
2022-07-17 09:58:17 +08:00
|
|
|
class Tacotron(Base):
|
2021-08-07 11:56:00 +08:00
|
|
|
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
|
|
|
|
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
|
|
|
|
dropout, stop_threshold, speaker_embedding_size):
|
2022-07-17 09:58:17 +08:00
|
|
|
super().__init__(stop_threshold)
|
2021-08-07 11:56:00 +08:00
|
|
|
self.n_mels = n_mels
|
|
|
|
self.lstm_dims = lstm_dims
|
|
|
|
self.encoder_dims = encoder_dims
|
|
|
|
self.decoder_dims = decoder_dims
|
|
|
|
self.speaker_embedding_size = speaker_embedding_size
|
|
|
|
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
|
|
|
encoder_K, num_highways, dropout)
|
2021-11-10 23:23:13 +08:00
|
|
|
project_dims = encoder_dims + speaker_embedding_size
|
|
|
|
if hparams.use_gst:
|
|
|
|
project_dims += gst_hp.E
|
|
|
|
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
|
2021-11-17 00:12:27 +08:00
|
|
|
if hparams.use_gst:
|
|
|
|
self.gst = GlobalStyleToken(speaker_embedding_size)
|
2021-08-07 11:56:00 +08:00
|
|
|
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
|
|
|
dropout, speaker_embedding_size)
|
|
|
|
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
|
|
|
[postnet_dims, fft_bins], num_highways)
|
|
|
|
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
|
|
|
|
2021-11-07 21:48:15 +08:00
|
|
|
@staticmethod
|
|
|
|
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
|
|
|
speaker_embeddings_ = speaker_embeddings.expand(
|
|
|
|
outputs.size(0), outputs.size(1), -1)
|
|
|
|
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
|
|
|
return outputs
|
|
|
|
|
2022-07-17 09:58:17 +08:00
|
|
|
@staticmethod
|
|
|
|
def _add_speaker_embedding(x, speaker_embedding):
|
|
|
|
# SV2TTS
|
|
|
|
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
|
|
|
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
|
|
|
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
|
|
|
# This concats the speaker embedding for each char in the encoder output
|
|
|
|
|
|
|
|
# Save the dimensions as human-readable names
|
|
|
|
batch_size = x.size()[0]
|
|
|
|
num_chars = x.size()[1]
|
|
|
|
|
|
|
|
if speaker_embedding.dim() == 1:
|
|
|
|
idx = 0
|
|
|
|
else:
|
|
|
|
idx = 1
|
|
|
|
|
|
|
|
# Start by making a copy of each speaker embedding to match the input text length
|
|
|
|
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
|
|
|
|
speaker_embedding_size = speaker_embedding.size()[idx]
|
|
|
|
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
|
|
|
|
|
|
|
# Reshape it and transpose
|
|
|
|
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
|
|
|
e = e.transpose(1, 2)
|
|
|
|
|
|
|
|
# Concatenate the tiled speaker embedding with the encoder output
|
|
|
|
x = torch.cat((x, e), 2)
|
|
|
|
return x
|
|
|
|
|
2022-06-26 23:21:32 +08:00
|
|
|
def forward(self, texts, mels, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
|
|
|
|
2021-11-27 20:53:08 +08:00
|
|
|
device = texts.device # use same device as parameters
|
2021-08-07 11:56:00 +08:00
|
|
|
|
2022-06-26 23:21:32 +08:00
|
|
|
if self.training:
|
|
|
|
self.step += 1
|
|
|
|
batch_size, _, steps = mels.size()
|
|
|
|
else:
|
|
|
|
batch_size, _ = texts.size()
|
2021-08-07 11:56:00 +08:00
|
|
|
|
|
|
|
# Initialise all hidden states and pack into tuple
|
|
|
|
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
|
|
|
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
|
|
|
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
|
|
|
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
|
|
|
|
|
|
|
# Initialise all lstm cell states and pack into tuple
|
|
|
|
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
|
|
|
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
|
|
|
cell_states = (rnn1_cell, rnn2_cell)
|
|
|
|
|
|
|
|
# <GO> Frame for start of decoder loop
|
|
|
|
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
|
|
|
|
|
|
|
# Need an initial context vector
|
2021-11-10 23:23:13 +08:00
|
|
|
size = self.encoder_dims + self.speaker_embedding_size
|
|
|
|
if hparams.use_gst:
|
|
|
|
size += gst_hp.E
|
|
|
|
context_vec = torch.zeros(batch_size, size, device=device)
|
2021-08-07 11:56:00 +08:00
|
|
|
|
|
|
|
# SV2TTS: Run the encoder with the speaker embedding
|
|
|
|
# The projection avoids unnecessary matmuls in the decoder loop
|
2022-07-17 09:58:17 +08:00
|
|
|
encoder_seq = self.encoder(texts)
|
2022-06-26 23:21:32 +08:00
|
|
|
|
2022-07-17 09:58:17 +08:00
|
|
|
if speaker_embedding is not None:
|
|
|
|
encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding)
|
|
|
|
|
2021-11-10 23:23:13 +08:00
|
|
|
if hparams.use_gst and self.gst is not None:
|
2022-06-26 23:21:32 +08:00
|
|
|
if self.training:
|
|
|
|
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
|
|
|
# style_embed = style_embed.expand_as(encoder_seq)
|
|
|
|
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
|
|
|
elif style_idx >= 0 and style_idx < 10:
|
2021-11-29 21:10:07 +08:00
|
|
|
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
|
|
|
if device.type == 'cuda':
|
|
|
|
query = query.cuda()
|
2021-11-18 22:55:13 +08:00
|
|
|
gst_embed = torch.tanh(self.gst.stl.embed)
|
|
|
|
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
|
|
|
style_embed = self.gst.stl.attention(query, key)
|
2021-11-07 21:48:15 +08:00
|
|
|
else:
|
2021-11-09 21:08:28 +08:00
|
|
|
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
2021-11-18 22:55:13 +08:00
|
|
|
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
2021-11-07 21:48:15 +08:00
|
|
|
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
2021-08-07 11:56:00 +08:00
|
|
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
|
|
|
|
|
|
|
# Need a couple of lists for outputs
|
|
|
|
mel_outputs, attn_scores, stop_outputs = [], [], []
|
|
|
|
|
|
|
|
# Run the decoder loop
|
|
|
|
for t in range(0, steps, self.r):
|
2022-06-26 23:21:32 +08:00
|
|
|
if self.training:
|
|
|
|
prenet_in = mels[:, :, t -1] if t > 0 else go_frame
|
|
|
|
else:
|
|
|
|
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
2021-08-07 11:56:00 +08:00
|
|
|
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
2022-06-26 23:21:32 +08:00
|
|
|
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
|
|
|
hidden_states, cell_states, context_vec, t, texts)
|
2021-08-07 11:56:00 +08:00
|
|
|
mel_outputs.append(mel_frames)
|
|
|
|
attn_scores.append(scores)
|
|
|
|
stop_outputs.extend([stop_tokens] * self.r)
|
2022-06-26 23:21:32 +08:00
|
|
|
if not self.training and (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
2021-08-07 11:56:00 +08:00
|
|
|
|
|
|
|
# Concat the mel outputs into sequence
|
|
|
|
mel_outputs = torch.cat(mel_outputs, dim=2)
|
|
|
|
|
|
|
|
# Post-Process for Linear Spectrograms
|
|
|
|
postnet_out = self.postnet(mel_outputs)
|
|
|
|
linear = self.post_proj(postnet_out)
|
|
|
|
linear = linear.transpose(1, 2)
|
|
|
|
|
|
|
|
# For easy visualisation
|
|
|
|
attn_scores = torch.cat(attn_scores, 1)
|
2022-06-26 23:21:32 +08:00
|
|
|
# attn_scores = attn_scores.cpu().data.numpy()
|
2021-08-07 11:56:00 +08:00
|
|
|
stop_outputs = torch.cat(stop_outputs, 1)
|
|
|
|
|
2022-06-26 23:21:32 +08:00
|
|
|
if self.training:
|
|
|
|
self.train()
|
|
|
|
|
|
|
|
return mel_outputs, linear, attn_scores, stop_outputs
|
|
|
|
|
|
|
|
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
|
|
|
self.eval()
|
|
|
|
mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token)
|
2021-08-07 11:56:00 +08:00
|
|
|
return mel_outputs, linear, attn_scores
|