mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
模型兼容问题加强 Compatibility Enhance of Pretrained Models and code base #209
This commit is contained in:
parent
902e1eb537
commit
a37b26a89c
|
@ -91,4 +91,6 @@ hparams = HParams(
|
||||||
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
||||||
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
||||||
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
||||||
|
use_gst = True, # Whether to use global style token
|
||||||
|
use_ser_for_gst = False, # Whether to use speaker embedding referenced for global style token
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
||||||
import torch.nn.init as init
|
import torch.nn.init as init
|
||||||
import torch.nn.functional as tFunctional
|
import torch.nn.functional as tFunctional
|
||||||
from synthesizer.gst_hyperparameters import GSTHyperparameters as hp
|
from synthesizer.gst_hyperparameters import GSTHyperparameters as hp
|
||||||
|
from synthesizer.hparams import hparams
|
||||||
|
|
||||||
|
|
||||||
class GlobalStyleToken(nn.Module):
|
class GlobalStyleToken(nn.Module):
|
||||||
|
@ -20,7 +21,7 @@ class GlobalStyleToken(nn.Module):
|
||||||
def forward(self, inputs, speaker_embedding=None):
|
def forward(self, inputs, speaker_embedding=None):
|
||||||
enc_out = self.encoder(inputs)
|
enc_out = self.encoder(inputs)
|
||||||
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
|
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
|
||||||
if speaker_embedding is not None:
|
if hparams.use_ser_for_gst and speaker_embedding is not None:
|
||||||
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
||||||
style_embed = self.stl(enc_out)
|
style_embed = self.stl(enc_out)
|
||||||
|
|
||||||
|
@ -87,7 +88,7 @@ class STL(nn.Module):
|
||||||
d_q = hp.E // 2
|
d_q = hp.E // 2
|
||||||
d_k = hp.E // hp.num_heads
|
d_k = hp.E // hp.num_heads
|
||||||
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
||||||
if speaker_embedding_dim:
|
if hparams.use_ser_for_gst and speaker_embedding_dim is not None:
|
||||||
d_q += speaker_embedding_dim
|
d_q += speaker_embedding_dim
|
||||||
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from synthesizer.models.global_style_token import GlobalStyleToken
|
from synthesizer.models.global_style_token import GlobalStyleToken
|
||||||
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
||||||
|
from synthesizer.hparams import hparams
|
||||||
|
|
||||||
|
|
||||||
class HighwayNetwork(nn.Module):
|
class HighwayNetwork(nn.Module):
|
||||||
|
@ -255,12 +256,14 @@ class Decoder(nn.Module):
|
||||||
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
self.attn_net = LSA(decoder_dims)
|
self.attn_net = LSA(decoder_dims)
|
||||||
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size + gst_hp.E, decoder_dims)
|
if hparams.use_gst:
|
||||||
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size + gst_hp.E, lstm_dims)
|
speaker_embedding_size += gst_hp.E
|
||||||
|
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
||||||
|
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
||||||
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||||
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
||||||
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
||||||
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims + gst_hp.E, 1)
|
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
||||||
|
|
||||||
def zoneout(self, prev, current, p=0.1):
|
def zoneout(self, prev, current, p=0.1):
|
||||||
device = next(self.parameters()).device # Use same device as parameters
|
device = next(self.parameters()).device # Use same device as parameters
|
||||||
|
@ -337,8 +340,11 @@ class Tacotron(nn.Module):
|
||||||
self.speaker_embedding_size = speaker_embedding_size
|
self.speaker_embedding_size = speaker_embedding_size
|
||||||
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
||||||
encoder_K, num_highways, dropout)
|
encoder_K, num_highways, dropout)
|
||||||
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size + gst_hp.E, decoder_dims, bias=False)
|
project_dims = encoder_dims + speaker_embedding_size
|
||||||
self.gst = GlobalStyleToken(speaker_embedding_size)
|
if hparams.use_gst:
|
||||||
|
project_dims += gst_hp.E
|
||||||
|
self.gst = GlobalStyleToken(speaker_embedding_size)
|
||||||
|
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
|
||||||
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
||||||
dropout, speaker_embedding_size)
|
dropout, speaker_embedding_size)
|
||||||
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
||||||
|
@ -387,13 +393,16 @@ class Tacotron(nn.Module):
|
||||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||||
|
|
||||||
# Need an initial context vector
|
# Need an initial context vector
|
||||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size + gst_hp.E, device=device)
|
size = self.encoder_dims + self.speaker_embedding_size
|
||||||
|
if hparams.use_gst:
|
||||||
|
size += gst_hp.E
|
||||||
|
context_vec = torch.zeros(batch_size, size, device=device)
|
||||||
|
|
||||||
# SV2TTS: Run the encoder with the speaker embedding
|
# SV2TTS: Run the encoder with the speaker embedding
|
||||||
# The projection avoids unnecessary matmuls in the decoder loop
|
# The projection avoids unnecessary matmuls in the decoder loop
|
||||||
encoder_seq = self.encoder(texts, speaker_embedding)
|
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||||
# put after encoder
|
# put after encoder
|
||||||
if self.gst is not None:
|
if hparams.use_gst and self.gst is not None:
|
||||||
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||||
# style_embed = style_embed.expand_as(encoder_seq)
|
# style_embed = style_embed.expand_as(encoder_seq)
|
||||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||||
|
@ -449,14 +458,17 @@ class Tacotron(nn.Module):
|
||||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||||
|
|
||||||
# Need an initial context vector
|
# Need an initial context vector
|
||||||
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size + gst_hp.E, device=device)
|
size = self.encoder_dims + self.speaker_embedding_size
|
||||||
|
if hparams.use_gst:
|
||||||
|
size += gst_hp.E
|
||||||
|
context_vec = torch.zeros(batch_size, size, device=device)
|
||||||
|
|
||||||
# SV2TTS: Run the encoder with the speaker embedding
|
# SV2TTS: Run the encoder with the speaker embedding
|
||||||
# The projection avoids unnecessary matmuls in the decoder loop
|
# The projection avoids unnecessary matmuls in the decoder loop
|
||||||
encoder_seq = self.encoder(x, speaker_embedding)
|
encoder_seq = self.encoder(x, speaker_embedding)
|
||||||
|
|
||||||
# put after encoder
|
# put after encoder
|
||||||
if self.gst is not None:
|
if hparams.use_gst and self.gst is not None:
|
||||||
if style_idx >= 0 and style_idx < 10:
|
if style_idx >= 0 and style_idx < 10:
|
||||||
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
|
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
|
||||||
gst_embed = np.tile(gst_embed, (1, 8))
|
gst_embed = np.tile(gst_embed, (1, 8))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user