mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Differentiate GST token
This commit is contained in:
parent
aff1b5313b
commit
26fe4a047d
|
@ -471,15 +471,13 @@ class Tacotron(nn.Module):
|
|||
# put after encoder
|
||||
if hparams.use_gst and self.gst is not None:
|
||||
if style_idx >= 0 and style_idx < 10:
|
||||
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
|
||||
gst_embed = np.tile(gst_embed, (1, 8))
|
||||
scale = np.zeros(512)
|
||||
scale[:] = 0.3
|
||||
speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32)
|
||||
speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device)
|
||||
query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda()
|
||||
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||
style_embed = self.gst.stl.attention(query, key)
|
||||
else:
|
||||
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||
# style_embed = style_embed.expand_as(encoder_seq)
|
||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user