From 4529479091c256ff9340c39b40207206a584e54c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AD=90?= <54951765+kslz@users.noreply.github.com> Date: Thu, 10 Feb 2022 20:47:26 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=8C=87=E5=AE=9Alibrosa=E7=89=88=E6=9C=AC?= =?UTF-8?q?=20(#378)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 支持data_aishell(SLR33)数据集 * 更新readme * 指定librosa版本 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9010c30..02a3c5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ umap-learn visdom -librosa>=0.8.0 +librosa==0.8.1 matplotlib>=3.3.0 numpy==1.19.3; platform_system == "Windows" numpy==1.19.4; platform_system != "Windows" From 0536874dec68e68969502ce1774168552727fa17 Mon Sep 17 00:00:00 2001 From: babysor00 Date: Wed, 23 Feb 2022 09:37:39 +0800 Subject: [PATCH 2/2] Add config file for pretrained --- synthesizer/hparams.py | 13 +++++++++++++ synthesizer/inference.py | 6 ++++++ synthesizer/train.py | 8 ++++++++ 3 files changed, 27 insertions(+) diff --git a/synthesizer/hparams.py b/synthesizer/hparams.py index 672634f..84cec9d 100644 --- a/synthesizer/hparams.py +++ b/synthesizer/hparams.py @@ -1,5 +1,6 @@ import ast import pprint +import json class HParams(object): def __init__(self, **kwargs): self.__dict__.update(kwargs) @@ -18,6 +19,18 @@ class HParams(object): self.__dict__[k] = ast.literal_eval(values[keys.index(k)]) return self + def loadJson(self, dict): + print("\Loading the json with %s\n", dict) + for k in dict.keys(): + self.__dict__[k] = dict[k] + return self + + def dumpJson(self, fp): + print("\Saving the json with %s\n", fp) + with fp.open("w", encoding="utf-8") as f: + json.dump(self.__dict__, f) + return self + hparams = HParams( ### Signal Processing (used in both synthesizer and vocoder) sample_rate = 16000, diff --git a/synthesizer/inference.py b/synthesizer/inference.py index 3a6dc6c..2b4d15b 100644 --- a/synthesizer/inference.py +++ b/synthesizer/inference.py @@ -10,6 +10,7 @@ from typing import Union, List import numpy as np import librosa from utils import logmmse +import json from pypinyin import lazy_pinyin, Style class Synthesizer: @@ -44,6 +45,11 @@ class Synthesizer: return self._model is not None def load(self): + # Try to scan config file + model_config_fpaths = list(self.model_fpath.parent.rglob("*.json")) + if len(model_config_fpaths)>0 and model_config_fpaths[0].exists(): + with model_config_fpaths[0].open("r", encoding="utf-8") as f: + hparams.loadJson(json.load(f)) """ Instantiates and loads the model given the weights file that was passed in the constructor. """ diff --git a/synthesizer/train.py b/synthesizer/train.py index 4ade026..2532348 100644 --- a/synthesizer/train.py +++ b/synthesizer/train.py @@ -12,6 +12,7 @@ from synthesizer.utils.symbols import symbols from synthesizer.utils.text import sequence_to_text from vocoder.display import * from datetime import datetime +import json import numpy as np from pathlib import Path import sys @@ -75,6 +76,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, if num_chars != loaded_shape[0]: print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`") num_chars != loaded_shape[0] + # Try to scan config file + model_config_fpaths = list(weights_fpath.parent.rglob("*.json")) + if len(model_config_fpaths)>0 and model_config_fpaths[0].exists(): + with model_config_fpaths[0].open("r", encoding="utf-8") as f: + hparams.loadJson(json.load(f)) + else: # save a config + hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json")) model = Tacotron(embed_dims=hparams.tts_embed_dims,