mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Differentiate GST token
This commit is contained in:
parent
aff1b5313b
commit
26fe4a047d
|
@ -471,12 +471,10 @@ class Tacotron(nn.Module):
|
||||||
# put after encoder
|
# put after encoder
|
||||||
if hparams.use_gst and self.gst is not None:
|
if hparams.use_gst and self.gst is not None:
|
||||||
if style_idx >= 0 and style_idx < 10:
|
if style_idx >= 0 and style_idx < 10:
|
||||||
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
|
query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda()
|
||||||
gst_embed = np.tile(gst_embed, (1, 8))
|
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||||
scale = np.zeros(512)
|
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||||
scale[:] = 0.3
|
style_embed = self.gst.stl.attention(query, key)
|
||||||
speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32)
|
|
||||||
speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device)
|
|
||||||
else:
|
else:
|
||||||
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user