import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from .utils.mol_attention import MOLAttention from .utils.basic_layers import Linear from .utils.vc_utils import get_mask_from_lengths class DecoderPrenet(nn.Module): def __init__(self, in_dim, sizes): super().__init__() in_sizes = [in_dim] + sizes[:-1] self.layers = nn.ModuleList( [Linear(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, sizes)]) def forward(self, x): for linear in self.layers: x = F.dropout(F.relu(linear(x)), p=0.5, training=True) return x class Decoder(nn.Module): """Mixture of Logistic (MoL) attention-based RNN Decoder.""" def __init__( self, enc_dim, num_mels, frames_per_step, attention_rnn_dim, decoder_rnn_dim, prenet_dims, num_mixtures, encoder_down_factor=1, num_decoder_rnn_layer=1, use_stop_tokens=False, concat_context_to_last=False, ): super().__init__() self.enc_dim = enc_dim self.encoder_down_factor = encoder_down_factor self.num_mels = num_mels self.frames_per_step = frames_per_step self.attention_rnn_dim = attention_rnn_dim self.decoder_rnn_dim = decoder_rnn_dim self.prenet_dims = prenet_dims self.use_stop_tokens = use_stop_tokens self.num_decoder_rnn_layer = num_decoder_rnn_layer self.concat_context_to_last = concat_context_to_last # Mel prenet self.prenet = DecoderPrenet(num_mels, prenet_dims) self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims) # Attention RNN self.attention_rnn = nn.LSTMCell( prenet_dims[-1] + enc_dim, attention_rnn_dim ) # Attention self.attention_layer = MOLAttention( attention_rnn_dim, r=frames_per_step/encoder_down_factor, M=num_mixtures, ) # Decoder RNN self.decoder_rnn_layers = nn.ModuleList() for i in range(num_decoder_rnn_layer): if i == 0: self.decoder_rnn_layers.append( nn.LSTMCell( enc_dim + attention_rnn_dim, decoder_rnn_dim)) else: self.decoder_rnn_layers.append( nn.LSTMCell( decoder_rnn_dim, decoder_rnn_dim)) # self.decoder_rnn = nn.LSTMCell( # 2 * enc_dim + attention_rnn_dim, # decoder_rnn_dim # ) if concat_context_to_last: self.linear_projection = Linear( enc_dim + decoder_rnn_dim, num_mels * frames_per_step ) else: self.linear_projection = Linear( decoder_rnn_dim, num_mels * frames_per_step ) # Stop-token layer if self.use_stop_tokens: if concat_context_to_last: self.stop_layer = Linear( enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid" ) else: self.stop_layer = Linear( decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid" ) def get_go_frame(self, memory): B = memory.size(0) go_frame = torch.zeros((B, self.num_mels), dtype=torch.float, device=memory.device) return go_frame def initialize_decoder_states(self, memory, mask): device = next(self.parameters()).device B = memory.size(0) # attention rnn states self.attention_hidden = torch.zeros( (B, self.attention_rnn_dim), device=device) self.attention_cell = torch.zeros( (B, self.attention_rnn_dim), device=device) # decoder rnn states self.decoder_hiddens = [] self.decoder_cells = [] for i in range(self.num_decoder_rnn_layer): self.decoder_hiddens.append( torch.zeros((B, self.decoder_rnn_dim), device=device) ) self.decoder_cells.append( torch.zeros((B, self.decoder_rnn_dim), device=device) ) # self.decoder_hidden = torch.zeros( # (B, self.decoder_rnn_dim), device=device) # self.decoder_cell = torch.zeros( # (B, self.decoder_rnn_dim), device=device) self.attention_context = torch.zeros( (B, self.enc_dim), device=device) self.memory = memory # self.processed_memory = self.attention_layer.memory_layer(memory) self.mask = mask def parse_decoder_inputs(self, decoder_inputs): """Prepare decoder inputs, i.e. gt mel Args: decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training. """ decoder_inputs = decoder_inputs.reshape( decoder_inputs.size(0), int(decoder_inputs.size(1)/self.frames_per_step), -1) # (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels) decoder_inputs = decoder_inputs.transpose(0, 1) # (T_out//r, B, num_mels) decoder_inputs = decoder_inputs[:,:,-self.num_mels:] return decoder_inputs def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs): """ Prepares decoder outputs for output Args: mel_outputs: alignments: """ # (T_out//r, B, T_enc) -> (B, T_out//r, T_enc) alignments = torch.stack(alignments).transpose(0, 1) # (T_out//r, B) -> (B, T_out//r) if stop_outputs is not None: if alignments.size(0) == 1: stop_outputs = torch.stack(stop_outputs).unsqueeze(0) else: stop_outputs = torch.stack(stop_outputs).transpose(0, 1) stop_outputs = stop_outputs.contiguous() # (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r) mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() # decouple frames per step # (B, T_out, num_mels) mel_outputs = mel_outputs.view( mel_outputs.size(0), -1, self.num_mels) return mel_outputs, alignments, stop_outputs def attend(self, decoder_input): cell_input = torch.cat((decoder_input, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell)) self.attention_context, attention_weights = self.attention_layer( self.attention_hidden, self.memory, None, self.mask) decoder_rnn_input = torch.cat( (self.attention_hidden, self.attention_context), -1) return decoder_rnn_input, self.attention_context, attention_weights def decode(self, decoder_input): for i in range(self.num_decoder_rnn_layer): if i == 0: self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i]( decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i])) else: self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i]( self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i])) return self.decoder_hiddens[-1] def forward(self, memory, mel_inputs, memory_lengths): """ Decoder forward pass for training Args: memory: (B, T_enc, enc_dim) Encoder outputs decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing. memory_lengths: (B, ) Encoder output lengths for attention masking. Returns: mel_outputs: (B, T, num_mels) mel outputs from the decoder alignments: (B, T//r, T_enc) attention weights. """ # [1, B, num_mels] go_frame = self.get_go_frame(memory).unsqueeze(0) # [T//r, B, num_mels] mel_inputs = self.parse_decoder_inputs(mel_inputs) # [T//r + 1, B, num_mels] mel_inputs = torch.cat((go_frame, mel_inputs), dim=0) # [T//r + 1, B, prenet_dim] decoder_inputs = self.prenet(mel_inputs) # decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__) self.initialize_decoder_states( memory, mask=~get_mask_from_lengths(memory_lengths), ) self.attention_layer.init_states(memory) # self.attention_layer_pitch.init_states(memory_pitch) mel_outputs, alignments = [], [] if self.use_stop_tokens: stop_outputs = [] else: stop_outputs = None while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] # decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)] decoder_rnn_input, context, attention_weights = self.attend(decoder_input) decoder_rnn_output = self.decode(decoder_rnn_input) if self.concat_context_to_last: decoder_rnn_output = torch.cat( (decoder_rnn_output, context), dim=1) mel_output = self.linear_projection(decoder_rnn_output) if self.use_stop_tokens: stop_output = self.stop_layer(decoder_rnn_output) stop_outputs += [stop_output.squeeze()] mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze alignments += [attention_weights] # alignments_pitch += [attention_weights_pitch] mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs( mel_outputs, alignments, stop_outputs) if stop_outputs is None: return mel_outputs, alignments else: return mel_outputs, stop_outputs, alignments def inference(self, memory, stop_threshold=0.5): """ Decoder inference Args: memory: (1, T_enc, D_enc) Encoder outputs Returns: mel_outputs: mel outputs from the decoder alignments: sequence of attention weights from the decoder """ # [1, num_mels] decoder_input = self.get_go_frame(memory) self.initialize_decoder_states(memory, mask=None) self.attention_layer.init_states(memory) mel_outputs, alignments = [], [] # NOTE(sx): heuristic max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5 while True: decoder_input = self.prenet(decoder_input) decoder_input_final, context, alignment = self.attend(decoder_input) #mel_output, stop_output, alignment = self.decode(decoder_input) decoder_rnn_output = self.decode(decoder_input_final) if self.concat_context_to_last: decoder_rnn_output = torch.cat( (decoder_rnn_output, context), dim=1) mel_output = self.linear_projection(decoder_rnn_output) stop_output = self.stop_layer(decoder_rnn_output) mel_outputs += [mel_output.squeeze(1)] alignments += [alignment] if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step: break if len(mel_outputs) >= max_decoder_step: # print("Warning! Decoding steps reaches max decoder steps.") break decoder_input = mel_output[:,-self.num_mels:] mel_outputs, alignments, _ = self.parse_decoder_outputs( mel_outputs, alignments, None) return mel_outputs, alignments def inference_batched(self, memory, stop_threshold=0.5): """ Decoder inference Args: memory: (B, T_enc, D_enc) Encoder outputs Returns: mel_outputs: mel outputs from the decoder alignments: sequence of attention weights from the decoder """ # [1, num_mels] decoder_input = self.get_go_frame(memory) self.initialize_decoder_states(memory, mask=None) self.attention_layer.init_states(memory) mel_outputs, alignments = [], [] stop_outputs = [] # NOTE(sx): heuristic max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5 while True: decoder_input = self.prenet(decoder_input) decoder_input_final, context, alignment = self.attend(decoder_input) #mel_output, stop_output, alignment = self.decode(decoder_input) decoder_rnn_output = self.decode(decoder_input_final) if self.concat_context_to_last: decoder_rnn_output = torch.cat( (decoder_rnn_output, context), dim=1) mel_output = self.linear_projection(decoder_rnn_output) # (B, 1) stop_output = self.stop_layer(decoder_rnn_output) stop_outputs += [stop_output.squeeze()] # stop_outputs.append(stop_output) mel_outputs += [mel_output.squeeze(1)] alignments += [alignment] # print(stop_output.shape) if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \ and len(mel_outputs) >= min_decoder_step: break if len(mel_outputs) >= max_decoder_step: # print("Warning! Decoding steps reaches max decoder steps.") break decoder_input = mel_output[:,-self.num_mels:] mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs( mel_outputs, alignments, stop_outputs) mel_outputs_stacked = [] for mel, stop_logit in zip(mel_outputs, stop_outputs): idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item() mel_outputs_stacked.append(mel[:idx,:]) mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0) return mel_outputs, alignments