The new vocoder Fre-GAN is now supported (#546)

* The new vocoder Fre-GAN is now supported

* Improved some fregan details
pull/548/head^2
flysmart 2022-05-12 12:27:17 +08:00 committed by GitHub
parent c5d03fb3cb
commit 0caed984e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1624 additions and 1 deletions

View File

@ -87,7 +87,9 @@
* 训练hifigan声码器:
`python vocoder_train.py <trainid> <datasets_root> hifigan`
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
* 训练fregan声码器:
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
### 3. 启动程序或工具箱
您可以尝试使用以下命令:

View File

@ -3,6 +3,7 @@ from encoder import inference as encoder
from synthesizer.inference import Synthesizer
from vocoder.wavernn import inference as rnn_vocoder
from vocoder.hifigan import inference as gan_vocoder
from vocoder.fregan import inference as fgan_vocoder
from pathlib import Path
from time import perf_counter as timer
from toolbox.utterance import Utterance
@ -451,6 +452,15 @@ class Toolbox:
return
if len(model_config_fpaths) > 0:
model_config_fpath = model_config_fpaths[0]
elif model_fpath.name is not None and model_fpath.name.find("fregan") > -1:
vocoder = fgan_vocoder
self.ui.log("set fregan as vocoder")
# search a config file
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
if self.vc_mode and self.ui.current_extractor_fpath is None:
return
if len(model_config_fpaths) > 0:
model_config_fpath = model_config_fpaths[0]
else:
vocoder = rnn_vocoder
self.ui.log("set wavernn as vocoder")

129
vocoder/fregan/.gitignore vendored Normal file
View File

@ -0,0 +1,129 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

21
vocoder/fregan/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021 Rishikesh (ऋषिकेश)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

25
vocoder/fregan/README.md Normal file
View File

@ -0,0 +1,25 @@
# Fre-GAN Vocoder
[Fre-GAN: Adversarial Frequency-consistent Audio Synthesis](https://arxiv.org/abs/2106.02297)
## Training:
```
python train.py --config config.json
```
## Citation:
```
@misc{kim2021fregan,
title={Fre-GAN: Adversarial Frequency-consistent Audio Synthesis},
author={Ji-Hoon Kim and Sang-Hoon Lee and Ji-Hyun Lee and Seong-Whan Lee},
year={2021},
eprint={2106.02297},
archivePrefix={arXiv},
primaryClass={eess.AS}
}
```
## Note
* For more complete and end to end Voice cloning or Text to Speech (TTS) toolbox please visit [Deepsync Technologies](https://deepsync.co/).
## References:
* [Hi-Fi-GAN repo](https://github.com/jik876/hifi-gan)
* [WaveSNet repo](https://github.com/LiQiufu/WaveSNet)

View File

@ -0,0 +1,41 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [5,5,2,2,2],
"upsample_kernel_sizes": [10,10,4,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1, 3, 5, 7], [1,3,5,7], [1,3,5,7]],
"segment_size": 6400,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 200,
"win_size": 800,
"sampling_rate": 16000,
"fmin": 0,
"fmax": 7600,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -0,0 +1,303 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, spectral_norm
from vocoder.fregan.utils import get_padding
from vocoder.fregan.stft_loss import stft
from vocoder.fregan.dwt import DWT_1D
LRELU_SLOPE = 0.1
class SpecDiscriminator(nn.Module):
"""docstring for Discriminator."""
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
super(SpecDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.discriminators = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
])
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
def forward(self, y):
fmap = []
with torch.no_grad():
y = y.squeeze(1)
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
y = y.unsqueeze(1)
for i, d in enumerate(self.discriminators):
y = d(y)
y = F.leaky_relu(y, LRELU_SLOPE)
fmap.append(y)
y = self.out(y)
fmap.append(y)
return torch.flatten(y, 1, -1), fmap
class MultiResSpecDiscriminator(torch.nn.Module):
def __init__(self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window"):
super(MultiResSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.dwt1d = DWT_1D()
self.dwt_conv1 = norm_f(Conv1d(2, 1, 1))
self.dwt_proj1 = norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
self.dwt_conv2 = norm_f(Conv1d(4, 1, 1))
self.dwt_proj2 = norm_f(Conv2d(1, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
self.dwt_conv3 = norm_f(Conv1d(8, 1, 1))
self.dwt_proj3 = norm_f(Conv2d(1, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0)))
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# DWT 1
x_d1_high1, x_d1_low1 = self.dwt1d(x)
x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
# 1d to 2d
b, c, t = x_d1.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x_d1 = F.pad(x_d1, (0, n_pad), "reflect")
t = t + n_pad
x_d1 = x_d1.view(b, c, t // self.period, self.period)
x_d1 = self.dwt_proj1(x_d1)
# DWT 2
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
# 1d to 2d
b, c, t = x_d2.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x_d2 = F.pad(x_d2, (0, n_pad), "reflect")
t = t + n_pad
x_d2 = x_d2.view(b, c, t // self.period, self.period)
x_d2 = self.dwt_proj2(x_d2)
# DWT 3
x_d3_high1, x_d3_low1 = self.dwt1d(x_d2_high1)
x_d3_high2, x_d3_low2 = self.dwt1d(x_d2_low1)
x_d3_high3, x_d3_low3 = self.dwt1d(x_d2_high2)
x_d3_high4, x_d3_low4 = self.dwt1d(x_d2_low2)
x_d3 = self.dwt_conv3(
torch.cat([x_d3_high1, x_d3_low1, x_d3_high2, x_d3_low2, x_d3_high3, x_d3_low3, x_d3_high4, x_d3_low4],
dim=1))
# 1d to 2d
b, c, t = x_d3.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x_d3 = F.pad(x_d3, (0, n_pad), "reflect")
t = t + n_pad
x_d3 = x_d3.view(b, c, t // self.period, self.period)
x_d3 = self.dwt_proj3(x_d3)
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
i = 0
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
if i == 0:
x = torch.cat([x, x_d1], dim=2)
elif i == 1:
x = torch.cat([x, x_d2], dim=2)
elif i == 2:
x = torch.cat([x, x_d3], dim=2)
else:
x = x
i = i + 1
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class ResWiseMultiPeriodDiscriminator(torch.nn.Module):
def __init__(self):
super(ResWiseMultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(2),
DiscriminatorP(3),
DiscriminatorP(5),
DiscriminatorP(7),
DiscriminatorP(11),
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.dwt1d = DWT_1D()
self.dwt_conv1 = norm_f(Conv1d(2, 128, 15, 1, padding=7))
self.dwt_conv2 = norm_f(Conv1d(4, 128, 41, 2, padding=20))
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
# DWT 1
x_d1_high1, x_d1_low1 = self.dwt1d(x)
x_d1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
# DWT 2
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
x_d2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
i = 0
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
if i == 0:
x = torch.cat([x, x_d1], dim=2)
if i == 1:
x = torch.cat([x, x_d2], dim=2)
i = i + 1
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class ResWiseMultiScaleDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(ResWiseMultiScaleDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.dwt1d = DWT_1D()
self.dwt_conv1 = norm_f(Conv1d(2, 1, 1))
self.dwt_conv2 = norm_f(Conv1d(4, 1, 1))
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
# DWT 1
y_hi, y_lo = self.dwt1d(y)
y_1 = self.dwt_conv1(torch.cat([y_hi, y_lo], dim=1))
x_d1_high1, x_d1_low1 = self.dwt1d(y_hat)
y_hat_1 = self.dwt_conv1(torch.cat([x_d1_high1, x_d1_low1], dim=1))
# DWT 2
x_d2_high1, x_d2_low1 = self.dwt1d(y_hi)
x_d2_high2, x_d2_low2 = self.dwt1d(y_lo)
y_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
x_d2_high1, x_d2_low1 = self.dwt1d(x_d1_high1)
x_d2_high2, x_d2_low2 = self.dwt1d(x_d1_low1)
y_hat_2 = self.dwt_conv2(torch.cat([x_d2_high1, x_d2_low1, x_d2_high2, x_d2_low2], dim=1))
for i, d in enumerate(self.discriminators):
if i == 1:
y = y_1
y_hat = y_hat_1
if i == 2:
y = y_2
y_hat = y_hat_2
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

76
vocoder/fregan/dwt.py Normal file
View File

@ -0,0 +1,76 @@
# Copyright (c) 2019, Adobe Inc. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
# 4.0 International Public License. To view a copy of this license, visit
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
# DWT code borrow from https://github.com/LiQiufu/WaveSNet/blob/12cb9d24208c3d26917bf953618c30f0c6b0f03d/DWT_IDWT/DWT_IDWT_layer.py
import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['DWT_1D']
Pad_Mode = ['constant', 'reflect', 'replicate', 'circular']
class DWT_1D(nn.Module):
def __init__(self, pad_type='reflect', wavename='haar',
stride=2, in_channels=1, out_channels=None, groups=None,
kernel_size=None, trainable=False):
super(DWT_1D, self).__init__()
self.trainable = trainable
self.kernel_size = kernel_size
if not self.trainable:
assert self.kernel_size == None
self.in_channels = in_channels
self.out_channels = self.in_channels if out_channels == None else out_channels
self.groups = self.in_channels if groups == None else groups
assert isinstance(self.groups, int) and self.in_channels % self.groups == 0
self.stride = stride
assert self.stride == 2
self.wavename = wavename
self.pad_type = pad_type
assert self.pad_type in Pad_Mode
self.get_filters()
self.initialization()
def get_filters(self):
wavelet = pywt.Wavelet(self.wavename)
band_low = torch.tensor(wavelet.rec_lo)
band_high = torch.tensor(wavelet.rec_hi)
length_band = band_low.size()[0]
self.kernel_size = length_band if self.kernel_size == None else self.kernel_size
assert self.kernel_size >= length_band
a = (self.kernel_size - length_band) // 2
b = - (self.kernel_size - length_band - a)
b = None if b == 0 else b
self.filt_low = torch.zeros(self.kernel_size)
self.filt_high = torch.zeros(self.kernel_size)
self.filt_low[a:b] = band_low
self.filt_high[a:b] = band_high
def initialization(self):
self.filter_low = self.filt_low[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
self.filter_high = self.filt_high[None, None, :].repeat((self.out_channels, self.in_channels // self.groups, 1))
if torch.cuda.is_available():
self.filter_low = self.filter_low.cuda()
self.filter_high = self.filter_high.cuda()
if self.trainable:
self.filter_low = nn.Parameter(self.filter_low)
self.filter_high = nn.Parameter(self.filter_high)
if self.kernel_size % 2 == 0:
self.pad_sizes = [self.kernel_size // 2 - 1, self.kernel_size // 2 - 1]
else:
self.pad_sizes = [self.kernel_size // 2, self.kernel_size // 2]
def forward(self, input):
assert isinstance(input, torch.Tensor)
assert len(input.size()) == 3
assert input.size()[1] == self.in_channels
input = F.pad(input, pad=self.pad_sizes, mode=self.pad_type)
return F.conv1d(input, self.filter_low.to(input.device), stride=self.stride, groups=self.groups), \
F.conv1d(input, self.filter_high.to(input.device), stride=self.stride, groups=self.groups)

210
vocoder/fregan/generator.py Normal file
View File

@ -0,0 +1,210 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from vocoder.fregan.utils import init_weights, get_padding
LRELU_SLOPE = 0.1
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5, 7)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[3],
padding=get_padding(kernel_size, dilation[3])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class FreGAN(torch.nn.Module):
def __init__(self, h, top_k=4):
super(FreGAN, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.upsample_rates = h.upsample_rates
self.up_kernels = h.upsample_kernel_sizes
self.cond_level = self.num_upsamples - top_k
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
self.ups = nn.ModuleList()
self.cond_up = nn.ModuleList()
self.res_output = nn.ModuleList()
upsample_ = 1
kr = 80
for i, (u, k) in enumerate(zip(self.upsample_rates, self.up_kernels)):
# self.ups.append(weight_norm(
# ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
# k, u, padding=(k - u) // 2)))
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
h.upsample_initial_channel//(2**(i+1)),
k, u, padding=(u//2 + u%2), output_padding=u%2)))
if i > (self.num_upsamples - top_k):
self.res_output.append(
nn.Sequential(
nn.Upsample(scale_factor=u, mode='nearest'),
weight_norm(nn.Conv1d(h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2 ** (i + 1)), 1))
)
)
if i >= (self.num_upsamples - top_k):
self.cond_up.append(
weight_norm(
ConvTranspose1d(kr, h.upsample_initial_channel // (2 ** i),
self.up_kernels[i - 1], self.upsample_rates[i - 1],
padding=(self.upsample_rates[i-1]//2+self.upsample_rates[i-1]%2), output_padding=self.upsample_rates[i-1]%2))
)
kr = h.upsample_initial_channel // (2 ** i)
upsample_ *= u
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.cond_up.apply(init_weights)
self.res_output.apply(init_weights)
def forward(self, x):
mel = x
x = self.conv_pre(x)
output = None
for i in range(self.num_upsamples):
if i >= self.cond_level:
mel = self.cond_up[i - self.cond_level](mel)
x += mel
if i > self.cond_level:
if output is None:
output = self.res_output[i - self.cond_level - 1](x)
else:
output = self.res_output[i - self.cond_level - 1](output)
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
if output is not None:
output = output + x
x = F.leaky_relu(output)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
for l in self.cond_up:
remove_weight_norm(l)
for l in self.res_output:
remove_weight_norm(l[1])
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
'''
to run this, fix
from . import ResStack
into
from res_stack import ResStack
'''
if __name__ == '__main__':
'''
torch.Size([3, 80, 10])
torch.Size([3, 1, 2000])
4527362
'''
with open('config.json') as f:
data = f.read()
from utils import AttrDict
import json
json_config = json.loads(data)
h = AttrDict(json_config)
model = FreGAN(h)
c = torch.randn(3, 80, 10) # (B, channels, T).
print(c.shape)
y = model(c) # (B, 1, T ** prod(upsample_scales)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560]) # For normal melgan torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

View File

@ -0,0 +1,74 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import json
import torch
from utils.util import AttrDict
from vocoder.fregan.generator import FreGAN
generator = None # type: FreGAN
output_sample_rate = None
_device = None
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print("Loading '{}'".format(filepath))
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def load_model(weights_fpath, config_fpath=None, verbose=True):
global generator, _device, output_sample_rate
if verbose:
print("Building fregan")
if config_fpath == None:
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
if len(model_config_fpaths) > 0:
config_fpath = model_config_fpaths[0]
else:
config_fpath = "./vocoder/fregan/config.json"
with open(config_fpath) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
output_sample_rate = h.sampling_rate
torch.manual_seed(h.seed)
if torch.cuda.is_available():
# _model = _model.cuda()
_device = torch.device('cuda')
else:
_device = torch.device('cpu')
generator = FreGAN(h).to(_device)
state_dict_g = load_checkpoint(
weights_fpath, _device
)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()
def is_loaded():
return generator is not None
def infer_waveform(mel, progress_callback=None):
if generator is None:
raise Exception("Please load fre-gan in memory before using it")
mel = torch.FloatTensor(mel).to(_device)
mel = mel.unsqueeze(0)
with torch.no_grad():
y_g_hat = generator(mel)
audio = y_g_hat.squeeze()
audio = audio.cpu().numpy()
return audio, output_sample_rate

35
vocoder/fregan/loss.py Normal file
View File

@ -0,0 +1,35 @@
import torch
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss*2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1-dr)**2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
gen_losses.append(l)
loss += l
return loss, gen_losses

View File

@ -0,0 +1,176 @@
import math
import os
import random
import torch
import torch.utils.data
import numpy as np
from librosa.util import normalize
from scipy.io.wavfile import read
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
mel_basis = {}
hann_window = {}
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global mel_basis, hann_window
if fmax not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
center=center, pad_mode='reflect', normalized=False, onesided=True)
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
def get_dataset_filelist(a):
#with open(a.input_training_file, 'r', encoding='utf-8') as fi:
# training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
# for x in fi.read().split('\n') if len(x) > 0]
#with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
# validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
# for x in fi.read().split('\n') if len(x) > 0]
files = os.listdir(a.input_wavs_dir)
random.shuffle(files)
files = [os.path.join(a.input_wavs_dir, f) for f in files]
training_files = files[: -int(len(files) * 0.05)]
validation_files = files[-int(len(files) * 0.05):]
return training_files, validation_files
class MelDataset(torch.utils.data.Dataset):
def __init__(self, training_files, segment_size, n_fft, num_mels,
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
self.audio_files = training_files
random.seed(1234)
if shuffle:
random.shuffle(self.audio_files)
self.segment_size = segment_size
self.sampling_rate = sampling_rate
self.split = split
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.fmax_loss = fmax_loss
self.cached_wav = None
self.n_cache_reuse = n_cache_reuse
self._cache_ref_count = 0
self.device = device
self.fine_tuning = fine_tuning
self.base_mels_path = base_mels_path
def __getitem__(self, index):
filename = self.audio_files[index]
if self._cache_ref_count == 0:
#audio, sampling_rate = load_wav(filename)
#audio = audio / MAX_WAV_VALUE
audio = np.load(filename)
if not self.fine_tuning:
audio = normalize(audio) * 0.95
self.cached_wav = audio
#if sampling_rate != self.sampling_rate:
# raise ValueError("{} SR doesn't match target {} SR".format(
# sampling_rate, self.sampling_rate))
self._cache_ref_count = self.n_cache_reuse
else:
audio = self.cached_wav
self._cache_ref_count -= 1
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)
if not self.fine_tuning:
if self.split:
if audio.size(1) >= self.segment_size:
max_audio_start = audio.size(1) - self.segment_size
audio_start = random.randint(0, max_audio_start)
audio = audio[:, audio_start:audio_start+self.segment_size]
else:
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
center=False)
else:
mel_path = os.path.join(self.base_mels_path, "mel" + "-" + filename.split("/")[-1].split("-")[-1])
mel = np.load(mel_path).T
#mel = np.load(
# os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
mel = torch.from_numpy(mel)
if len(mel.shape) < 3:
mel = mel.unsqueeze(0)
if self.split:
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
if audio.size(1) >= self.segment_size:
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
mel = mel[:, :, mel_start:mel_start + frames_per_seg]
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
else:
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
center=False)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
def __len__(self):
return len(self.audio_files)

201
vocoder/fregan/modules.py Normal file
View File

@ -0,0 +1,201 @@
import torch
import torch.nn.functional as F
class KernelPredictor(torch.nn.Module):
''' Kernel predictor for the location-variable convolutions
'''
def __init__(self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1}
):
'''
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int):
kpnet_
'''
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
l_b = conv_out_channels * conv_layers
padding = (kpnet_conv_size - 1) // 2
self.input_conv = torch.nn.Sequential(
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_conv = torch.nn.Sequential(
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Dropout(kpnet_dropout),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size,
padding=padding, bias=True)
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding,
bias=True)
def forward(self, c):
'''
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
'''
batch, cond_channels, cond_length = c.shape
c = self.input_conv(c)
c = c + self.residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(batch,
self.conv_layers,
self.conv_in_channels,
self.conv_out_channels,
self.conv_kernel_size,
cond_length)
bias = b.contiguous().view(batch,
self.conv_layers,
self.conv_out_channels,
cond_length)
return kernels, bias
class LVCBlock(torch.nn.Module):
''' the location-variable convolutions
'''
def __init__(self,
in_channels,
cond_channels,
upsample_ratio,
conv_layers=4,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = conv_layers
self.conv_kernel_size = conv_kernel_size
self.convs = torch.nn.ModuleList()
self.upsample = torch.nn.ConvTranspose1d(in_channels, in_channels,
kernel_size=upsample_ratio*2, stride=upsample_ratio,
padding=upsample_ratio // 2 + upsample_ratio % 2,
output_padding=upsample_ratio % 2)
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=conv_layers,
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout
)
for i in range(conv_layers):
padding = (3 ** i) * int((conv_kernel_size - 1) / 2)
conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i)
self.convs.append(conv)
def forward(self, x, c):
''' forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
'''
batch, in_channels, in_length = x.shape
kernels, bias = self.kernel_predictor(c)
x = F.leaky_relu(x, 0.2)
x = self.upsample(x)
for i in range(self.conv_layers):
y = F.leaky_relu(x, 0.2)
y = self.convs[i](y)
y = F.leaky_relu(y, 0.2)
k = kernels[:, i, :, :, :, :]
b = bias[:, i, :, :]
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
return x
def location_variable_convolution(self, x, kernel, bias, dilation, hop_size):
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
'''
batch, in_channels, in_length = x.shape
batch, in_channels, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), 'constant', 0)
x = x.unfold(3, dilation,
dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
o = o + bias.unsqueeze(-1).unsqueeze(-1)
o = o.contiguous().view(batch, out_channels, -1)
return o

View File

@ -0,0 +1 @@
PyWavelets

246
vocoder/fregan/train.py Normal file
View File

@ -0,0 +1,246 @@
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import itertools
import os
import time
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DistributedSampler, DataLoader
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from vocoder.fregan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from vocoder.fregan.generator import FreGAN
from vocoder.fregan.discriminator import ResWiseMultiPeriodDiscriminator, ResWiseMultiScaleDiscriminator
from vocoder.fregan.loss import feature_loss, generator_loss, discriminator_loss
from vocoder.fregan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
torch.backends.cudnn.benchmark = True
def train(rank, a, h):
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_fregan')
a.checkpoint_path.mkdir(exist_ok=True)
a.training_epochs = 3100
a.stdout_interval = 5
a.checkpoint_interval = a.backup_every
a.summary_interval = 5000
a.validation_interval = 1000
a.fine_tuning = True
a.input_wavs_dir = a.syn_dir.joinpath("audio")
a.input_mels_dir = a.syn_dir.joinpath("mels")
if h.num_gpus > 1:
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
torch.cuda.manual_seed(h.seed)
device = torch.device('cuda:{:d}'.format(rank))
generator = FreGAN(h).to(device)
mpd = ResWiseMultiPeriodDiscriminator().to(device)
msd = ResWiseMultiScaleDiscriminator().to(device)
if rank == 0:
print(generator)
os.makedirs(a.checkpoint_path, exist_ok=True)
print("checkpoints directory : ", a.checkpoint_path)
if os.path.isdir(a.checkpoint_path):
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
steps = 0
if cp_g is None or cp_do is None:
state_dict_do = None
last_epoch = -1
else:
state_dict_g = load_checkpoint(cp_g, device)
state_dict_do = load_checkpoint(cp_do, device)
generator.load_state_dict(state_dict_g['generator'])
mpd.load_state_dict(state_dict_do['mpd'])
msd.load_state_dict(state_dict_do['msd'])
steps = state_dict_do['steps'] + 1
last_epoch = state_dict_do['epoch']
if h.num_gpus > 1:
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
if state_dict_do is not None:
optim_g.load_state_dict(state_dict_do['optim_g'])
optim_d.load_state_dict(state_dict_do['optim_d'])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
training_filelist, validation_filelist = get_dataset_filelist(a)
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
sampler=train_sampler,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True)
if rank == 0:
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir)
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True)
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
generator.train()
mpd.train()
msd.train()
for epoch in range(max(0, last_epoch), a.training_epochs):
if rank == 0:
start = time.time()
print("Epoch: {}".format(epoch + 1))
if h.num_gpus > 1:
train_sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
if rank == 0:
start_b = time.time()
x, y, _, y_mel = batch
x = torch.autograd.Variable(x.to(device, non_blocking=True))
y = torch.autograd.Variable(y.to(device, non_blocking=True))
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y = y.unsqueeze(1)
y_g_hat = generator(x)
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size,
h.win_size,
h.fmin, h.fmax_for_loss)
optim_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
loss_disc_all.backward()
optim_d.step()
# Generator
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
# sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
# loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel
loss_gen_all.backward()
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
with torch.no_grad():
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
format(steps, loss_gen_all, mel_error, time.time() - start_b))
# checkpointing
if steps % a.checkpoint_interval == 0 and steps != 0:
checkpoint_path = "{}/g_fregan_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_fregan_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1
else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1
else msd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
# Tensorboard summary logging
if steps % a.summary_interval == 0:
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
# Validation
if steps % a.validation_interval == 0: # and steps != 0:
generator.eval()
torch.cuda.empty_cache()
val_err_tot = 0
with torch.no_grad():
for j, batch in enumerate(validation_loader):
x, y, _, y_mel = batch
y_g_hat = generator(x.to(device))
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
#val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
if j <= 4:
if steps == 0:
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax)
sw.add_figure('generated/y_hat_spec_{}'.format(j),
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
val_err = val_err_tot / (j + 1)
sw.add_scalar("validation/mel_spec_error", val_err, steps)
generator.train()
steps += 1
scheduler_g.step()
scheduler_d.step()
if rank == 0:
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))

65
vocoder/fregan/utils.py Normal file
View File

@ -0,0 +1,65 @@
import glob
import os
import matplotlib
import torch
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
import shutil
def build_env(config, config_name, path):
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
shutil.copyfile(config, os.path.join(path, config_name))
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def apply_weight_norm(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
weight_norm(m)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print("Loading '{}'".format(filepath))
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def save_checkpoint(filepath, obj):
print("Saving checkpoint to {}".format(filepath))
torch.save(obj, filepath)
print("Complete.")
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '????????')
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None
return sorted(cp_list)[-1]

View File

@ -1,6 +1,7 @@
from utils.argutils import print_args
from vocoder.wavernn.train import train
from vocoder.hifigan.train import train as train_hifigan
from vocoder.fregan.train import train as train_fregan
from utils.util import AttrDict
from pathlib import Path
import argparse
@ -61,11 +62,18 @@ if __name__ == "__main__":
# Process the arguments
if args.vocoder_type == "wavernn":
# Run the training wavernn
delattr(args, 'vocoder_type')
delattr(args, 'config')
train(**vars(args))
elif args.vocoder_type == "hifigan":
with open(args.config) as f:
json_config = json.load(f)
h = AttrDict(json_config)
train_hifigan(0, args, h)
elif args.vocoder_type == "fregan":
with open('vocoder/fregan/config.json') as f:
json_config = json.load(f)
h = AttrDict(json_config)
train_fregan(0, args, h)