# Copyright (c) 2019, Adobe Inc. All rights reserved. # # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike # 4.0 International Public License. To view a copy of this license, visit # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. # DWT code borrow from https://github.com/LiQiufu/WaveSNet/blob/12cb9d24208c3d26917bf953618c30f0c6b0f03d/DWT_IDWT/DWT_IDWT_layer.py import pywt import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['DWT_1D'] Pad_Mode = ['constant', 'reflect', 'replicate', 'circular'] class DWT_1D(nn.Module): def __init__(self, pad_type='reflect', wavename='haar', stride=2, in_channels=1, out_channels=None, groups=None, kernel_size=None, trainable=False): super(DWT_1D, self).__init__() self.trainable = trainable self.kernel_size = kernel_size if not self.trainable: assert self.kernel_size == None self.in_channels = in_channels self.out_channels = self.in_channels if out_channels == None else out_channels self.groups = self.in_channels if groups == None else groups assert isinstance(self.groups, int) and self.in_channels % self.groups == 0 self.stride = stride assert self.stride == 2 self.wavename = wavename self.pad_type = pad_type assert self.pad_type in Pad_Mode self.get_filters() self.initialization() def get_filters(self): wavelet = pywt.Wavelet(self.wavename) band_low = torch.tensor(wavelet.rec_lo) band_high = torch.tensor(wavelet.rec_hi) length_band = band_low.size()[0] self.kernel_size = length_band if self.kernel_size == None else self.kernel_size assert self.kernel_size >= length_band a = (self.kernel_size - length_band) // 2 b = - (self.kernel_size - length_band - a) b = None if b == 0 else b self.filt_low = torch.zeros(self.kernel_size) self.filt_high = torch.zeros(self.kernel_size) self.filt_low[a:b] = band_low self.filt_high[a:b] = band_high def initialization(self): self.filter_low = self.filt_low[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1)) self.filter_high = self.filt_high[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1)) if torch.cuda.is_available(): self.filter_low = self.filter_low.cuda() self.filter_high = self.filter_high.cuda() if self.trainable: self.filter_low = nn.Parameter(self.filter_low) self.filter_high = nn.Parameter(self.filter_high) if self.kernel_size % 2 == 0: self.pad_sizes = [self.kernel_size // 2 - 1, self.kernel_size // 2 - 1] else: self.pad_sizes = [self.kernel_size // 2, self.kernel_size // 2] def forward(self, input): assert isinstance(input, torch.Tensor) assert len(input.size()) == 3 assert input.size()[1] == self.in_channels input = F.pad(input, pad=self.pad_sizes, mode=self.pad_type) return F.conv1d(input, self.filter_low.to(input.device), stride=self.stride, groups=self.groups), \ F.conv1d(input, self.filter_high.to(input.device), stride=self.stride, groups=self.groups)