From c3590bffb21016b25eea80fafcdc4704c4bf6b25 Mon Sep 17 00:00:00 2001 From: babysor00 Date: Sun, 17 Jul 2022 11:53:50 +0800 Subject: [PATCH] Add description for --- synthesizer/models/sublayer/lsa.py | 4 +- synthesizer/models/sublayer/pre_net.py | 9 ++ synthesizer/models/tacotron.py | 152 +++++++++++++++---------- 3 files changed, 103 insertions(+), 62 deletions(-) diff --git a/synthesizer/models/sublayer/lsa.py b/synthesizer/models/sublayer/lsa.py index 9a32913..cf2dfa5 100644 --- a/synthesizer/models/sublayer/lsa.py +++ b/synthesizer/models/sublayer/lsa.py @@ -18,9 +18,9 @@ class LSA(nn.Module): self.cumulative = torch.zeros(b, t, device=device) self.attention = torch.zeros(b, t, device=device) - def forward(self, encoder_seq_proj, query, t, chars): + def forward(self, encoder_seq_proj, query, times, chars): - if t == 0: self.init_attention(encoder_seq_proj) + if times == 0: self.init_attention(encoder_seq_proj) processed_query = self.W(query).unsqueeze(1) diff --git a/synthesizer/models/sublayer/pre_net.py b/synthesizer/models/sublayer/pre_net.py index 3c8ebb8..886646a 100644 --- a/synthesizer/models/sublayer/pre_net.py +++ b/synthesizer/models/sublayer/pre_net.py @@ -9,6 +9,15 @@ class PreNet(nn.Module): self.p = dropout def forward(self, x): + """forward + + Args: + x (3D tensor with size `[batch_size, num_chars, tts_embed_dims]`): input texts list + + Returns: + 3D tensor with size `[batch_size, num_chars, encoder_dims]` + + """ x = self.fc1(x) x = F.relu(x) x = F.dropout(x, self.p, training=True) diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 1c12649..f8b01bb 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -9,52 +9,80 @@ from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp from synthesizer.hparams import hparams class Encoder(nn.Module): - def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout): + def __init__(self, num_chars, embed_dims=512, encoder_dims=256, K=5, num_highways=4, dropout=0.5): + """ Encoder for SV2TTS + + Args: + num_chars (int): length of symbols + embed_dims (int, optional): embedding dim for input texts. Defaults to 512. + encoder_dims (int, optional): output dim for encoder. Defaults to 256. + K (int, optional): _description_. Defaults to 5. + num_highways (int, optional): _description_. Defaults to 4. + dropout (float, optional): _description_. Defaults to 0.5. + """ super().__init__() - prenet_dims = (encoder_dims, encoder_dims) - cbhg_channels = encoder_dims self.embedding = nn.Embedding(num_chars, embed_dims) - self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], + self.pre_net = PreNet(embed_dims, fc1_dims=encoder_dims, fc2_dims=encoder_dims, dropout=dropout) - self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels, - proj_channels=[cbhg_channels, cbhg_channels], + self.cbhg = CBHG(K=K, in_channels=encoder_dims, channels=encoder_dims, + proj_channels=[encoder_dims, encoder_dims], num_highways=num_highways) def forward(self, x): - x = self.embedding(x) - x = self.pre_net(x) - x.transpose_(1, 2) - return self.cbhg(x) + """forward pass for encoder + + Args: + x (2D tensor with size `[batch_size, text_num_chars]`): input texts list + + Returns: + 3D tensor with size `[batch_size, text_num_chars, encoder_dims]` + + """ + x = self.embedding(x) # return: [batch_size, text_num_chars, tts_embed_dims] + x = self.pre_net(x) # return: [batch_size, text_num_chars, encoder_dims] + x.transpose_(1, 2) # return: [batch_size, encoder_dims, text_num_chars] + return self.cbhg(x) # return: [batch_size, text_num_chars, encoder_dims] class Decoder(nn.Module): # Class variable because its value doesn't change between classes # yet ought to be scoped by class because its a property of a Decoder max_r = 20 - def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims, + def __init__(self, n_mels, input_dims, decoder_dims, lstm_dims, dropout, speaker_embedding_size): super().__init__() self.register_buffer("r", torch.tensor(1, dtype=torch.int)) self.n_mels = n_mels - prenet_dims = (decoder_dims * 2, decoder_dims * 2) - self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], + self.prenet = PreNet(n_mels, fc1_dims=decoder_dims * 2, fc2_dims=decoder_dims * 2, dropout=dropout) self.attn_net = LSA(decoder_dims) if hparams.use_gst: speaker_embedding_size += gst_hp.E - self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims) - self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims) + self.attn_rnn = nn.GRUCell(input_dims + decoder_dims * 2, decoder_dims) + self.rnn_input = nn.Linear(input_dims + decoder_dims, lstm_dims) self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) - self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1) + self.stop_proj = nn.Linear(input_dims + lstm_dims, 1) def zoneout(self, prev, current, device, p=0.1): mask = torch.zeros(prev.size(),device=device).bernoulli_(p) return prev * mask + current * (1 - mask) def forward(self, encoder_seq, encoder_seq_proj, prenet_in, - hidden_states, cell_states, context_vec, t, chars): + hidden_states, cell_states, context_vec, times, chars): + """_summary_ + Args: + encoder_seq (3D tensor `[batch_size, text_num_chars, project_dim(default to 512)]`): _description_ + encoder_seq_proj (3D tensor `[batch_size, text_num_chars, decoder_dims(default to 128)]`): _description_ + prenet_in (2D tensor `[batch_size, n_mels]`): _description_ + hidden_states (_type_): _description_ + cell_states (_type_): _description_ + context_vec (2D tensor `[batch_size, project_dim(default to 512)]`): _description_ + times (int): the number of times runned + chars (2D tensor with size `[batch_size, text_num_chars]`): original texts list input + + """ # Need this for reshaping mels batch_size = encoder_seq.size(0) device = encoder_seq.device @@ -63,25 +91,25 @@ class Decoder(nn.Module): rnn1_cell, rnn2_cell = cell_states # PreNet for the Attention RNN - prenet_out = self.prenet(prenet_in) + prenet_out = self.prenet(prenet_in) # return: `[batch_size, decoder_dims * 2(256)]` # Compute the Attention RNN hidden state - attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) - attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) + attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) # `[batch_size, project_dim + decoder_dims * 2 (768)]` + attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) # `[batch_size, decoder_dims (128)]` # Compute the attention scores - scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars) + scores = self.attn_net(encoder_seq_proj, attn_hidden, times, chars) # Dot product to create the context vector context_vec = scores @ encoder_seq context_vec = context_vec.squeeze(1) # Concat Attention RNN output w. Context Vector & project - x = torch.cat([context_vec, attn_hidden], dim=1) - x = self.rnn_input(x) + x = torch.cat([context_vec, attn_hidden], dim=1) # `[batch_size, project_dim + decoder_dims (630)]` + x = self.rnn_input(x) # `[batch_size, lstm_dims(1024)]` # Compute first Residual RNN, training with fixed zoneout rate 0.1 - rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) + rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) # `[batch_size, lstm_dims(1024)]` if self.training: rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device) else: @@ -89,7 +117,7 @@ class Decoder(nn.Module): x = x + rnn1_hidden # Compute second Residual RNN - rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) + rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) # `[batch_size, lstm_dims(1024)]` if self.training: rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device) else: @@ -97,8 +125,8 @@ class Decoder(nn.Module): x = x + rnn2_hidden # Project Mels - mels = self.mel_proj(x) - mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] + mels = self.mel_proj(x) # `[batch_size, 1600]` + mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] # `[batch_size, n_mels, r]` hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) cell_states = (rnn1_cell, rnn2_cell) @@ -119,15 +147,15 @@ class Tacotron(Base): self.encoder_dims = encoder_dims self.decoder_dims = decoder_dims self.speaker_embedding_size = speaker_embedding_size - self.encoder = Encoder(embed_dims, num_chars, encoder_dims, + self.encoder = Encoder(num_chars, embed_dims, encoder_dims, encoder_K, num_highways, dropout) - project_dims = encoder_dims + speaker_embedding_size + self.project_dims = encoder_dims + speaker_embedding_size if hparams.use_gst: - project_dims += gst_hp.E - self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False) + self.project_dims += gst_hp.E + self.encoder_proj = nn.Linear(self.project_dims, decoder_dims, bias=False) if hparams.use_gst: self.gst = GlobalStyleToken(speaker_embedding_size) - self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims, + self.decoder = Decoder(n_mels, self.project_dims, decoder_dims, lstm_dims, dropout, speaker_embedding_size) self.postnet = CBHG(postnet_K, n_mels, postnet_dims, [postnet_dims, fft_bins], num_highways) @@ -142,36 +170,43 @@ class Tacotron(Base): @staticmethod def _add_speaker_embedding(x, speaker_embedding): - # SV2TTS - # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims) - # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size) - # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size)) - # This concats the speaker embedding for each char in the encoder output + """Add speaker embedding + This concats the speaker embedding for each char in the encoder output + Args: + x (3D tensor with size `[batch_size, text_num_chars, encoder_dims]`): the encoder output + speaker_embedding (2D tensor `[batch_size, speaker_embedding_size]`): the speaker embedding + Returns: + 3D tensor with size `[batch_size, text_num_chars, encoder_dims+speaker_embedding_size]` + """ # Save the dimensions as human-readable names batch_size = x.size()[0] - num_chars = x.size()[1] - - if speaker_embedding.dim() == 1: - idx = 0 - else: - idx = 1 + text_num_chars = x.size()[1] # Start by making a copy of each speaker embedding to match the input text length - # The output of this has size (batch_size, num_chars * speaker_embedding_size) - speaker_embedding_size = speaker_embedding.size()[idx] - e = speaker_embedding.repeat_interleave(num_chars, dim=idx) + # The output of this has size (batch_size, text_num_chars * speaker_embedding_size) + speaker_embedding_size = speaker_embedding.size()[1] + e = speaker_embedding.repeat_interleave(text_num_chars, dim=1) # Reshape it and transpose - e = e.reshape(batch_size, speaker_embedding_size, num_chars) + e = e.reshape(batch_size, speaker_embedding_size, text_num_chars) e = e.transpose(1, 2) # Concatenate the tiled speaker embedding with the encoder output x = torch.cat((x, e), 2) return x - def forward(self, texts, mels, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): - + def forward(self, texts, mels, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5): + """Forward pass for Tacotron + + Args: + texts (`[batch_size, text_num_chars]`): input texts list + mels (`[batch_size, varied_mel_lengths, steps]`): mels for comparison (training only) + speaker_embedding (`[batch_size, speaker_embedding_size(default to 256)]`): referring embedding. + steps (int, optional): . Defaults to 2000. + style_idx (int, optional): GST style selected. Defaults to 0. + min_stop_token (int, optional): decoder min_stop_token. Defaults to 5. + """ device = texts.device # use same device as parameters if self.training: @@ -194,18 +229,11 @@ class Tacotron(Base): # Frame for start of decoder loop go_frame = torch.zeros(batch_size, self.n_mels, device=device) - # Need an initial context vector - size = self.encoder_dims + self.speaker_embedding_size - if hparams.use_gst: - size += gst_hp.E - context_vec = torch.zeros(batch_size, size, device=device) - # SV2TTS: Run the encoder with the speaker embedding # The projection avoids unnecessary matmuls in the decoder loop encoder_seq = self.encoder(texts) - if speaker_embedding is not None: - encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding) + encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding) if hparams.use_gst and self.gst is not None: if self.training: @@ -222,12 +250,16 @@ class Tacotron(Base): else: speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) style_embed = self.gst(speaker_embedding_style, speaker_embedding) - encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) - encoder_seq_proj = self.encoder_proj(encoder_seq) + encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) # return: [batch_size, text_num_chars, project_dims] + + encoder_seq_proj = self.encoder_proj(encoder_seq) # return: [batch_size, text_num_chars, decoder_dims] # Need a couple of lists for outputs mel_outputs, attn_scores, stop_outputs = [], [], [] + # Need an initial context vector + context_vec = torch.zeros(batch_size, self.project_dims, device=device) + # Run the decoder loop for t in range(0, steps, self.r): if self.training: @@ -260,7 +292,7 @@ class Tacotron(Base): return mel_outputs, linear, attn_scores, stop_outputs - def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): + def generate(self, x, speaker_embedding, steps=2000, style_idx=0, min_stop_token=5): self.eval() mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token) return mel_outputs, linear, attn_scores