mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
1k steps to save tmp hifigan model (#240)
This commit is contained in:
parent
b50c7984ab
commit
a4daf42868
|
@ -23,11 +23,11 @@ torch.backends.cudnn.benchmark = True
|
|||
|
||||
def train(rank, a, h):
|
||||
|
||||
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
|
||||
a.checkpoint_path = a.models_dir.joinpath(a.run_id+'_hifigan')
|
||||
a.checkpoint_path.mkdir(exist_ok=True)
|
||||
a.training_epochs = 3100
|
||||
a.stdout_interval = 5
|
||||
a.checkpoint_interval = 25000
|
||||
a.checkpoint_interval = a.backup_every
|
||||
a.summary_interval = 5000
|
||||
a.validation_interval = 1000
|
||||
a.fine_tuning = True
|
||||
|
@ -186,11 +186,9 @@ def train(rank, a, h):
|
|||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1
|
||||
else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1
|
||||
else msd).state_dict(),
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
|
||||
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
||||
'epoch': epoch})
|
||||
|
||||
|
@ -198,6 +196,19 @@ def train(rank, a, h):
|
|||
if steps % a.summary_interval == 0:
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
|
||||
|
||||
# save temperate hifigan model
|
||||
if steps % a.save_every == 0:
|
||||
checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_hifigan".format(a.checkpoint_path)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
|
||||
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
||||
'epoch': epoch})
|
||||
|
||||
# Validation
|
||||
if steps % a.validation_interval == 0: # and steps != 0:
|
||||
|
|
Loading…
Reference in New Issue
Block a user