mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
GAN training now supports DistributedDataParallel (DDP) (#558)
* The new vocoder Fre-GAN is now supported * Improved some fregan details * Fixed the problem that the existing model could not be loaded to continue training when training GAN * Updated reference papers * GAN training now supports DistributedDataParallel (DDP) * Added requirements.txt * GAN training uses single card training by default * Added note about GAN vocoder training with multiple GPUs
This commit is contained in:
parent
e726c2eb12
commit
05f886162c
|
@ -90,6 +90,7 @@
|
|||
* 训练fregan声码器:
|
||||
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
|
||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||
* 将GAN声码器的训练切换为多GPU模式:修改GAN文件夹下.json文件中的"num_gpus"参数
|
||||
### 3. 启动程序或工具箱
|
||||
您可以尝试使用以下命令:
|
||||
|
||||
|
|
|
@ -25,3 +25,4 @@ streamlit==1.8.0
|
|||
PyYAML==5.4.1
|
||||
torch_complex
|
||||
espnet
|
||||
PyWavelets
|
|
@ -1,25 +0,0 @@
|
|||
# Fre-GAN Vocoder
|
||||
[Fre-GAN: Adversarial Frequency-consistent Audio Synthesis](https://arxiv.org/abs/2106.02297)
|
||||
|
||||
## Training:
|
||||
```
|
||||
python train.py --config config.json
|
||||
```
|
||||
|
||||
## Citation:
|
||||
```
|
||||
@misc{kim2021fregan,
|
||||
title={Fre-GAN: Adversarial Frequency-consistent Audio Synthesis},
|
||||
author={Ji-Hoon Kim and Sang-Hoon Lee and Ji-Hyun Lee and Seong-Whan Lee},
|
||||
year={2021},
|
||||
eprint={2106.02297},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={eess.AS}
|
||||
}
|
||||
```
|
||||
## Note
|
||||
* For more complete and end to end Voice cloning or Text to Speech (TTS) toolbox please visit [Deepsync Technologies](https://deepsync.co/).
|
||||
|
||||
## References:
|
||||
* [Hi-Fi-GAN repo](https://github.com/jik876/hifi-gan)
|
||||
* [WaveSNet repo](https://github.com/LiQiufu/WaveSNet)
|
|
@ -7,6 +7,7 @@
|
|||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"disc_start_step":0,
|
||||
|
||||
|
||||
"upsample_rates": [5,5,2,2,2],
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
PyWavelets
|
|
@ -135,8 +135,7 @@ def train(rank, a, h):
|
|||
h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
|
||||
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MPD
|
||||
|
@ -162,7 +161,7 @@ def train(rank, a, h):
|
|||
# sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
|
||||
# loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss
|
||||
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
|
@ -170,7 +169,8 @@ def train(rank, a, h):
|
|||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel
|
||||
|
||||
else:
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
optim_g.step()
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"disc_start_step":0,
|
||||
|
||||
"upsample_rates": [5,5,4,2],
|
||||
"upsample_kernel_sizes": [10,10,8,4],
|
||||
|
@ -27,5 +28,11 @@
|
|||
"fmax": 7600,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -137,7 +137,7 @@ def train(rank, a, h):
|
|||
y_g_hat = generator(x)
|
||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MPD
|
||||
|
@ -159,6 +159,7 @@ def train(rank, a, h):
|
|||
# L1 Mel-Spectrogram Loss
|
||||
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
||||
|
||||
if steps > h.disc_start_step:
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
|
@ -166,6 +167,8 @@ def train(rank, a, h):
|
|||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
else:
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
optim_g.step()
|
||||
|
|
|
@ -6,7 +6,8 @@ from utils.util import AttrDict
|
|||
from pathlib import Path
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -69,11 +70,23 @@ if __name__ == "__main__":
|
|||
with open(args.config) as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
if h.num_gpus > 1:
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print('Batch size per GPU :', h.batch_size)
|
||||
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
|
||||
else:
|
||||
train_hifigan(0, args, h)
|
||||
elif args.vocoder_type == "fregan":
|
||||
with open('vocoder/fregan/config.json') as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
if h.num_gpus > 1:
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print('Batch size per GPU :', h.batch_size)
|
||||
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
|
||||
else:
|
||||
train_fregan(0, args, h)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user