#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2020 Johns Hopkins University (Shinji Watanabe) # Northwestern Polytechnical University (Pengcheng Guo) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """ConvolutionModule definition.""" from torch import nn class ConvolutionModule(nn.Module): """ConvolutionModule in Conformer model. :param int channels: channels of cnn :param int kernel_size: kernerl size of cnn """ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 self.pointwise_conv1 = nn.Conv1d( channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, ) self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) self.activation = activation def forward(self, x): """Compute convolution module. :param torch.Tensor x: (batch, time, size) :return torch.Tensor: convoluted `value` (batch, time, d_model) """ # exchange the temporal dimension and the feature dimension x = x.transpose(1, 2) # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channel, dim) x = nn.functional.glu(x, dim=1) # (batch, channel, dim) # 1D Depthwise Conv x = self.depthwise_conv(x) x = self.activation(self.norm(x)) x = self.pointwise_conv2(x) return x.transpose(1, 2)