Add config file for pretrained

This commit is contained in:
babysor00 2022-02-23 09:37:39 +08:00
parent 4529479091
commit 0536874dec
3 changed files with 27 additions and 0 deletions

View File

@ -1,5 +1,6 @@
import ast
import pprint
import json
class HParams(object):
def __init__(self, **kwargs): self.__dict__.update(kwargs)
@ -18,6 +19,18 @@ class HParams(object):
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
return self
def loadJson(self, dict):
print("\Loading the json with %s\n", dict)
for k in dict.keys():
self.__dict__[k] = dict[k]
return self
def dumpJson(self, fp):
print("\Saving the json with %s\n", fp)
with fp.open("w", encoding="utf-8") as f:
json.dump(self.__dict__, f)
return self
hparams = HParams(
### Signal Processing (used in both synthesizer and vocoder)
sample_rate = 16000,

View File

@ -10,6 +10,7 @@ from typing import Union, List
import numpy as np
import librosa
from utils import logmmse
import json
from pypinyin import lazy_pinyin, Style
class Synthesizer:
@ -44,6 +45,11 @@ class Synthesizer:
return self._model is not None
def load(self):
# Try to scan config file
model_config_fpaths = list(self.model_fpath.parent.rglob("*.json"))
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
hparams.loadJson(json.load(f))
"""
Instantiates and loads the model given the weights file that was passed in the constructor.
"""

View File

@ -12,6 +12,7 @@ from synthesizer.utils.symbols import symbols
from synthesizer.utils.text import sequence_to_text
from vocoder.display import *
from datetime import datetime
import json
import numpy as np
from pathlib import Path
import sys
@ -75,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
if num_chars != loaded_shape[0]:
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
num_chars != loaded_shape[0]
# Try to scan config file
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
hparams.loadJson(json.load(f))
else: # save a config
hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json"))
model = Tacotron(embed_dims=hparams.tts_embed_dims,