import torch import torch.nn as nn from .common.batch_norm_conv import BatchNormConv from .common.highway_network import HighwayNetwork class CBHG(nn.Module): def __init__(self, K, in_channels, channels, proj_channels, num_highways): super().__init__() # List of all rnns to call `flatten_parameters()` on self._to_flatten = [] self.bank_kernels = [i for i in range(1, K + 1)] self.conv1d_bank = nn.ModuleList() for k in self.bank_kernels: conv = BatchNormConv(in_channels, channels, k) self.conv1d_bank.append(conv) self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) # Fix the highway input if necessary if proj_channels[-1] != channels: self.highway_mismatch = True self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) else: self.highway_mismatch = False self.highways = nn.ModuleList() for i in range(num_highways): hn = HighwayNetwork(channels) self.highways.append(hn) self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True) self._to_flatten.append(self.rnn) # Avoid fragmentation of RNN parameters and associated warning self._flatten_parameters() def forward(self, x): # Although we `_flatten_parameters()` on init, when using DataParallel # the model gets replicated, making it no longer guaranteed that the # weights are contiguous in GPU memory. Hence, we must call it again self.rnn.flatten_parameters() # Save these for later residual = x seq_len = x.size(-1) conv_bank = [] # Convolution Bank for conv in self.conv1d_bank: c = conv(x) # Convolution conv_bank.append(c[:, :, :seq_len]) # Stack along the channel axis conv_bank = torch.cat(conv_bank, dim=1) # dump the last padding to fit residual x = self.maxpool(conv_bank)[:, :, :seq_len] # Conv1d projections x = self.conv_project1(x) x = self.conv_project2(x) # Residual Connect x = x + residual # Through the highways x = x.transpose(1, 2) if self.highway_mismatch is True: x = self.pre_highway(x) for h in self.highways: x = h(x) # And then the RNN x, _ = self.rnn(x) return x def _flatten_parameters(self): """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used to improve efficiency and avoid PyTorch yelling at us.""" [m.flatten_parameters() for m in self._to_flatten]