From a37b26a89cea70719717ac45ef1255093c7a4174 Mon Sep 17 00:00:00 2001 From: babysor00 Date: Wed, 10 Nov 2021 23:23:13 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=85=BC=E5=AE=B9=E9=97=AE?= =?UTF-8?q?=E9=A2=98=E5=8A=A0=E5=BC=BA=20Compatibility=20Enhance=20of=20Pr?= =?UTF-8?q?etrained=20Models=20and=20code=20base=20#209?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- synthesizer/hparams.py | 2 ++ synthesizer/models/global_style_token.py | 5 ++-- synthesizer/models/tacotron.py | 30 +++++++++++++++++------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/synthesizer/hparams.py b/synthesizer/hparams.py index 629e144..e2d6a0f 100644 --- a/synthesizer/hparams.py +++ b/synthesizer/hparams.py @@ -91,4 +91,6 @@ hparams = HParams( speaker_embedding_size = 256, # Dimension for the speaker embedding silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded + use_gst = True, # Whether to use global style token + use_ser_for_gst = False, # Whether to use speaker embedding referenced for global style token ) diff --git a/synthesizer/models/global_style_token.py b/synthesizer/models/global_style_token.py index cef3009..229b9ef 100644 --- a/synthesizer/models/global_style_token.py +++ b/synthesizer/models/global_style_token.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.init as init import torch.nn.functional as tFunctional from synthesizer.gst_hyperparameters import GSTHyperparameters as hp +from synthesizer.hparams import hparams class GlobalStyleToken(nn.Module): @@ -20,7 +21,7 @@ class GlobalStyleToken(nn.Module): def forward(self, inputs, speaker_embedding=None): enc_out = self.encoder(inputs) # concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py - if speaker_embedding is not None: + if hparams.use_ser_for_gst and speaker_embedding is not None: enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) style_embed = self.stl(enc_out) @@ -87,7 +88,7 @@ class STL(nn.Module): d_q = hp.E // 2 d_k = hp.E // hp.num_heads # self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v) - if speaker_embedding_dim: + if hparams.use_ser_for_gst and speaker_embedding_dim is not None: d_q += speaker_embedding_dim self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads) diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index e83ab60..eaa7818 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from synthesizer.models.global_style_token import GlobalStyleToken from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp +from synthesizer.hparams import hparams class HighwayNetwork(nn.Module): @@ -255,12 +256,14 @@ class Decoder(nn.Module): self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], dropout=dropout) self.attn_net = LSA(decoder_dims) - self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size + gst_hp.E, decoder_dims) - self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size + gst_hp.E, lstm_dims) + if hparams.use_gst: + speaker_embedding_size += gst_hp.E + self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims) + self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims) self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) - self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims + gst_hp.E, 1) + self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1) def zoneout(self, prev, current, p=0.1): device = next(self.parameters()).device # Use same device as parameters @@ -337,8 +340,11 @@ class Tacotron(nn.Module): self.speaker_embedding_size = speaker_embedding_size self.encoder = Encoder(embed_dims, num_chars, encoder_dims, encoder_K, num_highways, dropout) - self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size + gst_hp.E, decoder_dims, bias=False) - self.gst = GlobalStyleToken(speaker_embedding_size) + project_dims = encoder_dims + speaker_embedding_size + if hparams.use_gst: + project_dims += gst_hp.E + self.gst = GlobalStyleToken(speaker_embedding_size) + self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False) self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims, dropout, speaker_embedding_size) self.postnet = CBHG(postnet_K, n_mels, postnet_dims, @@ -387,13 +393,16 @@ class Tacotron(nn.Module): go_frame = torch.zeros(batch_size, self.n_mels, device=device) # Need an initial context vector - context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size + gst_hp.E, device=device) + size = self.encoder_dims + self.speaker_embedding_size + if hparams.use_gst: + size += gst_hp.E + context_vec = torch.zeros(batch_size, size, device=device) # SV2TTS: Run the encoder with the speaker embedding # The projection avoids unnecessary matmuls in the decoder loop encoder_seq = self.encoder(texts, speaker_embedding) # put after encoder - if self.gst is not None: + if hparams.use_gst and self.gst is not None: style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced # style_embed = style_embed.expand_as(encoder_seq) # encoder_seq = torch.cat((encoder_seq, style_embed), 2) @@ -449,14 +458,17 @@ class Tacotron(nn.Module): go_frame = torch.zeros(batch_size, self.n_mels, device=device) # Need an initial context vector - context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size + gst_hp.E, device=device) + size = self.encoder_dims + self.speaker_embedding_size + if hparams.use_gst: + size += gst_hp.E + context_vec = torch.zeros(batch_size, size, device=device) # SV2TTS: Run the encoder with the speaker embedding # The projection avoids unnecessary matmuls in the decoder loop encoder_seq = self.encoder(x, speaker_embedding) # put after encoder - if self.gst is not None: + if hparams.use_gst and self.gst is not None: if style_idx >= 0 and style_idx < 10: gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token] gst_embed = np.tile(gst_embed, (1, 8))