mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
parent
222e302274
commit
3fbe03f2ff
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -17,4 +17,5 @@
|
|||
*.sh
|
||||
synthesizer/saved_models/*
|
||||
vocoder/saved_models/*
|
||||
cp_hifigan/*
|
||||
!vocoder/saved_models/pretrained/*
|
|
@ -58,9 +58,12 @@
|
|||
* 预处理数据:
|
||||
`python vocoder_preprocess.py <datasets_root>`
|
||||
|
||||
* 训练声码器:
|
||||
* 训练wavernn声码器:
|
||||
`python vocoder_train.py mandarin <datasets_root>`
|
||||
|
||||
* 训练hifigan声码器:
|
||||
`python vocoder_train.py mandarin <datasets_root> hifigan`
|
||||
|
||||
### 3. 启动工具箱
|
||||
然后您可以尝试使用工具箱:
|
||||
`python demo_toolbox.py -d <datasets_root>`
|
||||
|
|
|
@ -61,9 +61,12 @@ Code:aid4
|
|||
* Preprocess the data:
|
||||
`python vocoder_preprocess.py <datasets_root>`
|
||||
|
||||
* Train the vocoder:
|
||||
* Train the wavernn vocoder:
|
||||
`python vocoder_train.py mandarin <datasets_root>`
|
||||
|
||||
* Train the hifigan vocoder
|
||||
`python vocoder_train.py mandarin <datasets_root> hifigan`
|
||||
|
||||
### 3. Launch the Toolbox
|
||||
You can then try the toolbox:
|
||||
|
||||
|
|
|
@ -361,9 +361,10 @@ class Toolbox:
|
|||
# Sekect vocoder based on model name
|
||||
if model_fpath.name[0] == "g":
|
||||
vocoder = gan_vocoder
|
||||
self.ui.log("vocoder is hifigan")
|
||||
self.ui.log("set hifigan as vocoder")
|
||||
else:
|
||||
vocoder = rnn_vocoder
|
||||
self.ui.log("set wavernn as vocoder")
|
||||
|
||||
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
|
|
|
@ -84,8 +84,8 @@ def get_dataset_filelist(a):
|
|||
files = os.listdir(a.input_wavs_dir)
|
||||
random.shuffle(files)
|
||||
files = [os.path.join(a.input_wavs_dir, f) for f in files]
|
||||
training_files = files[: -500]
|
||||
validation_files = files[-500: ]
|
||||
training_files = files[: -int(len(files)*0.05)]
|
||||
validation_files = files[-int(len(files)*0.05): ]
|
||||
|
||||
return training_files, validation_files
|
||||
|
||||
|
|
240
vocoder/hifigan/train.py
Normal file
240
vocoder/hifigan/train.py
Normal file
|
@ -0,0 +1,240 @@
|
|||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from vocoder.hifigan.env import AttrDict, build_env
|
||||
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
||||
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
|
||||
discriminator_loss
|
||||
from vocoder.hifigan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def train(rank, a, h):
|
||||
|
||||
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
|
||||
a.checkpoint_path.mkdir(exist_ok=True)
|
||||
a.training_epochs = 3100
|
||||
a.stdout_interval = 5
|
||||
a.checkpoint_interval = 25000
|
||||
a.summary_interval = 5000
|
||||
a.validation_interval = 1000
|
||||
a.fine_tuning = True
|
||||
|
||||
a.input_wavs_dir = a.syn_dir.joinpath("audio")
|
||||
a.input_mels_dir = a.syn_dir.joinpath("mels")
|
||||
|
||||
if h.num_gpus > 1:
|
||||
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
||||
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
||||
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
device = torch.device('cuda:{:d}'.format(rank))
|
||||
|
||||
generator = Generator(h).to(device)
|
||||
mpd = MultiPeriodDiscriminator().to(device)
|
||||
msd = MultiScaleDiscriminator().to(device)
|
||||
|
||||
if rank == 0:
|
||||
print(generator)
|
||||
os.makedirs(a.checkpoint_path, exist_ok=True)
|
||||
print("checkpoints directory : ", a.checkpoint_path)
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
|
||||
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
|
||||
|
||||
steps = 0
|
||||
if cp_g is None or cp_do is None:
|
||||
state_dict_do = None
|
||||
last_epoch = -1
|
||||
else:
|
||||
state_dict_g = load_checkpoint(cp_g, device)
|
||||
state_dict_do = load_checkpoint(cp_do, device)
|
||||
generator.load_state_dict(state_dict_g['generator'])
|
||||
mpd.load_state_dict(state_dict_do['mpd'])
|
||||
msd.load_state_dict(state_dict_do['msd'])
|
||||
steps = state_dict_do['steps'] + 1
|
||||
last_epoch = state_dict_do['epoch']
|
||||
|
||||
if h.num_gpus > 1:
|
||||
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
||||
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
|
||||
if state_dict_do is not None:
|
||||
optim_g.load_state_dict(state_dict_do['optim_g'])
|
||||
optim_d.load_state_dict(state_dict_do['optim_d'])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
|
||||
training_filelist, validation_filelist = get_dataset_filelist(a)
|
||||
|
||||
# print(training_filelist)
|
||||
# exit()
|
||||
|
||||
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
|
||||
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
||||
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
||||
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
|
||||
|
||||
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
||||
|
||||
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
||||
sampler=train_sampler,
|
||||
batch_size=h.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
if rank == 0:
|
||||
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
|
||||
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
||||
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir)
|
||||
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
||||
sampler=None,
|
||||
batch_size=1,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
||||
|
||||
generator.train()
|
||||
mpd.train()
|
||||
msd.train()
|
||||
for epoch in range(max(0, last_epoch), a.training_epochs):
|
||||
if rank == 0:
|
||||
start = time.time()
|
||||
print("Epoch: {}".format(epoch+1))
|
||||
|
||||
if h.num_gpus > 1:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
if rank == 0:
|
||||
start_b = time.time()
|
||||
x, y, _, y_mel = batch
|
||||
x = torch.autograd.Variable(x.to(device, non_blocking=True))
|
||||
y = torch.autograd.Variable(y.to(device, non_blocking=True))
|
||||
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
||||
y = y.unsqueeze(1)
|
||||
|
||||
y_g_hat = generator(x)
|
||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
# MSD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
loss_disc_all.backward()
|
||||
optim_d.step()
|
||||
|
||||
# Generator
|
||||
optim_g.zero_grad()
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
||||
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
with torch.no_grad():
|
||||
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
||||
|
||||
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
|
||||
format(steps, loss_gen_all, mel_error, time.time() - start_b))
|
||||
|
||||
# checkpointing
|
||||
if steps % a.checkpoint_interval == 0 and steps != 0:
|
||||
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1
|
||||
else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1
|
||||
else msd).state_dict(),
|
||||
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
||||
'epoch': epoch})
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
|
||||
# Validation
|
||||
if steps % a.validation_interval == 0: # and steps != 0:
|
||||
generator.eval()
|
||||
torch.cuda.empty_cache()
|
||||
val_err_tot = 0
|
||||
with torch.no_grad():
|
||||
for j, batch in enumerate(validation_loader):
|
||||
x, y, _, y_mel = batch
|
||||
y_g_hat = generator(x.to(device))
|
||||
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
|
||||
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
||||
h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax_for_loss)
|
||||
# val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
||||
|
||||
if j <= 4:
|
||||
if steps == 0:
|
||||
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
||||
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
||||
|
||||
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
||||
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
|
||||
h.sampling_rate, h.hop_size, h.win_size,
|
||||
h.fmin, h.fmax)
|
||||
sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
||||
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
||||
|
||||
val_err = val_err_tot / (j+1)
|
||||
sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
||||
|
||||
generator.train()
|
||||
|
||||
steps += 1
|
||||
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
|
@ -1,7 +1,7 @@
|
|||
from torch.utils.data import Dataset
|
||||
from pathlib import Path
|
||||
from vocoder import audio
|
||||
import vocoder.hparams as hp
|
||||
from vocoder.wavernn import audio
|
||||
import vocoder.wavernn.hparams as hp
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from utils.argutils import print_args
|
||||
from vocoder.wavernn.train import train
|
||||
from vocoder.hifigan.train import train as train_hifigan
|
||||
from vocoder.hifigan.env import AttrDict
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -18,6 +21,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("datasets_root", type=str, help= \
|
||||
"Path to the directory containing your SV2TTS directory. Specifying --syn_dir or --voc_dir "
|
||||
"will take priority over this argument.")
|
||||
parser.add_argument("vocoder_type", type=str, default="wavernn", help= \
|
||||
"Choose the vocoder type for train. Defaults to wavernn"
|
||||
"Now, Support <hifigan> and <wavernn> for choose")
|
||||
parser.add_argument("--syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
||||
"the wavs and the embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/.")
|
||||
|
@ -37,9 +43,9 @@ if __name__ == "__main__":
|
|||
"model.")
|
||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||
"Do not load any saved model and restart from scratch.")
|
||||
parser.add_argument("--config", type=str, default="vocoder/hifigan/config_16k_.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Process the arguments
|
||||
if not hasattr(args, "syn_dir"):
|
||||
args.syn_dir = Path(args.datasets_root, "SV2TTS", "synthesizer")
|
||||
args.syn_dir = Path(args.syn_dir)
|
||||
|
@ -50,7 +56,16 @@ if __name__ == "__main__":
|
|||
args.models_dir = Path(args.models_dir)
|
||||
args.models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Run the training
|
||||
print_args(args, parser)
|
||||
train(**vars(args))
|
||||
|
||||
|
||||
# Process the arguments
|
||||
if args.vocoder_type == "wavernn":
|
||||
# Run the training wavernn
|
||||
train(**vars(args))
|
||||
elif args.vocoder_type == "hifigan":
|
||||
with open(args.config) as f:
|
||||
json_config = json.load(f)
|
||||
h = AttrDict(json_config)
|
||||
train_hifigan(0, args, h)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user