2021-08-07 11:56:00 +08:00
|
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
|
2023-02-04 14:13:38 +08:00
|
|
|
MATPLOTLIB_FLAG = False
|
2021-08-07 11:56:00 +08:00
|
|
|
|
|
|
|
def split_title_line(title_text, max_words=5):
|
|
|
|
"""
|
|
|
|
A function that splits any string based on specific character
|
|
|
|
(returning it with the string), with maximum number of words on it
|
|
|
|
"""
|
|
|
|
seq = title_text.split()
|
|
|
|
return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
|
|
|
|
|
|
|
|
def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
|
|
|
|
if max_len is not None:
|
|
|
|
alignment = alignment[:, :max_len]
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(8, 6))
|
|
|
|
ax = fig.add_subplot(111)
|
|
|
|
|
|
|
|
im = ax.imshow(
|
|
|
|
alignment,
|
|
|
|
aspect="auto",
|
|
|
|
origin="lower",
|
|
|
|
interpolation="none")
|
|
|
|
fig.colorbar(im, ax=ax)
|
|
|
|
xlabel = "Decoder timestep"
|
|
|
|
|
|
|
|
if split_title:
|
|
|
|
title = split_title_line(title)
|
|
|
|
|
|
|
|
plt.xlabel(xlabel)
|
|
|
|
plt.title(title)
|
|
|
|
plt.ylabel("Encoder timestep")
|
|
|
|
plt.tight_layout()
|
|
|
|
plt.savefig(path, format="png")
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
|
|
|
|
if max_len is not None:
|
|
|
|
target_spectrogram = target_spectrogram[:max_len]
|
|
|
|
pred_spectrogram = pred_spectrogram[:max_len]
|
|
|
|
|
|
|
|
if split_title:
|
|
|
|
title = split_title_line(title)
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(10, 8))
|
|
|
|
# Set common labels
|
|
|
|
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
|
|
|
|
|
|
|
|
#target spectrogram subplot
|
|
|
|
if target_spectrogram is not None:
|
|
|
|
ax1 = fig.add_subplot(311)
|
|
|
|
ax2 = fig.add_subplot(312)
|
|
|
|
|
|
|
|
if auto_aspect:
|
|
|
|
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
|
|
|
|
else:
|
|
|
|
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
|
|
|
|
ax1.set_title("Target Mel-Spectrogram")
|
|
|
|
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
|
|
|
|
ax2.set_title("Predicted Mel-Spectrogram")
|
|
|
|
else:
|
|
|
|
ax2 = fig.add_subplot(211)
|
|
|
|
|
|
|
|
if auto_aspect:
|
|
|
|
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
|
|
|
|
else:
|
|
|
|
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
|
|
|
|
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
plt.savefig(path, format="png")
|
|
|
|
plt.close()
|
2021-09-25 17:06:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
def plot_spectrogram_and_trace(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False, sw=None, step=0):
|
|
|
|
if max_len is not None:
|
|
|
|
target_spectrogram = target_spectrogram[:max_len]
|
|
|
|
pred_spectrogram = pred_spectrogram[:max_len]
|
|
|
|
|
|
|
|
if split_title:
|
|
|
|
title = split_title_line(title)
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(10, 8))
|
|
|
|
# Set common labels
|
|
|
|
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
|
|
|
|
|
|
|
|
#target spectrogram subplot
|
|
|
|
if target_spectrogram is not None:
|
|
|
|
ax1 = fig.add_subplot(311)
|
|
|
|
ax2 = fig.add_subplot(312)
|
|
|
|
|
|
|
|
if auto_aspect:
|
|
|
|
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
|
|
|
|
else:
|
|
|
|
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
|
|
|
|
ax1.set_title("Target Mel-Spectrogram")
|
|
|
|
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
|
|
|
|
ax2.set_title("Predicted Mel-Spectrogram")
|
|
|
|
else:
|
|
|
|
ax2 = fig.add_subplot(211)
|
|
|
|
|
|
|
|
if auto_aspect:
|
|
|
|
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
|
|
|
|
else:
|
|
|
|
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
|
|
|
|
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
plt.savefig(path, format="png")
|
|
|
|
sw.add_figure("spectrogram", fig, step)
|
2023-02-04 14:13:38 +08:00
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
def plot_spectrogram_to_numpy(spectrogram):
|
|
|
|
global MATPLOTLIB_FLAG
|
|
|
|
if not MATPLOTLIB_FLAG:
|
|
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
MATPLOTLIB_FLAG = True
|
|
|
|
import matplotlib.pylab as plt
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10,2))
|
|
|
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
|
|
|
interpolation='none')
|
|
|
|
plt.colorbar(im, ax=ax)
|
|
|
|
plt.xlabel("Frames")
|
|
|
|
plt.ylabel("Channels")
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
fig.canvas.draw()
|
|
|
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
|
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
|
plt.close()
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
def plot_alignment_to_numpy(alignment, info=None):
|
|
|
|
global MATPLOTLIB_FLAG
|
|
|
|
if not MATPLOTLIB_FLAG:
|
|
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
MATPLOTLIB_FLAG = True
|
|
|
|
import matplotlib.pylab as plt
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 4))
|
|
|
|
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
|
|
|
|
interpolation='none')
|
|
|
|
fig.colorbar(im, ax=ax)
|
|
|
|
xlabel = 'Decoder timestep'
|
|
|
|
if info is not None:
|
|
|
|
xlabel += '\n\n' + info
|
|
|
|
plt.xlabel(xlabel)
|
|
|
|
plt.ylabel('Encoder timestep')
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
fig.canvas.draw()
|
|
|
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
|
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
|
|
plt.close()
|
|
|
|
return data
|