Fix compatibility issue of symbols

This commit is contained in:
babysor00 2021-08-29 00:45:49 +08:00
parent 0bba0a806e
commit 17d47589c1
2 changed files with 12 additions and 1 deletions

View File

@ -67,8 +67,17 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
# Instantiate Tacotron Model
print("\nInitialising Tacotron Model...\n")
num_chars = len(symbols)
if weights_fpath.exists():
# for compatibility purpose, change symbols accordingly:
loaded_shape = torch.load(str(weights_fpath), map_location=device)["model_state"]["encoder.embedding.weight"].shape
if num_chars != loaded_shape[0]:
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
num_chars != loaded_shape[0]
model = Tacotron(embed_dims=hparams.tts_embed_dims,
num_chars=len(symbols),
num_chars=num_chars,
encoder_dims=hparams.tts_encoder_dims,
decoder_dims=hparams.tts_decoder_dims,
n_mels=hparams.num_mels,

View File

@ -9,6 +9,8 @@ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA
_pad = "_"
_eos = "~"
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!\'(),-.:;? '
#_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz12340!\'(),-.:;? ' # use this old one if you want to train old model
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
#_arpabet = ["@' + s for s in cmudict.valid_symbols]