diff --git a/requirements.txt b/requirements.txt index 7fc9e5f..1091207 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ umap-learn visdom -librosa>=0.8.0 +librosa==0.8.1 matplotlib>=3.3.0 numpy==1.19.3; platform_system == "Windows" numpy==1.19.4; platform_system != "Windows" diff --git a/synthesizer/hparams.py b/synthesizer/hparams.py index 672634f..84cec9d 100644 --- a/synthesizer/hparams.py +++ b/synthesizer/hparams.py @@ -1,5 +1,6 @@ import ast import pprint +import json class HParams(object): def __init__(self, **kwargs): self.__dict__.update(kwargs) @@ -18,6 +19,18 @@ class HParams(object): self.__dict__[k] = ast.literal_eval(values[keys.index(k)]) return self + def loadJson(self, dict): + print("\Loading the json with %s\n", dict) + for k in dict.keys(): + self.__dict__[k] = dict[k] + return self + + def dumpJson(self, fp): + print("\Saving the json with %s\n", fp) + with fp.open("w", encoding="utf-8") as f: + json.dump(self.__dict__, f) + return self + hparams = HParams( ### Signal Processing (used in both synthesizer and vocoder) sample_rate = 16000, diff --git a/synthesizer/inference.py b/synthesizer/inference.py index 3a6dc6c..2b4d15b 100644 --- a/synthesizer/inference.py +++ b/synthesizer/inference.py @@ -10,6 +10,7 @@ from typing import Union, List import numpy as np import librosa from utils import logmmse +import json from pypinyin import lazy_pinyin, Style class Synthesizer: @@ -44,6 +45,11 @@ class Synthesizer: return self._model is not None def load(self): + # Try to scan config file + model_config_fpaths = list(self.model_fpath.parent.rglob("*.json")) + if len(model_config_fpaths)>0 and model_config_fpaths[0].exists(): + with model_config_fpaths[0].open("r", encoding="utf-8") as f: + hparams.loadJson(json.load(f)) """ Instantiates and loads the model given the weights file that was passed in the constructor. """ diff --git a/synthesizer/train.py b/synthesizer/train.py index 4ade026..2532348 100644 --- a/synthesizer/train.py +++ b/synthesizer/train.py @@ -12,6 +12,7 @@ from synthesizer.utils.symbols import symbols from synthesizer.utils.text import sequence_to_text from vocoder.display import * from datetime import datetime +import json import numpy as np from pathlib import Path import sys @@ -75,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, if num_chars != loaded_shape[0]: print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`") num_chars != loaded_shape[0] + # Try to scan config file + model_config_fpaths = list(weights_fpath.parent.rglob("*.json")) + if len(model_config_fpaths)>0 and model_config_fpaths[0].exists(): + with model_config_fpaths[0].open("r", encoding="utf-8") as f: + hparams.loadJson(json.load(f)) + else: # save a config + hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json")) model = Tacotron(embed_dims=hparams.tts_embed_dims,