1k steps to save tmp hifigan model (#240)

This commit is contained in:
hertz 2021-11-29 21:09:54 +08:00 committed by GitHub
parent b50c7984ab
commit a4daf42868
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,11 +23,11 @@ torch.backends.cudnn.benchmark = True
def train(rank, a, h):
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
a.checkpoint_path.mkdir(exist_ok=True)
a.training_epochs = 3100
a.stdout_interval = 5
a.checkpoint_interval = 25000
a.checkpoint_interval = a.backup_every
a.summary_interval = 5000
a.validation_interval = 1000
a.fine_tuning = True
@ -186,11 +186,9 @@ def train(rank, a, h):
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1
else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1
else msd).state_dict(),
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
@ -198,6 +196,19 @@ def train(rank, a, h):
if steps % a.summary_interval == 0:
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
# save temperate hifigan model
if steps % a.save_every == 0:
checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_hifigan".format(a.checkpoint_path)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
# Validation
if steps % a.validation_interval == 0: # and steps != 0: