mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Hifigan Support train from existed checkpoint. (#389)
* 1k steps to save tmp hifigan model * hifigan support train from existed ckpt
This commit is contained in:
parent
b79e9d68e4
commit
9e072c2619
|
@ -185,7 +185,7 @@ def train(rank, a, h):
|
|||
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
|
||||
checkpoint_path = "{}/do_{:08d}.pt".format(a.checkpoint_path, steps)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
|
||||
|
@ -203,7 +203,7 @@ def train(rank, a, h):
|
|||
checkpoint_path = "{}/g_hifigan.pt".format(a.checkpoint_path)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
||||
checkpoint_path = "{}/do_hifigan".format(a.checkpoint_path)
|
||||
checkpoint_path = "{}/do_hifigan.pt".format(a.checkpoint_path)
|
||||
save_checkpoint(checkpoint_path,
|
||||
{'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||
'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(),
|
||||
|
|
|
@ -50,7 +50,7 @@ def save_checkpoint(filepath, obj):
|
|||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + '????????')
|
||||
pattern = os.path.join(cp_dir, prefix + 'hifigan.pt')
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue
Block a user