Add description for

This commit is contained in:
babysor00 2022-07-17 11:53:50 +08:00
parent efbdb21b70
commit c3590bffb2
3 changed files with 103 additions and 62 deletions

View File

@ -18,9 +18,9 @@ class LSA(nn.Module):
self.cumulative = torch.zeros(b, t, device=device) self.cumulative = torch.zeros(b, t, device=device)
self.attention = torch.zeros(b, t, device=device) self.attention = torch.zeros(b, t, device=device)
def forward(self, encoder_seq_proj, query, t, chars): def forward(self, encoder_seq_proj, query, times, chars):
if t == 0: self.init_attention(encoder_seq_proj) if times == 0: self.init_attention(encoder_seq_proj)
processed_query = self.W(query).unsqueeze(1) processed_query = self.W(query).unsqueeze(1)

View File

@ -9,6 +9,15 @@ class PreNet(nn.Module):
self.p = dropout self.p = dropout
def forward(self, x): def forward(self, x):
"""forward
Args:
x (3D tensor with size `[batch_size, num_chars, tts_embed_dims]`): input texts list
Returns:
3D tensor with size `[batch_size, num_chars, encoder_dims]`
"""
x = self.fc1(x) x = self.fc1(x)
x = F.relu(x) x = F.relu(x)
x = F.dropout(x, self.p, training=True) x = F.dropout(x, self.p, training=True)

View File

@ -9,52 +9,80 @@ from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
from synthesizer.hparams import hparams from synthesizer.hparams import hparams
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout): def __init__(self, num_chars, embed_dims=512, encoder_dims=256, K=5, num_highways=4, dropout=0.5):
""" Encoder for SV2TTS
Args:
num_chars (int): length of symbols
embed_dims (int, optional): embedding dim for input texts. Defaults to 512.
encoder_dims (int, optional): output dim for encoder. Defaults to 256.
K (int, optional): _description_. Defaults to 5.
num_highways (int, optional): _description_. Defaults to 4.
dropout (float, optional): _description_. Defaults to 0.5.
"""
super().__init__() super().__init__()
prenet_dims = (encoder_dims, encoder_dims)
cbhg_channels = encoder_dims
self.embedding = nn.Embedding(num_chars, embed_dims) self.embedding = nn.Embedding(num_chars, embed_dims)
self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], self.pre_net = PreNet(embed_dims, fc1_dims=encoder_dims, fc2_dims=encoder_dims,
dropout=dropout) dropout=dropout)
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels, self.cbhg = CBHG(K=K, in_channels=encoder_dims, channels=encoder_dims,
proj_channels=[cbhg_channels, cbhg_channels], proj_channels=[encoder_dims, encoder_dims],
num_highways=num_highways) num_highways=num_highways)
def forward(self, x): def forward(self, x):
x = self.embedding(x) """forward pass for encoder
x = self.pre_net(x)
x.transpose_(1, 2) Args:
return self.cbhg(x) x (2D tensor with size `[batch_size, text_num_chars]`): input texts list
Returns:
3D tensor with size `[batch_size, text_num_chars, encoder_dims]`
"""
x = self.embedding(x) # return: [batch_size, text_num_chars, tts_embed_dims]
x = self.pre_net(x) # return: [batch_size, text_num_chars, encoder_dims]
x.transpose_(1, 2) # return: [batch_size, encoder_dims, text_num_chars]
return self.cbhg(x) # return: [batch_size, text_num_chars, encoder_dims]
class Decoder(nn.Module): class Decoder(nn.Module):
# Class variable because its value doesn't change between classes # Class variable because its value doesn't change between classes
# yet ought to be scoped by class because its a property of a Decoder # yet ought to be scoped by class because its a property of a Decoder
max_r = 20 max_r = 20
def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims, def __init__(self, n_mels, input_dims, decoder_dims, lstm_dims,
dropout, speaker_embedding_size): dropout, speaker_embedding_size):
super().__init__() super().__init__()
self.register_buffer("r", torch.tensor(1, dtype=torch.int)) self.register_buffer("r", torch.tensor(1, dtype=torch.int))
self.n_mels = n_mels self.n_mels = n_mels
prenet_dims = (decoder_dims * 2, decoder_dims * 2) self.prenet = PreNet(n_mels, fc1_dims=decoder_dims * 2, fc2_dims=decoder_dims * 2,
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
dropout=dropout) dropout=dropout)
self.attn_net = LSA(decoder_dims) self.attn_net = LSA(decoder_dims)
if hparams.use_gst: if hparams.use_gst:
speaker_embedding_size += gst_hp.E speaker_embedding_size += gst_hp.E
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims) self.attn_rnn = nn.GRUCell(input_dims + decoder_dims * 2, decoder_dims)
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims) self.rnn_input = nn.Linear(input_dims + decoder_dims, lstm_dims)
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1) self.stop_proj = nn.Linear(input_dims + lstm_dims, 1)
def zoneout(self, prev, current, device, p=0.1): def zoneout(self, prev, current, device, p=0.1):
mask = torch.zeros(prev.size(),device=device).bernoulli_(p) mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
return prev * mask + current * (1 - mask) return prev * mask + current * (1 - mask)
def forward(self, encoder_seq, encoder_seq_proj, prenet_in, def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
hidden_states, cell_states, context_vec, t, chars): hidden_states, cell_states, context_vec, times, chars):
"""_summary_
Args:
encoder_seq (3D tensor `[batch_size, text_num_chars, project_dim(default to 512)]`): _description_
encoder_seq_proj (3D tensor `[batch_size, text_num_chars, decoder_dims(default to 128)]`): _description_
prenet_in (2D tensor `[batch_size, n_mels]`): _description_
hidden_states (_type_): _description_
cell_states (_type_): _description_
context_vec (2D tensor `[batch_size, project_dim(default to 512)]`): _description_
times (int): the number of times runned
chars (2D tensor with size `[batch_size, text_num_chars]`): original texts list input
"""
# Need this for reshaping mels # Need this for reshaping mels
batch_size = encoder_seq.size(0) batch_size = encoder_seq.size(0)
device = encoder_seq.device device = encoder_seq.device
@ -63,25 +91,25 @@ class Decoder(nn.Module):
rnn1_cell, rnn2_cell = cell_states rnn1_cell, rnn2_cell = cell_states
# PreNet for the Attention RNN # PreNet for the Attention RNN
prenet_out = self.prenet(prenet_in) prenet_out = self.prenet(prenet_in) # return: `[batch_size, decoder_dims * 2(256)]`
# Compute the Attention RNN hidden state # Compute the Attention RNN hidden state
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) # `[batch_size, project_dim + decoder_dims * 2 (768)]`
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) # `[batch_size, decoder_dims (128)]`
# Compute the attention scores # Compute the attention scores
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars) scores = self.attn_net(encoder_seq_proj, attn_hidden, times, chars)
# Dot product to create the context vector # Dot product to create the context vector
context_vec = scores @ encoder_seq context_vec = scores @ encoder_seq
context_vec = context_vec.squeeze(1) context_vec = context_vec.squeeze(1)
# Concat Attention RNN output w. Context Vector & project # Concat Attention RNN output w. Context Vector & project
x = torch.cat([context_vec, attn_hidden], dim=1) x = torch.cat([context_vec, attn_hidden], dim=1) # `[batch_size, project_dim + decoder_dims (630)]`
x = self.rnn_input(x) x = self.rnn_input(x) # `[batch_size, lstm_dims(1024)]`
# Compute first Residual RNN, training with fixed zoneout rate 0.1 # Compute first Residual RNN, training with fixed zoneout rate 0.1
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) # `[batch_size, lstm_dims(1024)]`
if self.training: if self.training:
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device) rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
else: else:
@ -89,7 +117,7 @@ class Decoder(nn.Module):
x = x + rnn1_hidden x = x + rnn1_hidden
# Compute second Residual RNN # Compute second Residual RNN
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) # `[batch_size, lstm_dims(1024)]`
if self.training: if self.training:
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device) rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
else: else:
@ -97,8 +125,8 @@ class Decoder(nn.Module):
x = x + rnn2_hidden x = x + rnn2_hidden
# Project Mels # Project Mels
mels = self.mel_proj(x) mels = self.mel_proj(x) # `[batch_size, 1600]`
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] # `[batch_size, n_mels, r]`
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
cell_states = (rnn1_cell, rnn2_cell) cell_states = (rnn1_cell, rnn2_cell)
@ -119,15 +147,15 @@ class Tacotron(Base):
self.encoder_dims = encoder_dims self.encoder_dims = encoder_dims
self.decoder_dims = decoder_dims self.decoder_dims = decoder_dims
self.speaker_embedding_size = speaker_embedding_size self.speaker_embedding_size = speaker_embedding_size
self.encoder = Encoder(embed_dims, num_chars, encoder_dims, self.encoder = Encoder(num_chars, embed_dims, encoder_dims,
encoder_K, num_highways, dropout) encoder_K, num_highways, dropout)
project_dims = encoder_dims + speaker_embedding_size self.project_dims = encoder_dims + speaker_embedding_size
if hparams.use_gst: if hparams.use_gst:
project_dims += gst_hp.E self.project_dims += gst_hp.E
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False) self.encoder_proj = nn.Linear(self.project_dims, decoder_dims, bias=False)
if hparams.use_gst: if hparams.use_gst:
self.gst = GlobalStyleToken(speaker_embedding_size) self.gst = GlobalStyleToken(speaker_embedding_size)
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims, self.decoder = Decoder(n_mels, self.project_dims, decoder_dims, lstm_dims,
dropout, speaker_embedding_size) dropout, speaker_embedding_size)
self.postnet = CBHG(postnet_K, n_mels, postnet_dims, self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
[postnet_dims, fft_bins], num_highways) [postnet_dims, fft_bins], num_highways)
@ -142,36 +170,43 @@ class Tacotron(Base):
@staticmethod @staticmethod
def _add_speaker_embedding(x, speaker_embedding): def _add_speaker_embedding(x, speaker_embedding):
# SV2TTS """Add speaker embedding
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims) This concats the speaker embedding for each char in the encoder output
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size) Args:
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size)) x (3D tensor with size `[batch_size, text_num_chars, encoder_dims]`): the encoder output
# This concats the speaker embedding for each char in the encoder output speaker_embedding (2D tensor `[batch_size, speaker_embedding_size]`): the speaker embedding
Returns:
3D tensor with size `[batch_size, text_num_chars, encoder_dims+speaker_embedding_size]`
"""
# Save the dimensions as human-readable names # Save the dimensions as human-readable names
batch_size = x.size()[0] batch_size = x.size()[0]
num_chars = x.size()[1] text_num_chars = x.size()[1]
if speaker_embedding.dim() == 1:
idx = 0
else:
idx = 1
# Start by making a copy of each speaker embedding to match the input text length # Start by making a copy of each speaker embedding to match the input text length
# The output of this has size (batch_size, num_chars * speaker_embedding_size) # The output of this has size (batch_size, text_num_chars * speaker_embedding_size)
speaker_embedding_size = speaker_embedding.size()[idx] speaker_embedding_size = speaker_embedding.size()[1]
e = speaker_embedding.repeat_interleave(num_chars, dim=idx) e = speaker_embedding.repeat_interleave(text_num_chars, dim=1)
# Reshape it and transpose # Reshape it and transpose
e = e.reshape(batch_size, speaker_embedding_size, num_chars) e = e.reshape(batch_size, speaker_embedding_size, text_num_chars)
e = e.transpose(1, 2) e = e.transpose(1, 2)
# Concatenate the tiled speaker embedding with the encoder output # Concatenate the tiled speaker embedding with the encoder output
x = torch.cat((x, e), 2) x = torch.cat((x, e), 2)
return x return x
def forward(self, texts, mels, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): def forward(self, texts, mels, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5):
"""Forward pass for Tacotron
Args:
texts (`[batch_size, text_num_chars]`): input texts list
mels (`[batch_size, varied_mel_lengths, steps]`): mels for comparison (training only)
speaker_embedding (`[batch_size, speaker_embedding_size(default to 256)]`): referring embedding.
steps (int, optional): . Defaults to 2000.
style_idx (int, optional): GST style selected. Defaults to 0.
min_stop_token (int, optional): decoder min_stop_token. Defaults to 5.
"""
device = texts.device # use same device as parameters device = texts.device # use same device as parameters
if self.training: if self.training:
@ -194,18 +229,11 @@ class Tacotron(Base):
# <GO> Frame for start of decoder loop # <GO> Frame for start of decoder loop
go_frame = torch.zeros(batch_size, self.n_mels, device=device) go_frame = torch.zeros(batch_size, self.n_mels, device=device)
# Need an initial context vector
size = self.encoder_dims + self.speaker_embedding_size
if hparams.use_gst:
size += gst_hp.E
context_vec = torch.zeros(batch_size, size, device=device)
# SV2TTS: Run the encoder with the speaker embedding # SV2TTS: Run the encoder with the speaker embedding
# The projection avoids unnecessary matmuls in the decoder loop # The projection avoids unnecessary matmuls in the decoder loop
encoder_seq = self.encoder(texts) encoder_seq = self.encoder(texts)
if speaker_embedding is not None: encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding)
encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding)
if hparams.use_gst and self.gst is not None: if hparams.use_gst and self.gst is not None:
if self.training: if self.training:
@ -222,12 +250,16 @@ class Tacotron(Base):
else: else:
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
style_embed = self.gst(speaker_embedding_style, speaker_embedding) style_embed = self.gst(speaker_embedding_style, speaker_embedding)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) # return: [batch_size, text_num_chars, project_dims]
encoder_seq_proj = self.encoder_proj(encoder_seq)
encoder_seq_proj = self.encoder_proj(encoder_seq) # return: [batch_size, text_num_chars, decoder_dims]
# Need a couple of lists for outputs # Need a couple of lists for outputs
mel_outputs, attn_scores, stop_outputs = [], [], [] mel_outputs, attn_scores, stop_outputs = [], [], []
# Need an initial context vector
context_vec = torch.zeros(batch_size, self.project_dims, device=device)
# Run the decoder loop # Run the decoder loop
for t in range(0, steps, self.r): for t in range(0, steps, self.r):
if self.training: if self.training:
@ -260,7 +292,7 @@ class Tacotron(Base):
return mel_outputs, linear, attn_scores, stop_outputs return mel_outputs, linear, attn_scores, stop_outputs
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): def generate(self, x, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5):
self.eval() self.eval()
mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token) mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token)
return mel_outputs, linear, attn_scores return mel_outputs, linear, attn_scores