Support train hifigan (#83)

* support train hifigan
This commit is contained in:
hertz 2021-09-14 13:31:53 +08:00 committed by GitHub
parent 222e302274
commit 3fbe03f2ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 274 additions and 11 deletions

1
.gitignore vendored
View File

@ -17,4 +17,5 @@
*.sh
synthesizer/saved_models/*
vocoder/saved_models/*
cp_hifigan/*
!vocoder/saved_models/pretrained/*

View File

@ -58,9 +58,12 @@
* 预处理数据:
`python vocoder_preprocess.py <datasets_root>`
* 训练声码器:
* 训练wavernn声码器:
`python vocoder_train.py mandarin <datasets_root>`
* 训练hifigan声码器:
`python vocoder_train.py mandarin <datasets_root> hifigan`
### 3. 启动工具箱
然后您可以尝试使用工具箱:
`python demo_toolbox.py -d <datasets_root>`

View File

@ -61,9 +61,12 @@ Codeaid4
* Preprocess the data:
`python vocoder_preprocess.py <datasets_root>`
* Train the vocoder:
* Train the wavernn vocoder:
`python vocoder_train.py mandarin <datasets_root>`
* Train the hifigan vocoder
`python vocoder_train.py mandarin <datasets_root> hifigan`
### 3. Launch the Toolbox
You can then try the toolbox:

View File

@ -361,9 +361,10 @@ class Toolbox:
# Sekect vocoder based on model name
if model_fpath.name[0] == "g":
vocoder = gan_vocoder
self.ui.log("vocoder is hifigan")
self.ui.log("set hifigan as vocoder")
else:
vocoder = rnn_vocoder
self.ui.log("set wavernn as vocoder")
self.ui.log("Loading the vocoder %s... " % model_fpath)
self.ui.set_loading(1)

View File

@ -84,8 +84,8 @@ def get_dataset_filelist(a):
files = os.listdir(a.input_wavs_dir)
random.shuffle(files)
files = [os.path.join(a.input_wavs_dir, f) for f in files]
training_files = files[: -500]
validation_files = files[-500: ]
training_files = files[: -int(len(files)*0.05)]
validation_files = files[-int(len(files)*0.05): ]
return training_files, validation_files

240
vocoder/hifigan/train.py Normal file
View File

@ -0,0 +1,240 @@
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import itertools
import os
import time
import argparse
import json
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DistributedSampler, DataLoader
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from vocoder.hifigan.env import AttrDict, build_env
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
discriminator_loss
from vocoder.hifigan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
torch.backends.cudnn.benchmark = True
def train(rank, a, h):
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
a.checkpoint_path.mkdir(exist_ok=True)
a.training_epochs = 3100
a.stdout_interval = 5
a.checkpoint_interval = 25000
a.summary_interval = 5000
a.validation_interval = 1000
a.fine_tuning = True
a.input_wavs_dir = a.syn_dir.joinpath("audio")
a.input_mels_dir = a.syn_dir.joinpath("mels")
if h.num_gpus > 1:
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
torch.cuda.manual_seed(h.seed)
device = torch.device('cuda:{:d}'.format(rank))
generator = Generator(h).to(device)
mpd = MultiPeriodDiscriminator().to(device)
msd = MultiScaleDiscriminator().to(device)
if rank == 0:
print(generator)
os.makedirs(a.checkpoint_path, exist_ok=True)
print("checkpoints directory : ", a.checkpoint_path)
if os.path.isdir(a.checkpoint_path):
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
steps = 0
if cp_g is None or cp_do is None:
state_dict_do = None
last_epoch = -1
else:
state_dict_g = load_checkpoint(cp_g, device)
state_dict_do = load_checkpoint(cp_do, device)
generator.load_state_dict(state_dict_g['generator'])
mpd.load_state_dict(state_dict_do['mpd'])
msd.load_state_dict(state_dict_do['msd'])
steps = state_dict_do['steps'] + 1
last_epoch = state_dict_do['epoch']
if h.num_gpus > 1:
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
if state_dict_do is not None:
optim_g.load_state_dict(state_dict_do['optim_g'])
optim_d.load_state_dict(state_dict_do['optim_d'])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
training_filelist, validation_filelist = get_dataset_filelist(a)
# print(training_filelist)
# exit()
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
sampler=train_sampler,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True)
if rank == 0:
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir)
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True)
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
generator.train()
mpd.train()
msd.train()
for epoch in range(max(0, last_epoch), a.training_epochs):
if rank == 0:
start = time.time()
print("Epoch: {}".format(epoch+1))
if h.num_gpus > 1:
train_sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
if rank == 0:
start_b = time.time()
x, y, _, y_mel = batch
x = torch.autograd.Variable(x.to(device, non_blocking=True))
y = torch.autograd.Variable(y.to(device, non_blocking=True))
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y = y.unsqueeze(1)
y_g_hat = generator(x)
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
optim_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
loss_disc_all.backward()
optim_d.step()
# Generator
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
loss_gen_all.backward()
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
with torch.no_grad():
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
format(steps, loss_gen_all, mel_error, time.time() - start_b))
# checkpointing
if steps % a.checkpoint_interval == 0 and steps != 0:
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1
else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1
else msd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
# Tensorboard summary logging
if steps % a.summary_interval == 0:
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
# Validation
if steps % a.validation_interval == 0: # and steps != 0:
generator.eval()
torch.cuda.empty_cache()
val_err_tot = 0
with torch.no_grad():
for j, batch in enumerate(validation_loader):
x, y, _, y_mel = batch
y_g_hat = generator(x.to(device))
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
# val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
if j <= 4:
if steps == 0:
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax)
sw.add_figure('generated/y_hat_spec_{}'.format(j),
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
val_err = val_err_tot / (j+1)
sw.add_scalar("validation/mel_spec_error", val_err, steps)
generator.train()
steps += 1
scheduler_g.step()
scheduler_d.step()
if rank == 0:
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))

View File

@ -1,7 +1,7 @@
from torch.utils.data import Dataset
from pathlib import Path
from vocoder import audio
import vocoder.hparams as hp
from vocoder.wavernn import audio
import vocoder.wavernn.hparams as hp
import numpy as np
import torch

View File

@ -1,7 +1,10 @@
from utils.argutils import print_args
from vocoder.wavernn.train import train
from vocoder.hifigan.train import train as train_hifigan
from vocoder.hifigan.env import AttrDict
from pathlib import Path
import argparse
import json
if __name__ == "__main__":
@ -18,6 +21,9 @@ if __name__ == "__main__":
parser.add_argument("datasets_root", type=str, help= \
"Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir "
"will take priority over this argument.")
parser.add_argument("vocoder_type", type=str, default="wavernn", help= \
"Choose the vocoder type for train. Defaults to wavernn"
"Now, Support <hifigan> and <wavernn> for choose")
parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
"the wavs and the embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
@ -37,9 +43,9 @@ if __name__ == "__main__":
"model.")
parser.add_argument("-f", "--force_restart", action="store_true", help= \
"Do not load any saved model and restart from scratch.")
parser.add_argument("--config", type=str, default="vocoder/hifigan/config_16k_.json")
args = parser.parse_args()
# Process the arguments
if not hasattr(args, "syn_dir"):
args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer")
args.syn_dir = Path(args.syn_dir)
@ -50,7 +56,16 @@ if __name__ == "__main__":
args.models_dir = Path(args.models_dir)
args.models_dir.mkdir(exist_ok=True)
# Run the training
print_args(args, parser)
train(**vars(args))
# Process the arguments
if args.vocoder_type == "wavernn":
# Run the training wavernn
train(**vars(args))
elif args.vocoder_type == "hifigan":
with open(args.config) as f:
json_config = json.load(f)
h = AttrDict(json_config)
train_hifigan(0, args, h)