diff --git a/synthesizer/train.py b/synthesizer/train.py index f327987..f1570aa 100644 --- a/synthesizer/train.py +++ b/synthesizer/train.py @@ -2,11 +2,12 @@ import torch import torch.nn.functional as F from torch import optim from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter from synthesizer import audio from synthesizer.models.tacotron import Tacotron from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer from synthesizer.utils import ValueWindow, data_parallel_workaround -from synthesizer.utils.plot import plot_spectrogram +from synthesizer.utils.plot import plot_spectrogram, plot_spectrogram_and_trace from synthesizer.utils.symbols import symbols from synthesizer.utils.text import sequence_to_text from vocoder.display import * @@ -23,7 +24,7 @@ def time_string(): return datetime.now().strftime("%Y-%m-%d %H:%M") def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, - backup_every: int, force_restart:bool, hparams): + backup_every: int, log_every:int, force_restart:bool, hparams): syn_dir = Path(syn_dir) models_dir = Path(models_dir) @@ -123,6 +124,9 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, shuffle=True, pin_memory=True) + # tracing training step + sw = SummaryWriter(log_dir=model_dir.joinpath("logs")) + for i, session in enumerate(hparams.tts_schedule): current_step = model.get_step() @@ -208,9 +212,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, step = model.get_step() k = step // 1000 + msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | " stream(msg) + if log_every != 0 and step % log_every == 0 : + sw.add_scalar("training/loss", loss_window.average, step) + # Backup or save model as appropriate if backup_every != 0 and step % backup_every == 0 : backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k)) @@ -220,6 +228,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, # Must save latest optimizer state to ensure that resuming training # doesn't produce artifacts model.save(weights_fpath, optimizer) + # Evaluate model to generate samples epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done @@ -233,7 +242,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length] target_spectrogram = np_now(mels[sample_idx]).T[:mel_length] attention_len = mel_length // model.r - + # eval_loss = F.mse_loss(mel_prediction, target_spectrogram) + # sw.add_scalar("validing/loss", eval_loss.item(), step) eval_model(attention=np_now(attention[sample_idx][:, :attention_len]), mel_prediction=mel_prediction, target_spectrogram=target_spectrogram, @@ -244,7 +254,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, wav_dir=wav_dir, sample_num=sample_idx + 1, loss=loss, - hparams=hparams) + hparams=hparams, + sw=sw) # Break out of loop to update training schedule if step >= max_step: @@ -254,10 +265,11 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, print("") def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step, - plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams): + plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams, sw): # Save some results for evaluation attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num))) - save_attention(attention, attention_path) + # save_attention(attention, attention_path) + save_and_trace_attention(attention, attention_path, sw, step) # save predicted mel spectrogram to disk (debug) mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num)) @@ -271,7 +283,15 @@ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step, # save real and predicted mel-spectrogram plot to disk (control purposes) spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num)) title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss) - plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str, - target_spectrogram=target_spectrogram, - max_len=target_spectrogram.size // hparams.num_mels) + # plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str, + # target_spectrogram=target_spectrogram, + # max_len=target_spectrogram.size // hparams.num_mels) + plot_spectrogram_and_trace( + mel_prediction, + str(spec_fpath), + title=title_str, + target_spectrogram=target_spectrogram, + max_len=target_spectrogram.size // hparams.num_mels, + sw=sw, + step=step) print("Input at step {}: {}".format(step, sequence_to_text(input_seq))) diff --git a/synthesizer/utils/plot.py b/synthesizer/utils/plot.py index f47d271..efdb567 100644 --- a/synthesizer/utils/plot.py +++ b/synthesizer/utils/plot.py @@ -74,3 +74,42 @@ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, targ plt.tight_layout() plt.savefig(path, format="png") plt.close() + + +def plot_spectrogram_and_trace(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False, sw=None, step=0): + if max_len is not None: + target_spectrogram = target_spectrogram[:max_len] + pred_spectrogram = pred_spectrogram[:max_len] + + if split_title: + title = split_title_line(title) + + fig = plt.figure(figsize=(10, 8)) + # Set common labels + fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16) + + #target spectrogram subplot + if target_spectrogram is not None: + ax1 = fig.add_subplot(311) + ax2 = fig.add_subplot(312) + + if auto_aspect: + im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none") + else: + im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none") + ax1.set_title("Target Mel-Spectrogram") + fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1) + ax2.set_title("Predicted Mel-Spectrogram") + else: + ax2 = fig.add_subplot(211) + + if auto_aspect: + im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none") + else: + im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none") + fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2) + + plt.tight_layout() + plt.savefig(path, format="png") + sw.add_figure("spectrogram", fig, step) + plt.close() \ No newline at end of file diff --git a/synthesizer_train.py b/synthesizer_train.py index 2743d59..0f0b598 100644 --- a/synthesizer_train.py +++ b/synthesizer_train.py @@ -21,6 +21,8 @@ if __name__ == "__main__": parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \ "Number of steps between backups of the model. Set to 0 to never make backups of the " "model.") + parser.add_argument("-l", "--log_every", type=int, default=200, help= \ + "Number of steps between summary the training info in tensorboard") parser.add_argument("-f", "--force_restart", action="store_true", help= \ "Do not load any saved model and restart from scratch.") parser.add_argument("--hparams", default="", diff --git a/vocoder/display.py b/vocoder/display.py index 9568807..fe7dd30 100644 --- a/vocoder/display.py +++ b/vocoder/display.py @@ -91,6 +91,14 @@ def save_attention(attn, path) : plt.close(fig) +def save_and_trace_attention(attn, path, sw, step): + fig = plt.figure(figsize=(12, 6)) + plt.imshow(attn.T, interpolation='nearest', aspect='auto') + fig.savefig(f'{path}.png', bbox_inches='tight') + sw.add_figure('attention', fig, step) + plt.close(fig) + + def save_spectrogram(M, path, length=None) : M = np.flip(M, axis=0) if length : M = M[:, :length]