mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Use speaker embedding anyway even with default style
This commit is contained in:
parent
80aaf32164
commit
3674d8b5c6
|
@ -6,15 +6,22 @@ from synthesizer.gst_hyperparameters import GSTHyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
class GlobalStyleToken(nn.Module):
|
class GlobalStyleToken(nn.Module):
|
||||||
|
"""
|
||||||
def __init__(self):
|
inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||||
|
speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||||
|
outputs: [batch_size, embedding_dim]
|
||||||
|
"""
|
||||||
|
def __init__(self, speaker_embedding_dim=None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = ReferenceEncoder()
|
self.encoder = ReferenceEncoder()
|
||||||
self.stl = STL()
|
self.stl = STL(speaker_embedding_dim)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs, speaker_embedding=None):
|
||||||
enc_out = self.encoder(inputs)
|
enc_out = self.encoder(inputs)
|
||||||
|
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
|
||||||
|
if speaker_embedding is not None:
|
||||||
|
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
||||||
style_embed = self.stl(enc_out)
|
style_embed = self.stl(enc_out)
|
||||||
|
|
||||||
return style_embed
|
return style_embed
|
||||||
|
@ -73,13 +80,15 @@ class STL(nn.Module):
|
||||||
inputs --- [N, E//2]
|
inputs --- [N, E//2]
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, speaker_embedding_dim=None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
|
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
|
||||||
d_q = hp.E // 2
|
d_q = hp.E // 2
|
||||||
d_k = hp.E // hp.num_heads
|
d_k = hp.E // hp.num_heads
|
||||||
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
||||||
|
if speaker_embedding_dim:
|
||||||
|
d_q += speaker_embedding_dim
|
||||||
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
||||||
|
|
||||||
init.normal_(self.embed, mean=0, std=0.5)
|
init.normal_(self.embed, mean=0, std=0.5)
|
||||||
|
|
|
@ -338,7 +338,7 @@ class Tacotron(nn.Module):
|
||||||
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
||||||
encoder_K, num_highways, dropout)
|
encoder_K, num_highways, dropout)
|
||||||
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size + gst_hp.E, decoder_dims, bias=False)
|
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size + gst_hp.E, decoder_dims, bias=False)
|
||||||
self.gst = GlobalStyleToken()
|
self.gst = GlobalStyleToken(speaker_embedding_size)
|
||||||
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
||||||
dropout, speaker_embedding_size)
|
dropout, speaker_embedding_size)
|
||||||
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
||||||
|
@ -359,6 +359,34 @@ class Tacotron(nn.Module):
|
||||||
def r(self, value):
|
def r(self, value):
|
||||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||||
|
|
||||||
|
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||||
|
""" Compute global style token """
|
||||||
|
device = inputs.device
|
||||||
|
if isinstance(style_input, dict):
|
||||||
|
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||||
|
if speaker_embedding is not None:
|
||||||
|
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||||
|
|
||||||
|
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||||
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
for k_token, v_amplifier in style_input.items():
|
||||||
|
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
|
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||||
|
elif style_input is None:
|
||||||
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
else:
|
||||||
|
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||||
|
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||||
|
speaker_embeddings_ = speaker_embeddings.expand(
|
||||||
|
outputs.size(0), outputs.size(1), -1)
|
||||||
|
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||||
|
return outputs
|
||||||
|
|
||||||
def forward(self, texts, mels, speaker_embedding):
|
def forward(self, texts, mels, speaker_embedding):
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = next(self.parameters()).device # use same device as parameters
|
||||||
|
|
||||||
|
@ -387,9 +415,10 @@ class Tacotron(nn.Module):
|
||||||
encoder_seq = self.encoder(texts, speaker_embedding)
|
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||||
# put after encoder
|
# put after encoder
|
||||||
if self.gst is not None:
|
if self.gst is not None:
|
||||||
style_embed = self.gst(speaker_embedding)
|
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||||
style_embed = style_embed.expand_as(encoder_seq)
|
# style_embed = style_embed.expand_as(encoder_seq)
|
||||||
encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||||
|
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
# Need a couple of lists for outputs
|
# Need a couple of lists for outputs
|
||||||
|
@ -454,11 +483,14 @@ class Tacotron(nn.Module):
|
||||||
gst_embed = np.tile(gst_embed, (1, 8))
|
gst_embed = np.tile(gst_embed, (1, 8))
|
||||||
scale = np.zeros(512)
|
scale = np.zeros(512)
|
||||||
scale[:] = 0.3
|
scale[:] = 0.3
|
||||||
speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32)
|
speaker_embedding_style = (gst_embed[style_idx] * scale).astype(np.float32)
|
||||||
speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device)
|
speaker_embedding_style = torch.from_numpy(np.tile(speaker_embedding_style, (x.shape[0], 1))).to(device)
|
||||||
style_embed = self.gst(speaker_embedding)
|
else:
|
||||||
style_embed = style_embed.expand_as(encoder_seq)
|
speaker_embedding_style = torch.zeros(2, 1, self.speaker_embedding_size).to(device)
|
||||||
encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||||
|
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||||
|
# style_embed = style_embed.expand_as(encoder_seq)
|
||||||
|
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
# Need a couple of lists for outputs
|
# Need a couple of lists for outputs
|
||||||
|
|
Loading…
Reference in New Issue
Block a user