mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
77 lines
3.3 KiB
Python
77 lines
3.3 KiB
Python
|
# Copyright (c) 2019, Adobe Inc. All rights reserved.
|
||
|
#
|
||
|
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
|
||
|
# 4.0 International Public License. To view a copy of this license, visit
|
||
|
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
|
||
|
|
||
|
# DWT code borrow from https://github.com/LiQiufu/WaveSNet/blob/12cb9d24208c3d26917bf953618c30f0c6b0f03d/DWT_IDWT/DWT_IDWT_layer.py
|
||
|
|
||
|
|
||
|
import pywt
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
__all__ = ['DWT_1D']
|
||
|
Pad_Mode = ['constant', 'reflect', 'replicate', 'circular']
|
||
|
|
||
|
|
||
|
class DWT_1D(nn.Module):
|
||
|
def __init__(self, pad_type='reflect', wavename='haar',
|
||
|
stride=2, in_channels=1, out_channels=None, groups=None,
|
||
|
kernel_size=None, trainable=False):
|
||
|
|
||
|
super(DWT_1D, self).__init__()
|
||
|
self.trainable = trainable
|
||
|
self.kernel_size = kernel_size
|
||
|
if not self.trainable:
|
||
|
assert self.kernel_size == None
|
||
|
self.in_channels = in_channels
|
||
|
self.out_channels = self.in_channels if out_channels == None else out_channels
|
||
|
self.groups = self.in_channels if groups == None else groups
|
||
|
assert isinstance(self.groups, int) and self.in_channels % self.groups == 0
|
||
|
self.stride = stride
|
||
|
assert self.stride == 2
|
||
|
self.wavename = wavename
|
||
|
self.pad_type = pad_type
|
||
|
assert self.pad_type in Pad_Mode
|
||
|
self.get_filters()
|
||
|
self.initialization()
|
||
|
|
||
|
def get_filters(self):
|
||
|
wavelet = pywt.Wavelet(self.wavename)
|
||
|
band_low = torch.tensor(wavelet.rec_lo)
|
||
|
band_high = torch.tensor(wavelet.rec_hi)
|
||
|
length_band = band_low.size()[0]
|
||
|
self.kernel_size = length_band if self.kernel_size == None else self.kernel_size
|
||
|
assert self.kernel_size >= length_band
|
||
|
a = (self.kernel_size - length_band) // 2
|
||
|
b = - (self.kernel_size - length_band - a)
|
||
|
b = None if b == 0 else b
|
||
|
self.filt_low = torch.zeros(self.kernel_size)
|
||
|
self.filt_high = torch.zeros(self.kernel_size)
|
||
|
self.filt_low[a:b] = band_low
|
||
|
self.filt_high[a:b] = band_high
|
||
|
|
||
|
def initialization(self):
|
||
|
self.filter_low = self.filt_low[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
|
||
|
self.filter_high = self.filt_high[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
|
||
|
if torch.cuda.is_available():
|
||
|
self.filter_low = self.filter_low.cuda()
|
||
|
self.filter_high = self.filter_high.cuda()
|
||
|
if self.trainable:
|
||
|
self.filter_low = nn.Parameter(self.filter_low)
|
||
|
self.filter_high = nn.Parameter(self.filter_high)
|
||
|
if self.kernel_size % 2 == 0:
|
||
|
self.pad_sizes = [self.kernel_size // 2 - 1, self.kernel_size // 2 - 1]
|
||
|
else:
|
||
|
self.pad_sizes = [self.kernel_size // 2, self.kernel_size // 2]
|
||
|
|
||
|
def forward(self, input):
|
||
|
assert isinstance(input, torch.Tensor)
|
||
|
assert len(input.size()) == 3
|
||
|
assert input.size()[1] == self.in_channels
|
||
|
input = F.pad(input, pad=self.pad_sizes, mode=self.pad_type)
|
||
|
return F.conv1d(input, self.filter_low.to(input.device), stride=self.stride, groups=self.groups), \
|
||
|
F.conv1d(input, self.filter_high.to(input.device), stride=self.stride, groups=self.groups)
|