mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Concat GST output instead of adding directly with original output
This commit is contained in:
parent
724194a4de
commit
7c58fe01d1
|
@ -4,6 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from synthesizer.models.global_style_token import GlobalStyleToken
|
||||
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
||||
|
||||
|
||||
class HighwayNetwork(nn.Module):
|
||||
|
@ -254,12 +255,12 @@ class Decoder(nn.Module):
|
|||
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
||||
dropout=dropout)
|
||||
self.attn_net = LSA(decoder_dims)
|
||||
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
||||
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
||||
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size + gst_hp.E, decoder_dims)
|
||||
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size + gst_hp.E, lstm_dims)
|
||||
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
||||
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
||||
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims + gst_hp.E, 1)
|
||||
|
||||
def zoneout(self, prev, current, p=0.1):
|
||||
device = next(self.parameters()).device # Use same device as parameters
|
||||
|
@ -336,7 +337,7 @@ class Tacotron(nn.Module):
|
|||
self.speaker_embedding_size = speaker_embedding_size
|
||||
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
||||
encoder_K, num_highways, dropout)
|
||||
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
||||
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size + gst_hp.E, decoder_dims, bias=False)
|
||||
self.gst = GlobalStyleToken()
|
||||
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
||||
dropout, speaker_embedding_size)
|
||||
|
@ -379,7 +380,7 @@ class Tacotron(nn.Module):
|
|||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||
|
||||
# Need an initial context vector
|
||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size + gst_hp.E, device=device)
|
||||
|
||||
# SV2TTS: Run the encoder with the speaker embedding
|
||||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
|
@ -388,7 +389,7 @@ class Tacotron(nn.Module):
|
|||
if self.gst is not None:
|
||||
style_embed = self.gst(speaker_embedding)
|
||||
style_embed = style_embed.expand_as(encoder_seq)
|
||||
encoder_seq = encoder_seq + style_embed
|
||||
encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
||||
# Need a couple of lists for outputs
|
||||
|
@ -440,14 +441,15 @@ class Tacotron(nn.Module):
|
|||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||
|
||||
# Need an initial context vector
|
||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size + gst_hp.E, device=device)
|
||||
|
||||
# SV2TTS: Run the encoder with the speaker embedding
|
||||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
encoder_seq = self.encoder(x, speaker_embedding)
|
||||
|
||||
# put after encoder
|
||||
if self.gst is not None and style_idx >= 0 and style_idx < 10:
|
||||
if self.gst is not None:
|
||||
if style_idx >= 0 and style_idx < 10:
|
||||
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
|
||||
gst_embed = np.tile(gst_embed, (1, 8))
|
||||
scale = np.zeros(512)
|
||||
|
@ -456,7 +458,7 @@ class Tacotron(nn.Module):
|
|||
speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device)
|
||||
style_embed = self.gst(speaker_embedding)
|
||||
style_embed = style_embed.expand_as(encoder_seq)
|
||||
encoder_seq = encoder_seq + style_embed
|
||||
encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
||||
# Need a couple of lists for outputs
|
||||
|
|
Loading…
Reference in New Issue
Block a user