mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Fix compatibility issue of symbols
This commit is contained in:
parent
0bba0a806e
commit
17d47589c1
|
@ -67,8 +67,17 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||
|
||||
# Instantiate Tacotron Model
|
||||
print("\nInitialising Tacotron Model...\n")
|
||||
num_chars = len(symbols)
|
||||
if weights_fpath.exists():
|
||||
# for compatibility purpose, change symbols accordingly:
|
||||
loaded_shape = torch.load(str(weights_fpath), map_location=device)["model_state"]["encoder.embedding.weight"].shape
|
||||
if num_chars != loaded_shape[0]:
|
||||
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
|
||||
num_chars != loaded_shape[0]
|
||||
|
||||
|
||||
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
||||
num_chars=len(symbols),
|
||||
num_chars=num_chars,
|
||||
encoder_dims=hparams.tts_encoder_dims,
|
||||
decoder_dims=hparams.tts_decoder_dims,
|
||||
n_mels=hparams.num_mels,
|
||||
|
|
|
@ -9,6 +9,8 @@ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA
|
|||
_pad = "_"
|
||||
_eos = "~"
|
||||
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!\'(),-.:;? '
|
||||
|
||||
#_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz12340!\'(),-.:;? ' # use this old one if you want to train old model
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
#_arpabet = ["@' + s for s in cmudict.valid_symbols]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user