From a4daf428688316d291abb06836ccb5de84b28b05 Mon Sep 17 00:00:00 2001 From: hertz Date: Mon, 29 Nov 2021 21:09:54 +0800 Subject: [PATCH] 1k steps to save tmp hifigan model (#240) --- vocoder/hifigan/train.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/vocoder/hifigan/train.py b/vocoder/hifigan/train.py index 1914b27..f27daba 100644 --- a/vocoder/hifigan/train.py +++ b/vocoder/hifigan/train.py @@ -23,11 +23,11 @@ torch.backends.cudnn.benchmark = True def train(rank, a, h): - a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan') + a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan') a.checkpoint_path.mkdir(exist_ok=True) a.training_epochs = 3100 a.stdout_interval = 5 - a.checkpoint_interval = 25000 + a.checkpoint_interval = a.backup_every a.summary_interval = 5000 a.validation_interval = 1000 a.fine_tuning = True @@ -186,11 +186,9 @@ def train(rank, a, h): save_checkpoint(checkpoint_path, {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) - save_checkpoint(checkpoint_path, - {'mpd': (mpd.module if h.num_gpus > 1 - else mpd).state_dict(), - 'msd': (msd.module if h.num_gpus > 1 - else msd).state_dict(), + save_checkpoint(checkpoint_path, + {'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), + 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch}) @@ -198,6 +196,19 @@ def train(rank, a, h): if steps % a.summary_interval == 0: sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) sw.add_scalar("training/mel_spec_error", mel_error, steps) + + + # save temperate hifigan model + if steps % a.save_every == 0: + checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path) + save_checkpoint(checkpoint_path, + {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) + checkpoint_path = "{}/do_hifigan".format(a.checkpoint_path) + save_checkpoint(checkpoint_path, + {'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), + 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), + 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, + 'epoch': epoch}) # Validation if steps % a.validation_interval == 0: # and steps != 0: