mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
375 lines
14 KiB
Python
375 lines
14 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from .utils.mol_attention import MOLAttention
|
|
from .utils.basic_layers import Linear
|
|
from .utils.vc_utils import get_mask_from_lengths
|
|
|
|
|
|
class DecoderPrenet(nn.Module):
|
|
def __init__(self, in_dim, sizes):
|
|
super().__init__()
|
|
in_sizes = [in_dim] + sizes[:-1]
|
|
self.layers = nn.ModuleList(
|
|
[Linear(in_size, out_size, bias=False)
|
|
for (in_size, out_size) in zip(in_sizes, sizes)])
|
|
|
|
def forward(self, x):
|
|
for linear in self.layers:
|
|
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
|
return x
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
"""Mixture of Logistic (MoL) attention-based RNN Decoder."""
|
|
def __init__(
|
|
self,
|
|
enc_dim,
|
|
num_mels,
|
|
frames_per_step,
|
|
attention_rnn_dim,
|
|
decoder_rnn_dim,
|
|
prenet_dims,
|
|
num_mixtures,
|
|
encoder_down_factor=1,
|
|
num_decoder_rnn_layer=1,
|
|
use_stop_tokens=False,
|
|
concat_context_to_last=False,
|
|
):
|
|
super().__init__()
|
|
self.enc_dim = enc_dim
|
|
self.encoder_down_factor = encoder_down_factor
|
|
self.num_mels = num_mels
|
|
self.frames_per_step = frames_per_step
|
|
self.attention_rnn_dim = attention_rnn_dim
|
|
self.decoder_rnn_dim = decoder_rnn_dim
|
|
self.prenet_dims = prenet_dims
|
|
self.use_stop_tokens = use_stop_tokens
|
|
self.num_decoder_rnn_layer = num_decoder_rnn_layer
|
|
self.concat_context_to_last = concat_context_to_last
|
|
|
|
# Mel prenet
|
|
self.prenet = DecoderPrenet(num_mels, prenet_dims)
|
|
self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
|
|
|
|
# Attention RNN
|
|
self.attention_rnn = nn.LSTMCell(
|
|
prenet_dims[-1] + enc_dim,
|
|
attention_rnn_dim
|
|
)
|
|
|
|
# Attention
|
|
self.attention_layer = MOLAttention(
|
|
attention_rnn_dim,
|
|
r=frames_per_step/encoder_down_factor,
|
|
M=num_mixtures,
|
|
)
|
|
|
|
# Decoder RNN
|
|
self.decoder_rnn_layers = nn.ModuleList()
|
|
for i in range(num_decoder_rnn_layer):
|
|
if i == 0:
|
|
self.decoder_rnn_layers.append(
|
|
nn.LSTMCell(
|
|
enc_dim + attention_rnn_dim,
|
|
decoder_rnn_dim))
|
|
else:
|
|
self.decoder_rnn_layers.append(
|
|
nn.LSTMCell(
|
|
decoder_rnn_dim,
|
|
decoder_rnn_dim))
|
|
# self.decoder_rnn = nn.LSTMCell(
|
|
# 2 * enc_dim + attention_rnn_dim,
|
|
# decoder_rnn_dim
|
|
# )
|
|
if concat_context_to_last:
|
|
self.linear_projection = Linear(
|
|
enc_dim + decoder_rnn_dim,
|
|
num_mels * frames_per_step
|
|
)
|
|
else:
|
|
self.linear_projection = Linear(
|
|
decoder_rnn_dim,
|
|
num_mels * frames_per_step
|
|
)
|
|
|
|
|
|
# Stop-token layer
|
|
if self.use_stop_tokens:
|
|
if concat_context_to_last:
|
|
self.stop_layer = Linear(
|
|
enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
|
)
|
|
else:
|
|
self.stop_layer = Linear(
|
|
decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
|
)
|
|
|
|
|
|
def get_go_frame(self, memory):
|
|
B = memory.size(0)
|
|
go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
|
|
device=memory.device)
|
|
return go_frame
|
|
|
|
def initialize_decoder_states(self, memory, mask):
|
|
device = next(self.parameters()).device
|
|
B = memory.size(0)
|
|
|
|
# attention rnn states
|
|
self.attention_hidden = torch.zeros(
|
|
(B, self.attention_rnn_dim), device=device)
|
|
self.attention_cell = torch.zeros(
|
|
(B, self.attention_rnn_dim), device=device)
|
|
|
|
# decoder rnn states
|
|
self.decoder_hiddens = []
|
|
self.decoder_cells = []
|
|
for i in range(self.num_decoder_rnn_layer):
|
|
self.decoder_hiddens.append(
|
|
torch.zeros((B, self.decoder_rnn_dim),
|
|
device=device)
|
|
)
|
|
self.decoder_cells.append(
|
|
torch.zeros((B, self.decoder_rnn_dim),
|
|
device=device)
|
|
)
|
|
# self.decoder_hidden = torch.zeros(
|
|
# (B, self.decoder_rnn_dim), device=device)
|
|
# self.decoder_cell = torch.zeros(
|
|
# (B, self.decoder_rnn_dim), device=device)
|
|
|
|
self.attention_context = torch.zeros(
|
|
(B, self.enc_dim), device=device)
|
|
|
|
self.memory = memory
|
|
# self.processed_memory = self.attention_layer.memory_layer(memory)
|
|
self.mask = mask
|
|
|
|
def parse_decoder_inputs(self, decoder_inputs):
|
|
"""Prepare decoder inputs, i.e. gt mel
|
|
Args:
|
|
decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
|
|
"""
|
|
decoder_inputs = decoder_inputs.reshape(
|
|
decoder_inputs.size(0),
|
|
int(decoder_inputs.size(1)/self.frames_per_step), -1)
|
|
# (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
|
|
decoder_inputs = decoder_inputs.transpose(0, 1)
|
|
# (T_out//r, B, num_mels)
|
|
decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
|
|
return decoder_inputs
|
|
|
|
def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
|
|
""" Prepares decoder outputs for output
|
|
Args:
|
|
mel_outputs:
|
|
alignments:
|
|
"""
|
|
# (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
|
|
alignments = torch.stack(alignments).transpose(0, 1)
|
|
# (T_out//r, B) -> (B, T_out//r)
|
|
if stop_outputs is not None:
|
|
if alignments.size(0) == 1:
|
|
stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
|
|
else:
|
|
stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
|
|
stop_outputs = stop_outputs.contiguous()
|
|
# (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
|
|
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
|
# decouple frames per step
|
|
# (B, T_out, num_mels)
|
|
mel_outputs = mel_outputs.view(
|
|
mel_outputs.size(0), -1, self.num_mels)
|
|
return mel_outputs, alignments, stop_outputs
|
|
|
|
def attend(self, decoder_input):
|
|
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
|
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
|
cell_input, (self.attention_hidden, self.attention_cell))
|
|
self.attention_context, attention_weights = self.attention_layer(
|
|
self.attention_hidden, self.memory, None, self.mask)
|
|
|
|
decoder_rnn_input = torch.cat(
|
|
(self.attention_hidden, self.attention_context), -1)
|
|
|
|
return decoder_rnn_input, self.attention_context, attention_weights
|
|
|
|
def decode(self, decoder_input):
|
|
for i in range(self.num_decoder_rnn_layer):
|
|
if i == 0:
|
|
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
|
decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
|
|
else:
|
|
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
|
self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
|
|
return self.decoder_hiddens[-1]
|
|
|
|
def forward(self, memory, mel_inputs, memory_lengths):
|
|
""" Decoder forward pass for training
|
|
Args:
|
|
memory: (B, T_enc, enc_dim) Encoder outputs
|
|
decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
|
|
memory_lengths: (B, ) Encoder output lengths for attention masking.
|
|
Returns:
|
|
mel_outputs: (B, T, num_mels) mel outputs from the decoder
|
|
alignments: (B, T//r, T_enc) attention weights.
|
|
"""
|
|
# [1, B, num_mels]
|
|
go_frame = self.get_go_frame(memory).unsqueeze(0)
|
|
# [T//r, B, num_mels]
|
|
mel_inputs = self.parse_decoder_inputs(mel_inputs)
|
|
# [T//r + 1, B, num_mels]
|
|
mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
|
|
# [T//r + 1, B, prenet_dim]
|
|
decoder_inputs = self.prenet(mel_inputs)
|
|
# decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
|
|
|
|
self.initialize_decoder_states(
|
|
memory, mask=~get_mask_from_lengths(memory_lengths),
|
|
)
|
|
|
|
self.attention_layer.init_states(memory)
|
|
# self.attention_layer_pitch.init_states(memory_pitch)
|
|
|
|
mel_outputs, alignments = [], []
|
|
if self.use_stop_tokens:
|
|
stop_outputs = []
|
|
else:
|
|
stop_outputs = None
|
|
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
|
decoder_input = decoder_inputs[len(mel_outputs)]
|
|
# decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
|
|
|
|
decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
|
|
|
|
decoder_rnn_output = self.decode(decoder_rnn_input)
|
|
if self.concat_context_to_last:
|
|
decoder_rnn_output = torch.cat(
|
|
(decoder_rnn_output, context), dim=1)
|
|
|
|
mel_output = self.linear_projection(decoder_rnn_output)
|
|
if self.use_stop_tokens:
|
|
stop_output = self.stop_layer(decoder_rnn_output)
|
|
stop_outputs += [stop_output.squeeze()]
|
|
mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
|
|
alignments += [attention_weights]
|
|
# alignments_pitch += [attention_weights_pitch]
|
|
|
|
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
|
mel_outputs, alignments, stop_outputs)
|
|
if stop_outputs is None:
|
|
return mel_outputs, alignments
|
|
else:
|
|
return mel_outputs, stop_outputs, alignments
|
|
|
|
def inference(self, memory, stop_threshold=0.5):
|
|
""" Decoder inference
|
|
Args:
|
|
memory: (1, T_enc, D_enc) Encoder outputs
|
|
Returns:
|
|
mel_outputs: mel outputs from the decoder
|
|
alignments: sequence of attention weights from the decoder
|
|
"""
|
|
# [1, num_mels]
|
|
decoder_input = self.get_go_frame(memory)
|
|
|
|
self.initialize_decoder_states(memory, mask=None)
|
|
|
|
self.attention_layer.init_states(memory)
|
|
|
|
mel_outputs, alignments = [], []
|
|
# NOTE(sx): heuristic
|
|
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
|
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
|
while True:
|
|
decoder_input = self.prenet(decoder_input)
|
|
|
|
decoder_input_final, context, alignment = self.attend(decoder_input)
|
|
|
|
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
|
decoder_rnn_output = self.decode(decoder_input_final)
|
|
if self.concat_context_to_last:
|
|
decoder_rnn_output = torch.cat(
|
|
(decoder_rnn_output, context), dim=1)
|
|
|
|
mel_output = self.linear_projection(decoder_rnn_output)
|
|
stop_output = self.stop_layer(decoder_rnn_output)
|
|
|
|
mel_outputs += [mel_output.squeeze(1)]
|
|
alignments += [alignment]
|
|
|
|
if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
|
|
break
|
|
if len(mel_outputs) >= max_decoder_step:
|
|
# print("Warning! Decoding steps reaches max decoder steps.")
|
|
break
|
|
|
|
decoder_input = mel_output[:,-self.num_mels:]
|
|
|
|
|
|
mel_outputs, alignments, _ = self.parse_decoder_outputs(
|
|
mel_outputs, alignments, None)
|
|
|
|
return mel_outputs, alignments
|
|
|
|
def inference_batched(self, memory, stop_threshold=0.5):
|
|
""" Decoder inference
|
|
Args:
|
|
memory: (B, T_enc, D_enc) Encoder outputs
|
|
Returns:
|
|
mel_outputs: mel outputs from the decoder
|
|
alignments: sequence of attention weights from the decoder
|
|
"""
|
|
# [1, num_mels]
|
|
decoder_input = self.get_go_frame(memory)
|
|
|
|
self.initialize_decoder_states(memory, mask=None)
|
|
|
|
self.attention_layer.init_states(memory)
|
|
|
|
mel_outputs, alignments = [], []
|
|
stop_outputs = []
|
|
# NOTE(sx): heuristic
|
|
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
|
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
|
while True:
|
|
decoder_input = self.prenet(decoder_input)
|
|
|
|
decoder_input_final, context, alignment = self.attend(decoder_input)
|
|
|
|
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
|
decoder_rnn_output = self.decode(decoder_input_final)
|
|
if self.concat_context_to_last:
|
|
decoder_rnn_output = torch.cat(
|
|
(decoder_rnn_output, context), dim=1)
|
|
|
|
mel_output = self.linear_projection(decoder_rnn_output)
|
|
# (B, 1)
|
|
stop_output = self.stop_layer(decoder_rnn_output)
|
|
stop_outputs += [stop_output.squeeze()]
|
|
# stop_outputs.append(stop_output)
|
|
|
|
mel_outputs += [mel_output.squeeze(1)]
|
|
alignments += [alignment]
|
|
# print(stop_output.shape)
|
|
if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
|
|
and len(mel_outputs) >= min_decoder_step:
|
|
break
|
|
if len(mel_outputs) >= max_decoder_step:
|
|
# print("Warning! Decoding steps reaches max decoder steps.")
|
|
break
|
|
|
|
decoder_input = mel_output[:,-self.num_mels:]
|
|
|
|
|
|
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
|
mel_outputs, alignments, stop_outputs)
|
|
mel_outputs_stacked = []
|
|
for mel, stop_logit in zip(mel_outputs, stop_outputs):
|
|
idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
|
|
mel_outputs_stacked.append(mel[:idx,:])
|
|
mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
|
|
return mel_outputs, alignments
|