#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2020 Johns Hopkins University (Shinji Watanabe) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Encoder self-attention layer definition.""" import torch from torch import nn from .layer_norm import LayerNorm class EncoderLayer(nn.Module): """Encoder layer module. :param int size: input dim :param espnet.nets.pytorch_backend.transformer.attention. MultiHeadedAttention self_attn: self attention module RelPositionMultiHeadedAttention self_attn: self attention module :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. PositionwiseFeedForward feed_forward: feed forward module :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward for macaron style PositionwiseFeedForward feed_forward: feed forward module :param espnet.nets.pytorch_backend.conformer.convolution. ConvolutionModule feed_foreard: feed forward module :param float dropout_rate: dropout rate :param bool normalize_before: whether to use layer_norm before the first block :param bool concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) """ def __init__( self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ): """Construct an EncoderLayer object.""" super(EncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.feed_forward_macaron = feed_forward_macaron self.conv_module = conv_module self.norm_ff = LayerNorm(size) # for the FNN module self.norm_mha = LayerNorm(size) # for the MHA module if feed_forward_macaron is not None: self.norm_ff_macaron = LayerNorm(size) self.ff_scale = 0.5 else: self.ff_scale = 1.0 if self.conv_module is not None: self.norm_conv = LayerNorm(size) # for the CNN module self.norm_final = LayerNorm(size) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(size + size, size) def forward(self, x_input, mask, cache=None): """Compute encoded features. :param torch.Tensor x_input: encoded source features, w/o pos_emb tuple((batch, max_time_in, size), (1, max_time_in, size)) or (batch, max_time_in, size) :param torch.Tensor mask: mask for x (batch, max_time_in) :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size) :rtype: Tuple[torch.Tensor, torch.Tensor] """ if isinstance(x_input, tuple): x, pos_emb = x_input[0], x_input[1] else: x, pos_emb = x_input, None # whether to use macaron style if self.feed_forward_macaron is not None: residual = x if self.normalize_before: x = self.norm_ff_macaron(x) x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) if not self.normalize_before: x = self.norm_ff_macaron(x) # multi-headed self-attention module residual = x if self.normalize_before: x = self.norm_mha(x) if cache is None: x_q = x else: assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) x_q = x[:, -1:, :] residual = residual[:, -1:, :] mask = None if mask is None else mask[:, -1:, :] if pos_emb is not None: x_att = self.self_attn(x_q, x, x, pos_emb, mask) else: x_att = self.self_attn(x_q, x, x, mask) if self.concat_after: x_concat = torch.cat((x, x_att), dim=-1) x = residual + self.concat_linear(x_concat) else: x = residual + self.dropout(x_att) if not self.normalize_before: x = self.norm_mha(x) # convolution module if self.conv_module is not None: residual = x if self.normalize_before: x = self.norm_conv(x) x = residual + self.dropout(self.conv_module(x)) if not self.normalize_before: x = self.norm_conv(x) # feed forward module residual = x if self.normalize_before: x = self.norm_ff(x) x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm_ff(x) if self.conv_module is not None: x = self.norm_final(x) if cache is not None: x = torch.cat([cache, x], dim=1) if pos_emb is not None: return (x, pos_emb), mask return x, mask