From 26fe4a047d7102d8ea528430a32602885b0864d8 Mon Sep 17 00:00:00 2001 From: babysor00 Date: Thu, 18 Nov 2021 22:55:13 +0800 Subject: [PATCH] Differentiate GST token --- synthesizer/models/tacotron.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 431fcb2..5c3fce6 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -471,15 +471,13 @@ class Tacotron(nn.Module): # put after encoder if hparams.use_gst and self.gst is not None: if style_idx >= 0 and style_idx < 10: - gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token] - gst_embed = np.tile(gst_embed, (1, 8)) - scale = np.zeros(512) - scale[:] = 0.3 - speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32) - speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device) + query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda() + gst_embed = torch.tanh(self.gst.stl.embed) + key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1) + style_embed = self.gst.stl.attention(query, key) else: speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) - style_embed = self.gst(speaker_embedding_style, speaker_embedding) + style_embed = self.gst(speaker_embedding_style, speaker_embedding) encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) # style_embed = style_embed.expand_as(encoder_seq) # encoder_seq = torch.cat((encoder_seq, style_embed), 2)