mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Fix #205
This commit is contained in:
parent
6c8f3f4515
commit
5c0e53a29a
|
@ -359,27 +359,6 @@ class Tacotron(nn.Module):
|
||||||
def r(self, value):
|
def r(self, value):
|
||||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||||
|
|
||||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
|
||||||
""" Compute global style token """
|
|
||||||
device = inputs.device
|
|
||||||
if isinstance(style_input, dict):
|
|
||||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
|
||||||
if speaker_embedding is not None:
|
|
||||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
|
||||||
|
|
||||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
|
||||||
for k_token, v_amplifier in style_input.items():
|
|
||||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
|
||||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
|
||||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
|
||||||
elif style_input is None:
|
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
|
||||||
else:
|
|
||||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
|
||||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||||
speaker_embeddings_ = speaker_embeddings.expand(
|
speaker_embeddings_ = speaker_embeddings.expand(
|
||||||
|
@ -486,7 +465,7 @@ class Tacotron(nn.Module):
|
||||||
speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32)
|
speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32)
|
||||||
speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device)
|
speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device)
|
||||||
else:
|
else:
|
||||||
speaker_embedding_style = torch.zeros(2, 1, self.speaker_embedding_size).to(device)
|
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||||
# style_embed = style_embed.expand_as(encoder_seq)
|
# style_embed = style_embed.expand_as(encoder_seq)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user