2021-10-12 19:43:29 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.init as init
|
|
|
|
import torch.nn.functional as tFunctional
|
|
|
|
from synthesizer.gst_hyperparameters import GSTHyperparameters as hp
|
2021-11-10 23:23:13 +08:00
|
|
|
from synthesizer.hparams import hparams
|
2021-10-12 19:43:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
class GlobalStyleToken(nn.Module):
|
2021-11-07 21:48:15 +08:00
|
|
|
"""
|
|
|
|
inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel]
|
|
|
|
speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel]
|
|
|
|
outputs: [batch_size, embedding_dim]
|
|
|
|
"""
|
|
|
|
def __init__(self, speaker_embedding_dim=None):
|
2021-10-12 19:43:29 +08:00
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
self.encoder = ReferenceEncoder()
|
2021-11-07 21:48:15 +08:00
|
|
|
self.stl = STL(speaker_embedding_dim)
|
2021-10-12 19:43:29 +08:00
|
|
|
|
2021-11-07 21:48:15 +08:00
|
|
|
def forward(self, inputs, speaker_embedding=None):
|
2021-10-12 19:43:29 +08:00
|
|
|
enc_out = self.encoder(inputs)
|
2021-11-07 21:48:15 +08:00
|
|
|
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
|
2021-11-10 23:23:13 +08:00
|
|
|
if hparams.use_ser_for_gst and speaker_embedding is not None:
|
2021-11-07 21:48:15 +08:00
|
|
|
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
2021-10-12 19:43:29 +08:00
|
|
|
style_embed = self.stl(enc_out)
|
|
|
|
|
|
|
|
return style_embed
|
|
|
|
|
|
|
|
|
|
|
|
class ReferenceEncoder(nn.Module):
|
|
|
|
'''
|
|
|
|
inputs --- [N, Ty/r, n_mels*r] mels
|
|
|
|
outputs --- [N, ref_enc_gru_size]
|
|
|
|
'''
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
K = len(hp.ref_enc_filters)
|
|
|
|
filters = [1] + hp.ref_enc_filters
|
|
|
|
convs = [nn.Conv2d(in_channels=filters[i],
|
|
|
|
out_channels=filters[i + 1],
|
|
|
|
kernel_size=(3, 3),
|
|
|
|
stride=(2, 2),
|
|
|
|
padding=(1, 1)) for i in range(K)]
|
|
|
|
self.convs = nn.ModuleList(convs)
|
|
|
|
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hp.ref_enc_filters[i]) for i in range(K)])
|
|
|
|
|
|
|
|
out_channels = self.calculate_channels(hp.n_mels, 3, 2, 1, K)
|
|
|
|
self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels,
|
|
|
|
hidden_size=hp.E // 2,
|
|
|
|
batch_first=True)
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
N = inputs.size(0)
|
|
|
|
out = inputs.view(N, 1, -1, hp.n_mels) # [N, 1, Ty, n_mels]
|
|
|
|
for conv, bn in zip(self.convs, self.bns):
|
|
|
|
out = conv(out)
|
|
|
|
out = bn(out)
|
|
|
|
out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
|
|
|
|
|
|
|
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
|
|
|
T = out.size(1)
|
|
|
|
N = out.size(0)
|
|
|
|
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
|
|
|
|
|
|
|
self.gru.flatten_parameters()
|
|
|
|
memory, out = self.gru(out) # out --- [1, N, E//2]
|
|
|
|
|
|
|
|
return out.squeeze(0)
|
|
|
|
|
|
|
|
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
|
|
|
for i in range(n_convs):
|
|
|
|
L = (L - kernel_size + 2 * pad) // stride + 1
|
|
|
|
return L
|
|
|
|
|
|
|
|
|
|
|
|
class STL(nn.Module):
|
|
|
|
'''
|
|
|
|
inputs --- [N, E//2]
|
|
|
|
'''
|
|
|
|
|
2021-11-07 21:48:15 +08:00
|
|
|
def __init__(self, speaker_embedding_dim=None):
|
2021-10-12 19:43:29 +08:00
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
|
|
|
|
d_q = hp.E // 2
|
|
|
|
d_k = hp.E // hp.num_heads
|
|
|
|
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
2021-11-10 23:23:13 +08:00
|
|
|
if hparams.use_ser_for_gst and speaker_embedding_dim is not None:
|
2021-11-07 21:48:15 +08:00
|
|
|
d_q += speaker_embedding_dim
|
2021-10-12 19:43:29 +08:00
|
|
|
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
|
|
|
|
|
|
|
init.normal_(self.embed, mean=0, std=0.5)
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
N = inputs.size(0)
|
|
|
|
query = inputs.unsqueeze(1) # [N, 1, E//2]
|
|
|
|
keys = tFunctional.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
|
|
|
style_embed = self.attention(query, keys)
|
|
|
|
|
|
|
|
return style_embed
|
|
|
|
|
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
|
|
'''
|
|
|
|
input:
|
|
|
|
query --- [N, T_q, query_dim]
|
|
|
|
key --- [N, T_k, key_dim]
|
|
|
|
output:
|
|
|
|
out --- [N, T_q, num_units]
|
|
|
|
'''
|
|
|
|
|
|
|
|
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
self.num_units = num_units
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.key_dim = key_dim
|
|
|
|
|
|
|
|
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
|
|
|
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
|
|
|
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
|
|
|
|
|
|
|
def forward(self, query, key):
|
|
|
|
querys = self.W_query(query) # [N, T_q, num_units]
|
|
|
|
keys = self.W_key(key) # [N, T_k, num_units]
|
|
|
|
values = self.W_value(key)
|
|
|
|
|
|
|
|
split_size = self.num_units // self.num_heads
|
|
|
|
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
|
|
|
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
|
|
|
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
|
|
|
|
|
|
|
# score = softmax(QK^T / (d_k ** 0.5))
|
|
|
|
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
|
|
|
scores = scores / (self.key_dim ** 0.5)
|
|
|
|
scores = tFunctional.softmax(scores, dim=3)
|
|
|
|
|
|
|
|
# out = score * V
|
|
|
|
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
|
|
|
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
|
|
|
|
|
|
|
return out
|