Support tensorboard to trace the training of Synthesizer (#98)

* add tensorborad tracing

* add log_every params
This commit is contained in:
hertz 2021-09-25 17:06:51 +08:00 committed by GitHub
parent 99269b2046
commit 4acfee2a64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 9 deletions

View File

@ -2,11 +2,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import optim from torch import optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from synthesizer import audio from synthesizer import audio
from synthesizer.models.tacotron import Tacotron from synthesizer.models.tacotron import Tacotron
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
from synthesizer.utils import ValueWindow, data_parallel_workaround from synthesizer.utils import ValueWindow, data_parallel_workaround
from synthesizer.utils.plot import plot_spectrogram from synthesizer.utils.plot import plot_spectrogram, plot_spectrogram_and_trace
from synthesizer.utils.symbols import symbols from synthesizer.utils.symbols import symbols
from synthesizer.utils.text import sequence_to_text from synthesizer.utils.text import sequence_to_text
from vocoder.display import * from vocoder.display import *
@ -23,7 +24,7 @@ def time_string():
return datetime.now().strftime("%Y-%m-%d %H:%M") return datetime.now().strftime("%Y-%m-%d %H:%M")
def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
backup_every: int, force_restart:bool, hparams): backup_every: int, log_every:int, force_restart:bool, hparams):
syn_dir = Path(syn_dir) syn_dir = Path(syn_dir)
models_dir = Path(models_dir) models_dir = Path(models_dir)
@ -123,6 +124,9 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
shuffle=True, shuffle=True,
pin_memory=True) pin_memory=True)
# tracing training step
sw = SummaryWriter(log_dir=model_dir.joinpath("logs"))
for i, session in enumerate(hparams.tts_schedule): for i, session in enumerate(hparams.tts_schedule):
current_step = model.get_step() current_step = model.get_step()
@ -208,9 +212,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
step = model.get_step() step = model.get_step()
k = step // 1000 k = step // 1000
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | " msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
stream(msg) stream(msg)
if log_every != 0 and step % log_every == 0 :
sw.add_scalar("training/loss", loss_window.average, step)
# Backup or save model as appropriate # Backup or save model as appropriate
if backup_every != 0 and step % backup_every == 0 : if backup_every != 0 and step % backup_every == 0 :
backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k)) backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
@ -220,6 +228,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
# Must save latest optimizer state to ensure that resuming training # Must save latest optimizer state to ensure that resuming training
# doesn't produce artifacts # doesn't produce artifacts
model.save(weights_fpath, optimizer) model.save(weights_fpath, optimizer)
# Evaluate model to generate samples # Evaluate model to generate samples
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
@ -233,7 +242,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length] mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length] target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
attention_len = mel_length // model.r attention_len = mel_length // model.r
# eval_loss = F.mse_loss(mel_prediction, target_spectrogram)
# sw.add_scalar("validing/loss", eval_loss.item(), step)
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]), eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
mel_prediction=mel_prediction, mel_prediction=mel_prediction,
target_spectrogram=target_spectrogram, target_spectrogram=target_spectrogram,
@ -244,7 +254,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
wav_dir=wav_dir, wav_dir=wav_dir,
sample_num=sample_idx + 1, sample_num=sample_idx + 1,
loss=loss, loss=loss,
hparams=hparams) hparams=hparams,
sw=sw)
# Break out of loop to update training schedule # Break out of loop to update training schedule
if step >= max_step: if step >= max_step:
@ -254,10 +265,11 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
print("") print("")
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step, def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams): plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams, sw):
# Save some results for evaluation # Save some results for evaluation
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num))) attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
save_attention(attention, attention_path) # save_attention(attention, attention_path)
save_and_trace_attention(attention, attention_path, sw, step)
# save predicted mel spectrogram to disk (debug) # save predicted mel spectrogram to disk (debug)
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num)) mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
@ -271,7 +283,15 @@ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
# save real and predicted mel-spectrogram plot to disk (control purposes) # save real and predicted mel-spectrogram plot to disk (control purposes)
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num)) spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss) title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str, # plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
target_spectrogram=target_spectrogram, # target_spectrogram=target_spectrogram,
max_len=target_spectrogram.size // hparams.num_mels) # max_len=target_spectrogram.size // hparams.num_mels)
plot_spectrogram_and_trace(
mel_prediction,
str(spec_fpath),
title=title_str,
target_spectrogram=target_spectrogram,
max_len=target_spectrogram.size // hparams.num_mels,
sw=sw,
step=step)
print("Input at step {}: {}".format(step, sequence_to_text(input_seq))) print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))

View File

@ -74,3 +74,42 @@ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, targ
plt.tight_layout() plt.tight_layout()
plt.savefig(path, format="png") plt.savefig(path, format="png")
plt.close() plt.close()
def plot_spectrogram_and_trace(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False, sw=None, step=0):
if max_len is not None:
target_spectrogram = target_spectrogram[:max_len]
pred_spectrogram = pred_spectrogram[:max_len]
if split_title:
title = split_title_line(title)
fig = plt.figure(figsize=(10, 8))
# Set common labels
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
#target spectrogram subplot
if target_spectrogram is not None:
ax1 = fig.add_subplot(311)
ax2 = fig.add_subplot(312)
if auto_aspect:
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
else:
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
ax1.set_title("Target Mel-Spectrogram")
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
ax2.set_title("Predicted Mel-Spectrogram")
else:
ax2 = fig.add_subplot(211)
if auto_aspect:
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
else:
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
plt.tight_layout()
plt.savefig(path, format="png")
sw.add_figure("spectrogram", fig, step)
plt.close()

View File

@ -21,6 +21,8 @@ if __name__ == "__main__":
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \ parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
"Number of steps between backups of the model. Set to 0 to never make backups of the " "Number of steps between backups of the model. Set to 0 to never make backups of the "
"model.") "model.")
parser.add_argument("-l", "--log_every", type=int, default=200, help= \
"Number of steps between summary the training info in tensorboard")
parser.add_argument("-f", "--force_restart", action="store_true", help= \ parser.add_argument("-f", "--force_restart", action="store_true", help= \
"Do not load any saved model and restart from scratch.") "Do not load any saved model and restart from scratch.")
parser.add_argument("--hparams", default="", parser.add_argument("--hparams", default="",

View File

@ -91,6 +91,14 @@ def save_attention(attn, path) :
plt.close(fig) plt.close(fig)
def save_and_trace_attention(attn, path, sw, step):
fig = plt.figure(figsize=(12, 6))
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
fig.savefig(f'{path}.png', bbox_inches='tight')
sw.add_figure('attention', fig, step)
plt.close(fig)
def save_spectrogram(M, path, length=None) : def save_spectrogram(M, path, length=None) :
M = np.flip(M, axis=0) M = np.flip(M, axis=0)
if length : M = M[:, :length] if length : M = M[:, :length]