mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
74a3fc97d0
Need readme
75 lines
2.1 KiB
Python
75 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
|
# Northwestern Polytechnical University (Pengcheng Guo)
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
"""ConvolutionModule definition."""
|
|
|
|
from torch import nn
|
|
|
|
|
|
class ConvolutionModule(nn.Module):
|
|
"""ConvolutionModule in Conformer model.
|
|
|
|
:param int channels: channels of cnn
|
|
:param int kernel_size: kernerl size of cnn
|
|
|
|
"""
|
|
|
|
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
|
"""Construct an ConvolutionModule object."""
|
|
super(ConvolutionModule, self).__init__()
|
|
# kernerl_size should be a odd number for 'SAME' padding
|
|
assert (kernel_size - 1) % 2 == 0
|
|
|
|
self.pointwise_conv1 = nn.Conv1d(
|
|
channels,
|
|
2 * channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=bias,
|
|
)
|
|
self.depthwise_conv = nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=(kernel_size - 1) // 2,
|
|
groups=channels,
|
|
bias=bias,
|
|
)
|
|
self.norm = nn.BatchNorm1d(channels)
|
|
self.pointwise_conv2 = nn.Conv1d(
|
|
channels,
|
|
channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=bias,
|
|
)
|
|
self.activation = activation
|
|
|
|
def forward(self, x):
|
|
"""Compute convolution module.
|
|
|
|
:param torch.Tensor x: (batch, time, size)
|
|
:return torch.Tensor: convoluted `value` (batch, time, d_model)
|
|
"""
|
|
# exchange the temporal dimension and the feature dimension
|
|
x = x.transpose(1, 2)
|
|
|
|
# GLU mechanism
|
|
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
|
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
|
|
|
# 1D Depthwise Conv
|
|
x = self.depthwise_conv(x)
|
|
x = self.activation(self.norm(x))
|
|
|
|
x = self.pointwise_conv2(x)
|
|
|
|
return x.transpose(1, 2)
|