diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 0904c8d..0ed665f 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from synthesizer.models.global_style_token import GlobalStyleToken +from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp class HighwayNetwork(nn.Module): @@ -254,12 +255,12 @@ class Decoder(nn.Module): self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], dropout=dropout) self.attn_net = LSA(decoder_dims) - self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims) - self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims) + self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size + gst_hp.E, decoder_dims) + self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size + gst_hp.E, lstm_dims) self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) - self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1) + self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims + gst_hp.E, 1) def zoneout(self, prev, current, p=0.1): device = next(self.parameters()).device # Use same device as parameters @@ -336,7 +337,7 @@ class Tacotron(nn.Module): self.speaker_embedding_size = speaker_embedding_size self.encoder = Encoder(embed_dims, num_chars, encoder_dims, encoder_K, num_highways, dropout) - self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False) + self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size + gst_hp.E, decoder_dims, bias=False) self.gst = GlobalStyleToken() self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims, dropout, speaker_embedding_size) @@ -379,7 +380,7 @@ class Tacotron(nn.Module): go_frame = torch.zeros(batch_size, self.n_mels, device=device) # Need an initial context vector - context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device) + context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size + gst_hp.E, device=device) # SV2TTS: Run the encoder with the speaker embedding # The projection avoids unnecessary matmuls in the decoder loop @@ -388,7 +389,7 @@ class Tacotron(nn.Module): if self.gst is not None: style_embed = self.gst(speaker_embedding) style_embed = style_embed.expand_as(encoder_seq) - encoder_seq = encoder_seq + style_embed + encoder_seq = torch.cat((encoder_seq, style_embed), 2) encoder_seq_proj = self.encoder_proj(encoder_seq) # Need a couple of lists for outputs @@ -440,23 +441,24 @@ class Tacotron(nn.Module): go_frame = torch.zeros(batch_size, self.n_mels, device=device) # Need an initial context vector - context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device) + context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size + gst_hp.E, device=device) # SV2TTS: Run the encoder with the speaker embedding # The projection avoids unnecessary matmuls in the decoder loop encoder_seq = self.encoder(x, speaker_embedding) # put after encoder - if self.gst is not None and style_idx >= 0 and style_idx < 10: - gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token] - gst_embed = np.tile(gst_embed, (1, 8)) - scale = np.zeros(512) - scale[:] = 0.3 - speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32) - speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device) + if self.gst is not None: + if style_idx >= 0 and style_idx < 10: + gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token] + gst_embed = np.tile(gst_embed, (1, 8)) + scale = np.zeros(512) + scale[:] = 0.3 + speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32) + speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device) style_embed = self.gst(speaker_embedding) style_embed = style_embed.expand_as(encoder_seq) - encoder_seq = encoder_seq + style_embed + encoder_seq = torch.cat((encoder_seq, style_embed), 2) encoder_seq_proj = self.encoder_proj(encoder_seq) # Need a couple of lists for outputs