import torch import torch.nn as nn import torch.nn.functional as F class LSA(nn.Module): def __init__(self, attn_dim, kernel_size=31, filters=32): super().__init__() self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True) self.L = nn.Linear(filters, attn_dim, bias=False) self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term self.v = nn.Linear(attn_dim, 1, bias=False) self.cumulative = None self.attention = None def init_attention(self, encoder_seq_proj): device = encoder_seq_proj.device # use same device as parameters b, t, c = encoder_seq_proj.size() self.cumulative = torch.zeros(b, t, device=device) self.attention = torch.zeros(b, t, device=device) def forward(self, encoder_seq_proj, query, t, chars): if t == 0: self.init_attention(encoder_seq_proj) processed_query = self.W(query).unsqueeze(1) location = self.cumulative.unsqueeze(1) processed_loc = self.L(self.conv(location).transpose(1, 2)) u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) u = u.squeeze(-1) # Mask zero padding chars u = u * (chars != 0).float() # Smooth Attention # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) scores = F.softmax(u, dim=1) self.attention = scores self.cumulative = self.cumulative + self.attention return scores.unsqueeze(-1).transpose(1, 2)