mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
GAN training now supports DistributedDataParallel (DDP) (#558)
* The new vocoder Fre-GAN is now supported * Improved some fregan details * Fixed the problem that the existing model could not be loaded to continue training when training GAN * Updated reference papers * GAN training now supports DistributedDataParallel (DDP) * Added requirements.txt * GAN training uses single card training by default * Added note about GAN vocoder training with multiple GPUs
This commit is contained in:
parent
e726c2eb12
commit
05f886162c
|
@ -90,6 +90,7 @@
|
||||||
* 训练fregan声码器:
|
* 训练fregan声码器:
|
||||||
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
|
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
|
||||||
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
||||||
|
* 将GAN声码器的训练切换为多GPU模式:修改GAN文件夹下.json文件中的"num_gpus"参数
|
||||||
### 3. 启动程序或工具箱
|
### 3. 启动程序或工具箱
|
||||||
您可以尝试使用以下命令:
|
您可以尝试使用以下命令:
|
||||||
|
|
||||||
|
|
|
@ -24,4 +24,5 @@ tensorboard
|
||||||
streamlit==1.8.0
|
streamlit==1.8.0
|
||||||
PyYAML==5.4.1
|
PyYAML==5.4.1
|
||||||
torch_complex
|
torch_complex
|
||||||
espnet
|
espnet
|
||||||
|
PyWavelets
|
|
@ -1,25 +0,0 @@
|
||||||
# Fre-GAN Vocoder
|
|
||||||
[Fre-GAN: Adversarial Frequency-consistent Audio Synthesis](https://arxiv.org/abs/2106.02297)
|
|
||||||
|
|
||||||
## Training:
|
|
||||||
```
|
|
||||||
python train.py --config config.json
|
|
||||||
```
|
|
||||||
|
|
||||||
## Citation:
|
|
||||||
```
|
|
||||||
@misc{kim2021fregan,
|
|
||||||
title={Fre-GAN: Adversarial Frequency-consistent Audio Synthesis},
|
|
||||||
author={Ji-Hoon Kim and Sang-Hoon Lee and Ji-Hyun Lee and Seong-Whan Lee},
|
|
||||||
year={2021},
|
|
||||||
eprint={2106.02297},
|
|
||||||
archivePrefix={arXiv},
|
|
||||||
primaryClass={eess.AS}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
## Note
|
|
||||||
* For more complete and end to end Voice cloning or Text to Speech (TTS) toolbox please visit [Deepsync Technologies](https://deepsync.co/).
|
|
||||||
|
|
||||||
## References:
|
|
||||||
* [Hi-Fi-GAN repo](https://github.com/jik876/hifi-gan)
|
|
||||||
* [WaveSNet repo](https://github.com/LiQiufu/WaveSNet)
|
|
|
@ -7,6 +7,7 @@
|
||||||
"adam_b2": 0.99,
|
"adam_b2": 0.99,
|
||||||
"lr_decay": 0.999,
|
"lr_decay": 0.999,
|
||||||
"seed": 1234,
|
"seed": 1234,
|
||||||
|
"disc_start_step":0,
|
||||||
|
|
||||||
|
|
||||||
"upsample_rates": [5,5,2,2,2],
|
"upsample_rates": [5,5,2,2,2],
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
PyWavelets
|
|
|
@ -135,22 +135,21 @@ def train(rank, a, h):
|
||||||
h.win_size,
|
h.win_size,
|
||||||
h.fmin, h.fmax_for_loss)
|
h.fmin, h.fmax_for_loss)
|
||||||
|
|
||||||
|
if steps > h.disc_start_step:
|
||||||
|
optim_d.zero_grad()
|
||||||
|
|
||||||
|
# MPD
|
||||||
|
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||||
|
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||||
|
|
||||||
optim_d.zero_grad()
|
# MSD
|
||||||
|
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
||||||
|
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||||
|
|
||||||
# MPD
|
loss_disc_all = loss_disc_s + loss_disc_f
|
||||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
|
||||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
|
||||||
|
|
||||||
# MSD
|
loss_disc_all.backward()
|
||||||
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
optim_d.step()
|
||||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
|
||||||
|
|
||||||
loss_disc_all = loss_disc_s + loss_disc_f
|
|
||||||
|
|
||||||
loss_disc_all.backward()
|
|
||||||
optim_d.step()
|
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
optim_g.zero_grad()
|
optim_g.zero_grad()
|
||||||
|
@ -162,15 +161,16 @@ def train(rank, a, h):
|
||||||
# sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
|
# sc_loss, mag_loss = stft_loss(y_g_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
|
||||||
# loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss
|
# loss_mel = h.lambda_aux * (sc_loss + mag_loss) # STFT Loss
|
||||||
|
|
||||||
|
if steps > h.disc_start_step:
|
||||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||||
loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel
|
loss_gen_all = loss_gen_s + loss_gen_f + (2 * (loss_fm_s + loss_fm_f)) + loss_mel
|
||||||
|
else:
|
||||||
|
loss_gen_all = loss_mel
|
||||||
|
|
||||||
loss_gen_all.backward()
|
loss_gen_all.backward()
|
||||||
optim_g.step()
|
optim_g.step()
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
"adam_b2": 0.99,
|
"adam_b2": 0.99,
|
||||||
"lr_decay": 0.999,
|
"lr_decay": 0.999,
|
||||||
"seed": 1234,
|
"seed": 1234,
|
||||||
|
"disc_start_step":0,
|
||||||
|
|
||||||
"upsample_rates": [5,5,4,2],
|
"upsample_rates": [5,5,4,2],
|
||||||
"upsample_kernel_sizes": [10,10,8,4],
|
"upsample_kernel_sizes": [10,10,8,4],
|
||||||
|
@ -27,5 +28,11 @@
|
||||||
"fmax": 7600,
|
"fmax": 7600,
|
||||||
"fmax_for_loss": null,
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
"num_workers": 4
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,21 +137,21 @@ def train(rank, a, h):
|
||||||
y_g_hat = generator(x)
|
y_g_hat = generator(x)
|
||||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
||||||
h.fmin, h.fmax_for_loss)
|
h.fmin, h.fmax_for_loss)
|
||||||
|
if steps > h.disc_start_step:
|
||||||
|
optim_d.zero_grad()
|
||||||
|
|
||||||
optim_d.zero_grad()
|
# MPD
|
||||||
|
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||||
|
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||||
|
|
||||||
# MPD
|
# MSD
|
||||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
||||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||||
|
|
||||||
# MSD
|
loss_disc_all = loss_disc_s + loss_disc_f
|
||||||
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
|
||||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
|
||||||
|
|
||||||
loss_disc_all = loss_disc_s + loss_disc_f
|
loss_disc_all.backward()
|
||||||
|
optim_d.step()
|
||||||
loss_disc_all.backward()
|
|
||||||
optim_d.step()
|
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
optim_g.zero_grad()
|
optim_g.zero_grad()
|
||||||
|
@ -159,13 +159,16 @@ def train(rank, a, h):
|
||||||
# L1 Mel-Spectrogram Loss
|
# L1 Mel-Spectrogram Loss
|
||||||
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
||||||
|
|
||||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
if steps > h.disc_start_step:
|
||||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||||
|
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||||
|
else:
|
||||||
|
loss_gen_all = loss_mel
|
||||||
|
|
||||||
loss_gen_all.backward()
|
loss_gen_all.backward()
|
||||||
optim_g.step()
|
optim_g.step()
|
||||||
|
|
|
@ -6,7 +6,8 @@ from utils.util import AttrDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -69,11 +70,23 @@ if __name__ == "__main__":
|
||||||
with open(args.config) as f:
|
with open(args.config) as f:
|
||||||
json_config = json.load(f)
|
json_config = json.load(f)
|
||||||
h = AttrDict(json_config)
|
h = AttrDict(json_config)
|
||||||
train_hifigan(0, args, h)
|
if h.num_gpus > 1:
|
||||||
|
h.num_gpus = torch.cuda.device_count()
|
||||||
|
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||||
|
print('Batch size per GPU :', h.batch_size)
|
||||||
|
mp.spawn(train_hifigan, nprocs=h.num_gpus, args=(args, h,))
|
||||||
|
else:
|
||||||
|
train_hifigan(0, args, h)
|
||||||
elif args.vocoder_type == "fregan":
|
elif args.vocoder_type == "fregan":
|
||||||
with open('vocoder/fregan/config.json') as f:
|
with open('vocoder/fregan/config.json') as f:
|
||||||
json_config = json.load(f)
|
json_config = json.load(f)
|
||||||
h = AttrDict(json_config)
|
h = AttrDict(json_config)
|
||||||
train_fregan(0, args, h)
|
if h.num_gpus > 1:
|
||||||
|
h.num_gpus = torch.cuda.device_count()
|
||||||
|
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||||
|
print('Batch size per GPU :', h.batch_size)
|
||||||
|
mp.spawn(train_fregan, nprocs=h.num_gpus, args=(args, h,))
|
||||||
|
else:
|
||||||
|
train_fregan(0, args, h)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user