import torch from torch import nn from torch.nn import functional as F from torch.autograd import Function def tile(x, count, dim=0): """ Tiles x on dimension dim count times. """ perm = list(range(len(x.size()))) if dim != 0: perm[0], perm[dim] = perm[dim], perm[0] x = x.permute(perm).contiguous() out_size = list(x.size()) out_size[0] *= count batch = x.size(0) x = x.view(batch, -1) \ .transpose(0, 1) \ .repeat(count, 1) \ .transpose(0, 1) \ .contiguous() \ .view(*out_size) if dim != 0: x = x.permute(perm).contiguous() return x class Linear(torch.nn.Module): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): super(Linear, self).__init__() self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) torch.nn.init.xavier_uniform_( self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) def forward(self, x): return self.linear_layer(x) class Conv1d(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): super(Conv1d, self).__init__() if padding is None: assert(kernel_size % 2 == 1) padding = int(dilation * (kernel_size - 1)/2) self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) torch.nn.init.xavier_uniform_( self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) def forward(self, x): # x: BxDxT return self.conv(x) def tile(x, count, dim=0): """ Tiles x on dimension dim count times. """ perm = list(range(len(x.size()))) if dim != 0: perm[0], perm[dim] = perm[dim], perm[0] x = x.permute(perm).contiguous() out_size = list(x.size()) out_size[0] *= count batch = x.size(0) x = x.view(batch, -1) \ .transpose(0, 1) \ .repeat(count, 1) \ .transpose(0, 1) \ .contiguous() \ .view(*out_size) if dim != 0: x = x.permute(perm).contiguous() return x