mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
6abdd0ebf0
* Refactor model * Refactor and fix bug to save plots
43 lines
1.6 KiB
Python
43 lines
1.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class LSA(nn.Module):
|
|
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
|
super().__init__()
|
|
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
|
self.L = nn.Linear(filters, attn_dim, bias=False)
|
|
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
|
self.v = nn.Linear(attn_dim, 1, bias=False)
|
|
self.cumulative = None
|
|
self.attention = None
|
|
|
|
def init_attention(self, encoder_seq_proj):
|
|
device = encoder_seq_proj.device # use same device as parameters
|
|
b, t, c = encoder_seq_proj.size()
|
|
self.cumulative = torch.zeros(b, t, device=device)
|
|
self.attention = torch.zeros(b, t, device=device)
|
|
|
|
def forward(self, encoder_seq_proj, query, t, chars):
|
|
|
|
if t == 0: self.init_attention(encoder_seq_proj)
|
|
|
|
processed_query = self.W(query).unsqueeze(1)
|
|
|
|
location = self.cumulative.unsqueeze(1)
|
|
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
|
|
|
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
|
u = u.squeeze(-1)
|
|
|
|
# Mask zero padding chars
|
|
u = u * (chars != 0).float()
|
|
|
|
# Smooth Attention
|
|
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
|
scores = F.softmax(u, dim=1)
|
|
self.attention = scores
|
|
self.cumulative = self.cumulative + self.attention
|
|
|
|
return scores.unsqueeze(-1).transpose(1, 2)
|