mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
parent
400a7207e3
commit
6abdd0ebf0
73
synthesizer/models/base.py
Normal file
73
synthesizer/models/base.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import imp
|
||||
import numpy as np
|
||||
|
||||
class Base(nn.Module):
|
||||
def __init__(self, stop_threshold):
|
||||
super().__init__()
|
||||
|
||||
self.init_model()
|
||||
self.num_params()
|
||||
|
||||
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
||||
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
||||
|
||||
@property
|
||||
def r(self):
|
||||
return self.decoder.r.item()
|
||||
|
||||
@r.setter
|
||||
def r(self, value):
|
||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||
|
||||
def init_model(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||
|
||||
def finetune_partial(self, whitelist_layers):
|
||||
self.zero_grad()
|
||||
for name, child in self.named_children():
|
||||
if name in whitelist_layers:
|
||||
print("Trainable Layer: %s" % name)
|
||||
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||
for param in child.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_step(self):
|
||||
return self.step.data.item()
|
||||
|
||||
def reset_step(self):
|
||||
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
||||
self.step = self.step.data.new_tensor(1)
|
||||
|
||||
def log(self, path, msg):
|
||||
with open(path, "a") as f:
|
||||
print(msg, file=f)
|
||||
|
||||
def load(self, path, device, optimizer=None):
|
||||
# Use device of model params as location for loaded state
|
||||
checkpoint = torch.load(str(path), map_location=device)
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
|
||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
|
||||
def save(self, path, optimizer=None):
|
||||
if optimizer is not None:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, str(path))
|
||||
else:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
}, str(path))
|
||||
|
||||
|
||||
def num_params(self, print_out=True):
|
||||
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
||||
if print_out:
|
||||
print("Trainable Parameters: %.3fM" % parameters)
|
||||
return parameters
|
1
synthesizer/models/sublayer/__init__.py
Normal file
1
synthesizer/models/sublayer/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
#
|
85
synthesizer/models/sublayer/cbhg.py
Normal file
85
synthesizer/models/sublayer/cbhg.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .common.batch_norm_conv import BatchNormConv
|
||||
from .common.highway_network import HighwayNetwork
|
||||
|
||||
class CBHG(nn.Module):
|
||||
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
||||
super().__init__()
|
||||
|
||||
# List of all rnns to call `flatten_parameters()` on
|
||||
self._to_flatten = []
|
||||
|
||||
self.bank_kernels = [i for i in range(1, K + 1)]
|
||||
self.conv1d_bank = nn.ModuleList()
|
||||
for k in self.bank_kernels:
|
||||
conv = BatchNormConv(in_channels, channels, k)
|
||||
self.conv1d_bank.append(conv)
|
||||
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
|
||||
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
||||
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
||||
|
||||
# Fix the highway input if necessary
|
||||
if proj_channels[-1] != channels:
|
||||
self.highway_mismatch = True
|
||||
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
||||
else:
|
||||
self.highway_mismatch = False
|
||||
|
||||
self.highways = nn.ModuleList()
|
||||
for i in range(num_highways):
|
||||
hn = HighwayNetwork(channels)
|
||||
self.highways.append(hn)
|
||||
|
||||
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
||||
self._to_flatten.append(self.rnn)
|
||||
|
||||
# Avoid fragmentation of RNN parameters and associated warning
|
||||
self._flatten_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||
# the model gets replicated, making it no longer guaranteed that the
|
||||
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||
self.rnn.flatten_parameters()
|
||||
|
||||
# Save these for later
|
||||
residual = x
|
||||
seq_len = x.size(-1)
|
||||
conv_bank = []
|
||||
|
||||
# Convolution Bank
|
||||
for conv in self.conv1d_bank:
|
||||
c = conv(x) # Convolution
|
||||
conv_bank.append(c[:, :, :seq_len])
|
||||
|
||||
# Stack along the channel axis
|
||||
conv_bank = torch.cat(conv_bank, dim=1)
|
||||
|
||||
# dump the last padding to fit residual
|
||||
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
||||
|
||||
# Conv1d projections
|
||||
x = self.conv_project1(x)
|
||||
x = self.conv_project2(x)
|
||||
|
||||
# Residual Connect
|
||||
x = x + residual
|
||||
|
||||
# Through the highways
|
||||
x = x.transpose(1, 2)
|
||||
if self.highway_mismatch is True:
|
||||
x = self.pre_highway(x)
|
||||
for h in self.highways: x = h(x)
|
||||
|
||||
# And then the RNN
|
||||
x, _ = self.rnn(x)
|
||||
return x
|
||||
|
||||
def _flatten_parameters(self):
|
||||
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
||||
to improve efficiency and avoid PyTorch yelling at us."""
|
||||
[m.flatten_parameters() for m in self._to_flatten]
|
||||
|
14
synthesizer/models/sublayer/common/batch_norm_conv.py
Normal file
14
synthesizer/models/sublayer/common/batch_norm_conv.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class BatchNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
||||
self.bnorm = nn.BatchNorm1d(out_channels)
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(x) if self.relu is True else x
|
||||
return self.bnorm(x)
|
17
synthesizer/models/sublayer/common/highway_network.py
Normal file
17
synthesizer/models/sublayer/common/highway_network.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class HighwayNetwork(nn.Module):
|
||||
def __init__(self, size):
|
||||
super().__init__()
|
||||
self.W1 = nn.Linear(size, size)
|
||||
self.W2 = nn.Linear(size, size)
|
||||
self.W1.bias.data.fill_(0.)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.W1(x)
|
||||
x2 = self.W2(x)
|
||||
g = torch.sigmoid(x2)
|
||||
y = g * F.relu(x1) + (1. - g) * x
|
||||
return y
|
42
synthesizer/models/sublayer/lsa.py
Normal file
42
synthesizer/models/sublayer/lsa.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LSA(nn.Module):
|
||||
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
||||
self.L = nn.Linear(filters, attn_dim, bias=False)
|
||||
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
self.cumulative = None
|
||||
self.attention = None
|
||||
|
||||
def init_attention(self, encoder_seq_proj):
|
||||
device = encoder_seq_proj.device # use same device as parameters
|
||||
b, t, c = encoder_seq_proj.size()
|
||||
self.cumulative = torch.zeros(b, t, device=device)
|
||||
self.attention = torch.zeros(b, t, device=device)
|
||||
|
||||
def forward(self, encoder_seq_proj, query, t, chars):
|
||||
|
||||
if t == 0: self.init_attention(encoder_seq_proj)
|
||||
|
||||
processed_query = self.W(query).unsqueeze(1)
|
||||
|
||||
location = self.cumulative.unsqueeze(1)
|
||||
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
||||
|
||||
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
||||
u = u.squeeze(-1)
|
||||
|
||||
# Mask zero padding chars
|
||||
u = u * (chars != 0).float()
|
||||
|
||||
# Smooth Attention
|
||||
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
||||
scores = F.softmax(u, dim=1)
|
||||
self.attention = scores
|
||||
self.cumulative = self.cumulative + self.attention
|
||||
|
||||
return scores.unsqueeze(-1).transpose(1, 2)
|
18
synthesizer/models/sublayer/pre_net.py
Normal file
18
synthesizer/models/sublayer/pre_net.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class PreNet(nn.Module):
|
||||
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
||||
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
||||
self.p = dropout
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
x = self.fc2(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
return x
|
|
@ -1,29 +1,14 @@
|
|||
import os
|
||||
from matplotlib.pyplot import step
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from synthesizer.models.global_style_token import GlobalStyleToken
|
||||
from .sublayer.global_style_token import GlobalStyleToken
|
||||
from .sublayer.pre_net import PreNet
|
||||
from .sublayer.cbhg import CBHG
|
||||
from .sublayer.lsa import LSA
|
||||
from .base import Base
|
||||
from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp
|
||||
from synthesizer.hparams import hparams
|
||||
|
||||
|
||||
class HighwayNetwork(nn.Module):
|
||||
def __init__(self, size):
|
||||
super().__init__()
|
||||
self.W1 = nn.Linear(size, size)
|
||||
self.W2 = nn.Linear(size, size)
|
||||
self.W1.bias.data.fill_(0.)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.W1(x)
|
||||
x2 = self.W2(x)
|
||||
g = torch.sigmoid(x2)
|
||||
y = g * F.relu(x1) + (1. - g) * x
|
||||
return y
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
|
||||
super().__init__()
|
||||
|
@ -36,213 +21,11 @@ class Encoder(nn.Module):
|
|||
proj_channels=[cbhg_channels, cbhg_channels],
|
||||
num_highways=num_highways)
|
||||
|
||||
def forward(self, x, speaker_embedding=None):
|
||||
def forward(self, x):
|
||||
x = self.embedding(x)
|
||||
x = self.pre_net(x)
|
||||
x.transpose_(1, 2)
|
||||
x = self.cbhg(x)
|
||||
if speaker_embedding is not None:
|
||||
x = self.add_speaker_embedding(x, speaker_embedding)
|
||||
return x
|
||||
|
||||
def add_speaker_embedding(self, x, speaker_embedding):
|
||||
# SV2TTS
|
||||
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
||||
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
||||
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
||||
# This concats the speaker embedding for each char in the encoder output
|
||||
|
||||
# Save the dimensions as human-readable names
|
||||
batch_size = x.size()[0]
|
||||
num_chars = x.size()[1]
|
||||
|
||||
if speaker_embedding.dim() == 1:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
|
||||
# Start by making a copy of each speaker embedding to match the input text length
|
||||
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
|
||||
speaker_embedding_size = speaker_embedding.size()[idx]
|
||||
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
||||
|
||||
# Reshape it and transpose
|
||||
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
||||
e = e.transpose(1, 2)
|
||||
|
||||
# Concatenate the tiled speaker embedding with the encoder output
|
||||
x = torch.cat((x, e), 2)
|
||||
return x
|
||||
|
||||
|
||||
class BatchNormConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
||||
self.bnorm = nn.BatchNorm1d(out_channels)
|
||||
self.relu = relu
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(x) if self.relu is True else x
|
||||
return self.bnorm(x)
|
||||
|
||||
|
||||
class CBHG(nn.Module):
|
||||
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
||||
super().__init__()
|
||||
|
||||
# List of all rnns to call `flatten_parameters()` on
|
||||
self._to_flatten = []
|
||||
|
||||
self.bank_kernels = [i for i in range(1, K + 1)]
|
||||
self.conv1d_bank = nn.ModuleList()
|
||||
for k in self.bank_kernels:
|
||||
conv = BatchNormConv(in_channels, channels, k)
|
||||
self.conv1d_bank.append(conv)
|
||||
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
|
||||
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
||||
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
||||
|
||||
# Fix the highway input if necessary
|
||||
if proj_channels[-1] != channels:
|
||||
self.highway_mismatch = True
|
||||
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
||||
else:
|
||||
self.highway_mismatch = False
|
||||
|
||||
self.highways = nn.ModuleList()
|
||||
for i in range(num_highways):
|
||||
hn = HighwayNetwork(channels)
|
||||
self.highways.append(hn)
|
||||
|
||||
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
||||
self._to_flatten.append(self.rnn)
|
||||
|
||||
# Avoid fragmentation of RNN parameters and associated warning
|
||||
self._flatten_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||
# the model gets replicated, making it no longer guaranteed that the
|
||||
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||
self.rnn.flatten_parameters()
|
||||
|
||||
# Save these for later
|
||||
residual = x
|
||||
seq_len = x.size(-1)
|
||||
conv_bank = []
|
||||
|
||||
# Convolution Bank
|
||||
for conv in self.conv1d_bank:
|
||||
c = conv(x) # Convolution
|
||||
conv_bank.append(c[:, :, :seq_len])
|
||||
|
||||
# Stack along the channel axis
|
||||
conv_bank = torch.cat(conv_bank, dim=1)
|
||||
|
||||
# dump the last padding to fit residual
|
||||
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
||||
|
||||
# Conv1d projections
|
||||
x = self.conv_project1(x)
|
||||
x = self.conv_project2(x)
|
||||
|
||||
# Residual Connect
|
||||
x = x + residual
|
||||
|
||||
# Through the highways
|
||||
x = x.transpose(1, 2)
|
||||
if self.highway_mismatch is True:
|
||||
x = self.pre_highway(x)
|
||||
for h in self.highways: x = h(x)
|
||||
|
||||
# And then the RNN
|
||||
x, _ = self.rnn(x)
|
||||
return x
|
||||
|
||||
def _flatten_parameters(self):
|
||||
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
||||
to improve efficiency and avoid PyTorch yelling at us."""
|
||||
[m.flatten_parameters() for m in self._to_flatten]
|
||||
|
||||
class PreNet(nn.Module):
|
||||
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
||||
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
||||
self.p = dropout
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
x = self.fc2(x)
|
||||
x = F.relu(x)
|
||||
x = F.dropout(x, self.p, training=True)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, attn_dims):
|
||||
super().__init__()
|
||||
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
|
||||
self.v = nn.Linear(attn_dims, 1, bias=False)
|
||||
|
||||
def forward(self, encoder_seq_proj, query, t):
|
||||
|
||||
# print(encoder_seq_proj.shape)
|
||||
# Transform the query vector
|
||||
query_proj = self.W(query).unsqueeze(1)
|
||||
|
||||
# Compute the scores
|
||||
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
|
||||
scores = F.softmax(u, dim=1)
|
||||
|
||||
return scores.transpose(1, 2)
|
||||
|
||||
|
||||
class LSA(nn.Module):
|
||||
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
||||
self.L = nn.Linear(filters, attn_dim, bias=False)
|
||||
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
self.cumulative = None
|
||||
self.attention = None
|
||||
|
||||
def init_attention(self, encoder_seq_proj):
|
||||
device = encoder_seq_proj.device # use same device as parameters
|
||||
b, t, c = encoder_seq_proj.size()
|
||||
self.cumulative = torch.zeros(b, t, device=device)
|
||||
self.attention = torch.zeros(b, t, device=device)
|
||||
|
||||
def forward(self, encoder_seq_proj, query, t, chars):
|
||||
|
||||
if t == 0: self.init_attention(encoder_seq_proj)
|
||||
|
||||
processed_query = self.W(query).unsqueeze(1)
|
||||
|
||||
location = self.cumulative.unsqueeze(1)
|
||||
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
||||
|
||||
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
||||
u = u.squeeze(-1)
|
||||
|
||||
# Mask zero padding chars
|
||||
u = u * (chars != 0).float()
|
||||
|
||||
# Smooth Attention
|
||||
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
||||
scores = F.softmax(u, dim=1)
|
||||
self.attention = scores
|
||||
self.cumulative = self.cumulative + self.attention
|
||||
|
||||
return scores.unsqueeze(-1).transpose(1, 2)
|
||||
|
||||
return self.cbhg(x)
|
||||
|
||||
class Decoder(nn.Module):
|
||||
# Class variable because its value doesn't change between classes
|
||||
|
@ -327,12 +110,11 @@ class Decoder(nn.Module):
|
|||
|
||||
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
||||
|
||||
|
||||
class Tacotron(nn.Module):
|
||||
class Tacotron(Base):
|
||||
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
|
||||
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
|
||||
dropout, stop_threshold, speaker_embedding_size):
|
||||
super().__init__()
|
||||
super().__init__(stop_threshold)
|
||||
self.n_mels = n_mels
|
||||
self.lstm_dims = lstm_dims
|
||||
self.encoder_dims = encoder_dims
|
||||
|
@ -352,20 +134,6 @@ class Tacotron(nn.Module):
|
|||
[postnet_dims, fft_bins], num_highways)
|
||||
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
||||
|
||||
self.init_model()
|
||||
self.num_params()
|
||||
|
||||
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
||||
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
||||
|
||||
@property
|
||||
def r(self):
|
||||
return self.decoder.r.item()
|
||||
|
||||
@r.setter
|
||||
def r(self, value):
|
||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
|
@ -373,6 +141,36 @@ class Tacotron(nn.Module):
|
|||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(x, speaker_embedding):
|
||||
# SV2TTS
|
||||
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
||||
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
||||
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
||||
# This concats the speaker embedding for each char in the encoder output
|
||||
|
||||
# Save the dimensions as human-readable names
|
||||
batch_size = x.size()[0]
|
||||
num_chars = x.size()[1]
|
||||
|
||||
if speaker_embedding.dim() == 1:
|
||||
idx = 0
|
||||
else:
|
||||
idx = 1
|
||||
|
||||
# Start by making a copy of each speaker embedding to match the input text length
|
||||
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
|
||||
speaker_embedding_size = speaker_embedding.size()[idx]
|
||||
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
||||
|
||||
# Reshape it and transpose
|
||||
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
||||
e = e.transpose(1, 2)
|
||||
|
||||
# Concatenate the tiled speaker embedding with the encoder output
|
||||
x = torch.cat((x, e), 2)
|
||||
return x
|
||||
|
||||
def forward(self, texts, mels, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
||||
|
||||
device = texts.device # use same device as parameters
|
||||
|
@ -405,8 +203,11 @@ class Tacotron(nn.Module):
|
|||
|
||||
# SV2TTS: Run the encoder with the speaker embedding
|
||||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||
encoder_seq = self.encoder(texts)
|
||||
|
||||
if speaker_embedding is not None:
|
||||
encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding)
|
||||
|
||||
if hparams.use_gst and self.gst is not None:
|
||||
if self.training:
|
||||
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||
|
@ -455,7 +256,6 @@ class Tacotron(nn.Module):
|
|||
# attn_scores = attn_scores.cpu().data.numpy()
|
||||
stop_outputs = torch.cat(stop_outputs, 1)
|
||||
|
||||
|
||||
if self.training:
|
||||
self.train()
|
||||
|
||||
|
@ -465,131 +265,3 @@ class Tacotron(nn.Module):
|
|||
self.eval()
|
||||
mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token)
|
||||
return mel_outputs, linear, attn_scores
|
||||
# device = x.device # use same device as parameters
|
||||
|
||||
# batch_size, _ = x.size()
|
||||
|
||||
# # Need to initialise all hidden states and pack into tuple for tidyness
|
||||
# attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||
# rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
# rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
# hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
||||
|
||||
# # Need to initialise all lstm cell states and pack into tuple for tidyness
|
||||
# rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
# rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||
# cell_states = (rnn1_cell, rnn2_cell)
|
||||
|
||||
# # Need a <GO> Frame for start of decoder loop
|
||||
# go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||
|
||||
# # Need an initial context vector
|
||||
# size = self.encoder_dims + self.speaker_embedding_size
|
||||
# if hparams.use_gst:
|
||||
# size += gst_hp.E
|
||||
# context_vec = torch.zeros(batch_size, size, device=device)
|
||||
|
||||
# # SV2TTS: Run the encoder with the speaker embedding
|
||||
# # The projection avoids unnecessary matmuls in the decoder loop
|
||||
# encoder_seq = self.encoder(x, speaker_embedding)
|
||||
|
||||
# # put after encoder
|
||||
# if hparams.use_gst and self.gst is not None:
|
||||
# if style_idx >= 0 and style_idx < 10:
|
||||
# query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||
# if device.type == 'cuda':
|
||||
# query = query.cuda()
|
||||
# gst_embed = torch.tanh(self.gst.stl.embed)
|
||||
# key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||
# style_embed = self.gst.stl.attention(query, key)
|
||||
# else:
|
||||
# speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||
# style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||
# encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||
# # style_embed = style_embed.expand_as(encoder_seq)
|
||||
# # encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||
# encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
||||
# # Need a couple of lists for outputs
|
||||
# mel_outputs, attn_scores, stop_outputs = [], [], []
|
||||
|
||||
# # Run the decoder loop
|
||||
# for t in range(0, steps, self.r):
|
||||
# prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
||||
# mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||
# self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||
# hidden_states, cell_states, context_vec, t, x)
|
||||
# mel_outputs.append(mel_frames)
|
||||
# attn_scores.append(scores)
|
||||
# stop_outputs.extend([stop_tokens] * self.r)
|
||||
# # Stop the loop when all stop tokens in batch exceed threshold
|
||||
# if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||
|
||||
# # Concat the mel outputs into sequence
|
||||
# mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||
|
||||
# # Post-Process for Linear Spectrograms
|
||||
# postnet_out = self.postnet(mel_outputs)
|
||||
# linear = self.post_proj(postnet_out)
|
||||
|
||||
|
||||
# linear = linear.transpose(1, 2)
|
||||
|
||||
# # For easy visualisation
|
||||
# attn_scores = torch.cat(attn_scores, 1)
|
||||
# stop_outputs = torch.cat(stop_outputs, 1)
|
||||
|
||||
# self.train()
|
||||
|
||||
# return mel_outputs, linear, attn_scores
|
||||
|
||||
def init_model(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
||||
|
||||
def finetune_partial(self, whitelist_layers):
|
||||
self.zero_grad()
|
||||
for name, child in self.named_children():
|
||||
if name in whitelist_layers:
|
||||
print("Trainable Layer: %s" % name)
|
||||
print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()]))
|
||||
for param in child.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_step(self):
|
||||
return self.step.data.item()
|
||||
|
||||
def reset_step(self):
|
||||
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
||||
self.step = self.step.data.new_tensor(1)
|
||||
|
||||
def log(self, path, msg):
|
||||
with open(path, "a") as f:
|
||||
print(msg, file=f)
|
||||
|
||||
def load(self, path, device, optimizer=None):
|
||||
# Use device of model params as location for loaded state
|
||||
checkpoint = torch.load(str(path), map_location=device)
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
|
||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
|
||||
def save(self, path, optimizer=None):
|
||||
if optimizer is not None:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, str(path))
|
||||
else:
|
||||
torch.save({
|
||||
"model_state": self.state_dict(),
|
||||
}, str(path))
|
||||
|
||||
|
||||
def num_params(self, print_out=True):
|
||||
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
||||
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
||||
if print_out:
|
||||
print("Trainable Parameters: %.3fM" % parameters)
|
||||
return parameters
|
||||
|
|
|
@ -265,7 +265,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||
hparams=hparams,
|
||||
sw=sw)
|
||||
MAX_SAVED_COUNT = 20
|
||||
if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT:
|
||||
if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT == 0:
|
||||
# clean up and save last MAX_SAVED_COUNT;
|
||||
plots = next(os.walk(plot_dir), (None, None, []))[2]
|
||||
for plot in plots[-MAX_SAVED_COUNT:]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user