Hifigan Support train from existed checkpoint. (#389)

* 1k steps to save tmp hifigan model

* hifigan support train from existed ckpt
This commit is contained in:
hertz 2022-02-27 11:01:47 +08:00 committed by GitHub
parent b79e9d68e4
commit 9e072c2619
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -185,7 +185,7 @@ def train(rank, a, h):
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
checkpoint_path = "{}/do_{:08d}.pt".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
@ -203,7 +203,7 @@ def train(rank, a, h):
checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_hifigan".format(a.checkpoint_path)
checkpoint_path = "{}/do_hifigan.pt".format(a.checkpoint_path)
save_checkpoint(checkpoint_path,
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),

View File

@ -50,7 +50,7 @@ def save_checkpoint(filepath, obj):
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '????????')
pattern = os.path.join(cp_dir, prefix + 'hifigan.pt')
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None