Add config file for pretrained

This commit is contained in:
babysor00 2022-02-23 09:37:39 +08:00
parent 4529479091
commit 0536874dec
3 changed files with 27 additions and 0 deletions

View File

@ -1,5 +1,6 @@
import ast import ast
import pprint import pprint
import json
class HParams(object): class HParams(object):
def __init__(self, **kwargs): self.__dict__.update(kwargs) def __init__(self, **kwargs): self.__dict__.update(kwargs)
@ -18,6 +19,18 @@ class HParams(object):
self.__dict__[k] = ast.literal_eval(values[keys.index(k)]) self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
return self return self
def loadJson(self, dict):
print("\Loading the json with %s\n", dict)
for k in dict.keys():
self.__dict__[k] = dict[k]
return self
def dumpJson(self, fp):
print("\Saving the json with %s\n", fp)
with fp.open("w", encoding="utf-8") as f:
json.dump(self.__dict__, f)
return self
hparams = HParams( hparams = HParams(
### Signal Processing (used in both synthesizer and vocoder) ### Signal Processing (used in both synthesizer and vocoder)
sample_rate = 16000, sample_rate = 16000,

View File

@ -10,6 +10,7 @@ from typing import Union, List
import numpy as np import numpy as np
import librosa import librosa
from utils import logmmse from utils import logmmse
import json
from pypinyin import lazy_pinyin, Style from pypinyin import lazy_pinyin, Style
class Synthesizer: class Synthesizer:
@ -44,6 +45,11 @@ class Synthesizer:
return self._model is not None return self._model is not None
def load(self): def load(self):
# Try to scan config file
model_config_fpaths = list(self.model_fpath.parent.rglob("*.json"))
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
hparams.loadJson(json.load(f))
""" """
Instantiates and loads the model given the weights file that was passed in the constructor. Instantiates and loads the model given the weights file that was passed in the constructor.
""" """

View File

@ -12,6 +12,7 @@ from synthesizer.utils.symbols import symbols
from synthesizer.utils.text import sequence_to_text from synthesizer.utils.text import sequence_to_text
from vocoder.display import * from vocoder.display import *
from datetime import datetime from datetime import datetime
import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import sys import sys
@ -75,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
if num_chars != loaded_shape[0]: if num_chars != loaded_shape[0]:
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`") print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
num_chars != loaded_shape[0] num_chars != loaded_shape[0]
# Try to scan config file
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
hparams.loadJson(json.load(f))
else: # save a config
hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json"))
model = Tacotron(embed_dims=hparams.tts_embed_dims, model = Tacotron(embed_dims=hparams.tts_embed_dims,