import torch import torch.nn as nn import torch.nn.functional as F class MOLAttention(nn.Module): """ Discretized Mixture of Logistic (MOL) attention. C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis". """ def __init__( self, query_dim, r=1, M=5, ): """ Args: query_dim: attention_rnn_dim. M: number of mixtures. """ super().__init__() if r < 1: self.r = float(r) else: self.r = int(r) self.M = M self.score_mask_value = 0.0 # -float("inf") self.eps = 1e-5 # Position arrary for encoder time steps self.J = None # Query layer: [w, sigma,] self.query_layer = torch.nn.Sequential( nn.Linear(query_dim, 256, bias=True), nn.ReLU(), nn.Linear(256, 3*M, bias=True) ) self.mu_prev = None self.initialize_bias() def initialize_bias(self): """Initialize sigma and Delta.""" # sigma torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0) # Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0 # softplus(-0.432) = 0.5003 if self.r == 2: torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545) elif self.r == 4: torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815) elif self.r == 1: torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413) else: torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432) def init_states(self, memory): """Initialize mu_prev and J. This function should be called by the decoder before decoding one batch. Args: memory: (B, T, D_enc) encoder output. """ B, T_enc, _ = memory.size() device = memory.device self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage # self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float) self.mu_prev = torch.zeros(B, self.M).to(device) def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None): """ att_rnn_h: attetion rnn hidden state. memory: encoder outputs (B, T_enc, D). mask: binary mask for padded data (B, T_enc). """ # [B, 3M] mixture_params = self.query_layer(att_rnn_h) # [B, M] w_hat = mixture_params[:, :self.M] sigma_hat = mixture_params[:, self.M:2*self.M] Delta_hat = mixture_params[:, 2*self.M:3*self.M] # print("w_hat: ", w_hat) # print("sigma_hat: ", sigma_hat) # print("Delta_hat: ", Delta_hat) # Dropout to de-correlate attention heads w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed? # Mixture parameters w = torch.softmax(w_hat, dim=-1) + self.eps sigma = F.softplus(sigma_hat) + self.eps Delta = F.softplus(Delta_hat) mu_cur = self.mu_prev + Delta # print("w:", w) j = self.J[:memory.size(1) + 1] # Attention weights # CDF of logistic distribution phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid( (mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1)))) # print("phi_t:", phi_t) # Discretize attention weights # (B, T_enc + 1) alpha_t = torch.sum(phi_t, dim=1) alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] alpha_t[alpha_t == 0] = self.eps # print("alpha_t: ", alpha_t.size()) # Apply masking if mask is not None: alpha_t.data.masked_fill_(mask, self.score_mask_value) context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1) if memory_pitch is not None: context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1) self.mu_prev = mu_cur if memory_pitch is not None: return context, context_pitch, alpha_t return context, alpha_t