GAN training now supports DistributedDataParallel (DDP) (#558)

* The new vocoder Fre-GAN is now supported

* Improved some fregan details

* Fixed the problem that the existing model could not be loaded to continue training when training GAN

* Updated reference papers

* GAN training now supports DistributedDataParallel (DDP)

* Added requirements.txt

* GAN training uses single card training by default

* Added note about GAN vocoder training with multiple GPUs
This commit is contained in:
flysmart 2022-05-22 16:24:50 +08:00 committed by GitHub
parent e726c2eb12
commit 05f886162c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 70 deletions

View File

@ -90,6 +90,7 @@
* 训练fregan声码器:
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
* 将GAN声码器的训练切换为多GPU模式修改GAN文件夹下.json文件中的"num_gpus"参数
### 3. 启动程序或工具箱
您可以尝试使用以下命令:

View File

@ -25,3 +25,4 @@ streamlit==1.8.0
PyYAML==5.4.1
torch_complex
espnet
PyWavelets

View File

@ -1,25 +0,0 @@
# Fre-GAN Vocoder
[Fre-GAN: Adversarial Frequency-consistent Audio Synthesis](https://arxiv.org/abs/2106.02297)
## Training:
```
python train.py --config config.json
```
## Citation:
```
@misc{kim2021fregan,
title={Fre-GAN: Adversarial Frequency-consistent Audio Synthesis},
author={Ji-Hoon Kim and Sang-Hoon Lee and Ji-Hyun Lee and Seong-Whan Lee},
year={2021},
eprint={2106.02297},
archivePrefix={arXiv},
primaryClass={eess.AS}
}
```
## Note
* For more complete and end to end Voice cloning or Text to Speech (TTS) toolbox please visit [Deepsync Technologies](https://deepsync.co/).
## References:
* [Hi-Fi-GAN repo](https://github.com/jik876/hifi-gan)
* [WaveSNet repo](https://github.com/LiQiufu/WaveSNet)

View File

@ -7,6 +7,7 @@
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"disc_start_step":0,
"upsample_rates": [5,5,2,2,2],

View File

@ -1 +0,0 @@
PyWavelets

View File

@ -135,8 +135,7 @@ def train(rank, a, h):
h.win_size,
h.fmin, h.fmax_for_loss)
if steps > h.disc_start_step:
optim_d.zero_grad()
# MPD
@ -162,7 +161,7 @@ def train(rank, a, h):
# sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
# loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss
if steps > h.disc_start_step:
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
@ -170,7 +169,8 @@ def train(rank, a, h):
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel
else:
loss_gen_all = loss_mel
loss_gen_all.backward()
optim_g.step()

View File

@ -7,6 +7,7 @@
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"disc_start_step":0,
"upsample_rates": [5,5,4,2],
"upsample_kernel_sizes": [10,10,8,4],
@ -27,5 +28,11 @@
"fmax": 7600,
"fmax_for_loss": null,
"num_workers": 4
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -137,7 +137,7 @@ def train(rank, a, h):
y_g_hat = generator(x)
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
if steps > h.disc_start_step:
optim_d.zero_grad()
# MPD
@ -159,6 +159,7 @@ def train(rank, a, h):
# L1 Mel-Spectrogram Loss
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
if steps > h.disc_start_step:
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
@ -166,6 +167,8 @@ def train(rank, a, h):
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
else:
loss_gen_all = loss_mel
loss_gen_all.backward()
optim_g.step()

View File

@ -6,7 +6,8 @@ from utils.util import AttrDict
from pathlib import Path
import argparse
import json
import torch
import torch.multiprocessing as mp
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@ -69,11 +70,23 @@ if __name__ == "__main__":
with open(args.config) as f:
json_config = json.load(f)
h = AttrDict(json_config)
if h.num_gpus > 1:
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
else:
train_hifigan(0, args, h)
elif args.vocoder_type == "fregan":
with open('vocoder/fregan/config.json') as f:
json_config = json.load(f)
h = AttrDict(json_config)
if h.num_gpus > 1:
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
else:
train_fregan(0, args, h)