mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
86 lines
2.8 KiB
Python
86 lines
2.8 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from .common.batch_norm_conv import BatchNormConv
|
||
|
from .common.highway_network import HighwayNetwork
|
||
|
|
||
|
class CBHG(nn.Module):
|
||
|
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
||
|
super().__init__()
|
||
|
|
||
|
# List of all rnns to call `flatten_parameters()` on
|
||
|
self._to_flatten = []
|
||
|
|
||
|
self.bank_kernels = [i for i in range(1, K + 1)]
|
||
|
self.conv1d_bank = nn.ModuleList()
|
||
|
for k in self.bank_kernels:
|
||
|
conv = BatchNormConv(in_channels, channels, k)
|
||
|
self.conv1d_bank.append(conv)
|
||
|
|
||
|
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||
|
|
||
|
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
||
|
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
||
|
|
||
|
# Fix the highway input if necessary
|
||
|
if proj_channels[-1] != channels:
|
||
|
self.highway_mismatch = True
|
||
|
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
||
|
else:
|
||
|
self.highway_mismatch = False
|
||
|
|
||
|
self.highways = nn.ModuleList()
|
||
|
for i in range(num_highways):
|
||
|
hn = HighwayNetwork(channels)
|
||
|
self.highways.append(hn)
|
||
|
|
||
|
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
||
|
self._to_flatten.append(self.rnn)
|
||
|
|
||
|
# Avoid fragmentation of RNN parameters and associated warning
|
||
|
self._flatten_parameters()
|
||
|
|
||
|
def forward(self, x):
|
||
|
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||
|
# the model gets replicated, making it no longer guaranteed that the
|
||
|
# weights are contiguous in GPU memory. Hence, we must call it again
|
||
|
self.rnn.flatten_parameters()
|
||
|
|
||
|
# Save these for later
|
||
|
residual = x
|
||
|
seq_len = x.size(-1)
|
||
|
conv_bank = []
|
||
|
|
||
|
# Convolution Bank
|
||
|
for conv in self.conv1d_bank:
|
||
|
c = conv(x) # Convolution
|
||
|
conv_bank.append(c[:, :, :seq_len])
|
||
|
|
||
|
# Stack along the channel axis
|
||
|
conv_bank = torch.cat(conv_bank, dim=1)
|
||
|
|
||
|
# dump the last padding to fit residual
|
||
|
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
||
|
|
||
|
# Conv1d projections
|
||
|
x = self.conv_project1(x)
|
||
|
x = self.conv_project2(x)
|
||
|
|
||
|
# Residual Connect
|
||
|
x = x + residual
|
||
|
|
||
|
# Through the highways
|
||
|
x = x.transpose(1, 2)
|
||
|
if self.highway_mismatch is True:
|
||
|
x = self.pre_highway(x)
|
||
|
for h in self.highways: x = h(x)
|
||
|
|
||
|
# And then the RNN
|
||
|
x, _ = self.rnn(x)
|
||
|
return x
|
||
|
|
||
|
def _flatten_parameters(self):
|
||
|
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
||
|
to improve efficiency and avoid PyTorch yelling at us."""
|
||
|
[m.flatten_parameters() for m in self._to_flatten]
|
||
|
|