mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Add config file for pretrained
This commit is contained in:
parent
4529479091
commit
0536874dec
|
@ -1,5 +1,6 @@
|
||||||
import ast
|
import ast
|
||||||
import pprint
|
import pprint
|
||||||
|
import json
|
||||||
|
|
||||||
class HParams(object):
|
class HParams(object):
|
||||||
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
||||||
|
@ -18,6 +19,18 @@ class HParams(object):
|
||||||
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def loadJson(self, dict):
|
||||||
|
print("\Loading the json with %s\n", dict)
|
||||||
|
for k in dict.keys():
|
||||||
|
self.__dict__[k] = dict[k]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def dumpJson(self, fp):
|
||||||
|
print("\Saving the json with %s\n", fp)
|
||||||
|
with fp.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(self.__dict__, f)
|
||||||
|
return self
|
||||||
|
|
||||||
hparams = HParams(
|
hparams = HParams(
|
||||||
### Signal Processing (used in both synthesizer and vocoder)
|
### Signal Processing (used in both synthesizer and vocoder)
|
||||||
sample_rate = 16000,
|
sample_rate = 16000,
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import Union, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import librosa
|
import librosa
|
||||||
from utils import logmmse
|
from utils import logmmse
|
||||||
|
import json
|
||||||
from pypinyin import lazy_pinyin, Style
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
|
||||||
class Synthesizer:
|
class Synthesizer:
|
||||||
|
@ -44,6 +45,11 @@ class Synthesizer:
|
||||||
return self._model is not None
|
return self._model is not None
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
# Try to scan config file
|
||||||
|
model_config_fpaths = list(self.model_fpath.parent.rglob("*.json"))
|
||||||
|
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
|
||||||
|
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
|
||||||
|
hparams.loadJson(json.load(f))
|
||||||
"""
|
"""
|
||||||
Instantiates and loads the model given the weights file that was passed in the constructor.
|
Instantiates and loads the model given the weights file that was passed in the constructor.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,6 +12,7 @@ from synthesizer.utils.symbols import symbols
|
||||||
from synthesizer.utils.text import sequence_to_text
|
from synthesizer.utils.text import sequence_to_text
|
||||||
from vocoder.display import *
|
from vocoder.display import *
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
@ -75,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
if num_chars != loaded_shape[0]:
|
if num_chars != loaded_shape[0]:
|
||||||
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
|
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
|
||||||
num_chars != loaded_shape[0]
|
num_chars != loaded_shape[0]
|
||||||
|
# Try to scan config file
|
||||||
|
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||||
|
if len(model_config_fpaths)>0 and model_config_fpaths[0].exists():
|
||||||
|
with model_config_fpaths[0].open("r", encoding="utf-8") as f:
|
||||||
|
hparams.loadJson(json.load(f))
|
||||||
|
else: # save a config
|
||||||
|
hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json"))
|
||||||
|
|
||||||
|
|
||||||
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user