Differentiate GST token

This commit is contained in:
babysor00 2021-11-18 22:55:13 +08:00
parent aff1b5313b
commit 26fe4a047d

View File

@ -471,15 +471,13 @@ class Tacotron(nn.Module):
# put after encoder
if hparams.use_gst and self.gst is not None:
if style_idx >= 0 and style_idx < 10:
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
gst_embed = np.tile(gst_embed, (1, 8))
scale = np.zeros(512)
scale[:] = 0.3
speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32)
speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device)
query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda()
gst_embed = torch.tanh(self.gst.stl.embed)
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
style_embed = self.gst.stl.attention(query, key)
else:
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
# style_embed = style_embed.expand_as(encoder_seq)
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)