mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Support tensorboard to trace the training of Synthesizer (#98)
* add tensorborad tracing * add log_every params
This commit is contained in:
parent
99269b2046
commit
4acfee2a64
|
@ -2,11 +2,12 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from synthesizer import audio
|
from synthesizer import audio
|
||||||
from synthesizer.models.tacotron import Tacotron
|
from synthesizer.models.tacotron import Tacotron
|
||||||
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
||||||
from synthesizer.utils import ValueWindow, data_parallel_workaround
|
from synthesizer.utils import ValueWindow, data_parallel_workaround
|
||||||
from synthesizer.utils.plot import plot_spectrogram
|
from synthesizer.utils.plot import plot_spectrogram, plot_spectrogram_and_trace
|
||||||
from synthesizer.utils.symbols import symbols
|
from synthesizer.utils.symbols import symbols
|
||||||
from synthesizer.utils.text import sequence_to_text
|
from synthesizer.utils.text import sequence_to_text
|
||||||
from vocoder.display import *
|
from vocoder.display import *
|
||||||
|
@ -23,7 +24,7 @@ def time_string():
|
||||||
return datetime.now().strftime("%Y-%m-%d %H:%M")
|
return datetime.now().strftime("%Y-%m-%d %H:%M")
|
||||||
|
|
||||||
def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
backup_every: int, force_restart:bool, hparams):
|
backup_every: int, log_every:int, force_restart:bool, hparams):
|
||||||
|
|
||||||
syn_dir = Path(syn_dir)
|
syn_dir = Path(syn_dir)
|
||||||
models_dir = Path(models_dir)
|
models_dir = Path(models_dir)
|
||||||
|
@ -123,6 +124,9 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
|
# tracing training step
|
||||||
|
sw = SummaryWriter(log_dir=model_dir.joinpath("logs"))
|
||||||
|
|
||||||
for i, session in enumerate(hparams.tts_schedule):
|
for i, session in enumerate(hparams.tts_schedule):
|
||||||
current_step = model.get_step()
|
current_step = model.get_step()
|
||||||
|
|
||||||
|
@ -208,9 +212,13 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
step = model.get_step()
|
step = model.get_step()
|
||||||
k = step // 1000
|
k = step // 1000
|
||||||
|
|
||||||
|
|
||||||
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
|
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
|
||||||
stream(msg)
|
stream(msg)
|
||||||
|
|
||||||
|
if log_every != 0 and step % log_every == 0 :
|
||||||
|
sw.add_scalar("training/loss", loss_window.average, step)
|
||||||
|
|
||||||
# Backup or save model as appropriate
|
# Backup or save model as appropriate
|
||||||
if backup_every != 0 and step % backup_every == 0 :
|
if backup_every != 0 and step % backup_every == 0 :
|
||||||
backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
|
backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
|
||||||
|
@ -220,6 +228,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
# Must save latest optimizer state to ensure that resuming training
|
# Must save latest optimizer state to ensure that resuming training
|
||||||
# doesn't produce artifacts
|
# doesn't produce artifacts
|
||||||
model.save(weights_fpath, optimizer)
|
model.save(weights_fpath, optimizer)
|
||||||
|
|
||||||
|
|
||||||
# Evaluate model to generate samples
|
# Evaluate model to generate samples
|
||||||
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
|
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
|
||||||
|
@ -233,7 +242,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
|
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
|
||||||
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
|
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
|
||||||
attention_len = mel_length // model.r
|
attention_len = mel_length // model.r
|
||||||
|
# eval_loss = F.mse_loss(mel_prediction, target_spectrogram)
|
||||||
|
# sw.add_scalar("validing/loss", eval_loss.item(), step)
|
||||||
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
|
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
|
||||||
mel_prediction=mel_prediction,
|
mel_prediction=mel_prediction,
|
||||||
target_spectrogram=target_spectrogram,
|
target_spectrogram=target_spectrogram,
|
||||||
|
@ -244,7 +254,8 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
wav_dir=wav_dir,
|
wav_dir=wav_dir,
|
||||||
sample_num=sample_idx + 1,
|
sample_num=sample_idx + 1,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
hparams=hparams)
|
hparams=hparams,
|
||||||
|
sw=sw)
|
||||||
|
|
||||||
# Break out of loop to update training schedule
|
# Break out of loop to update training schedule
|
||||||
if step >= max_step:
|
if step >= max_step:
|
||||||
|
@ -254,10 +265,11 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
||||||
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
|
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams, sw):
|
||||||
# Save some results for evaluation
|
# Save some results for evaluation
|
||||||
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
|
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
|
||||||
save_attention(attention, attention_path)
|
# save_attention(attention, attention_path)
|
||||||
|
save_and_trace_attention(attention, attention_path, sw, step)
|
||||||
|
|
||||||
# save predicted mel spectrogram to disk (debug)
|
# save predicted mel spectrogram to disk (debug)
|
||||||
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
|
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
|
||||||
|
@ -271,7 +283,15 @@ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
||||||
# save real and predicted mel-spectrogram plot to disk (control purposes)
|
# save real and predicted mel-spectrogram plot to disk (control purposes)
|
||||||
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
|
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
|
||||||
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
|
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
|
||||||
plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
# plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
||||||
target_spectrogram=target_spectrogram,
|
# target_spectrogram=target_spectrogram,
|
||||||
max_len=target_spectrogram.size // hparams.num_mels)
|
# max_len=target_spectrogram.size // hparams.num_mels)
|
||||||
|
plot_spectrogram_and_trace(
|
||||||
|
mel_prediction,
|
||||||
|
str(spec_fpath),
|
||||||
|
title=title_str,
|
||||||
|
target_spectrogram=target_spectrogram,
|
||||||
|
max_len=target_spectrogram.size // hparams.num_mels,
|
||||||
|
sw=sw,
|
||||||
|
step=step)
|
||||||
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
|
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
|
||||||
|
|
|
@ -74,3 +74,42 @@ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, targ
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(path, format="png")
|
plt.savefig(path, format="png")
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram_and_trace(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False, sw=None, step=0):
|
||||||
|
if max_len is not None:
|
||||||
|
target_spectrogram = target_spectrogram[:max_len]
|
||||||
|
pred_spectrogram = pred_spectrogram[:max_len]
|
||||||
|
|
||||||
|
if split_title:
|
||||||
|
title = split_title_line(title)
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(10, 8))
|
||||||
|
# Set common labels
|
||||||
|
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
|
||||||
|
|
||||||
|
#target spectrogram subplot
|
||||||
|
if target_spectrogram is not None:
|
||||||
|
ax1 = fig.add_subplot(311)
|
||||||
|
ax2 = fig.add_subplot(312)
|
||||||
|
|
||||||
|
if auto_aspect:
|
||||||
|
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
|
||||||
|
else:
|
||||||
|
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
|
||||||
|
ax1.set_title("Target Mel-Spectrogram")
|
||||||
|
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
|
||||||
|
ax2.set_title("Predicted Mel-Spectrogram")
|
||||||
|
else:
|
||||||
|
ax2 = fig.add_subplot(211)
|
||||||
|
|
||||||
|
if auto_aspect:
|
||||||
|
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
|
||||||
|
else:
|
||||||
|
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
|
||||||
|
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(path, format="png")
|
||||||
|
sw.add_figure("spectrogram", fig, step)
|
||||||
|
plt.close()
|
|
@ -21,6 +21,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
|
||||||
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
||||||
"model.")
|
"model.")
|
||||||
|
parser.add_argument("-l", "--log_every", type=int, default=200, help= \
|
||||||
|
"Number of steps between summary the training info in tensorboard")
|
||||||
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
||||||
"Do not load any saved model and restart from scratch.")
|
"Do not load any saved model and restart from scratch.")
|
||||||
parser.add_argument("--hparams", default="",
|
parser.add_argument("--hparams", default="",
|
||||||
|
|
|
@ -91,6 +91,14 @@ def save_attention(attn, path) :
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def save_and_trace_attention(attn, path, sw, step):
|
||||||
|
fig = plt.figure(figsize=(12, 6))
|
||||||
|
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
|
||||||
|
fig.savefig(f'{path}.png', bbox_inches='tight')
|
||||||
|
sw.add_figure('attention', fig, step)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
def save_spectrogram(M, path, length=None) :
|
def save_spectrogram(M, path, length=None) :
|
||||||
M = np.flip(M, axis=0)
|
M = np.flip(M, axis=0)
|
||||||
if length : M = M[:, :length]
|
if length : M = M[:, :length]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user