mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
4acfee2a64
* add tensorborad tracing * add log_every params
129 lines
3.1 KiB
Python
129 lines
3.1 KiB
Python
import matplotlib.pyplot as plt
|
|
import time
|
|
import numpy as np
|
|
import sys
|
|
|
|
|
|
def progbar(i, n, size=16):
|
|
done = (i * size) // n
|
|
bar = ''
|
|
for i in range(size):
|
|
bar += '█' if i <= done else '░'
|
|
return bar
|
|
|
|
|
|
def stream(message) :
|
|
try:
|
|
sys.stdout.write("\r{%s}" % message)
|
|
except:
|
|
#Remove non-ASCII characters from message
|
|
message = ''.join(i for i in message if ord(i)<128)
|
|
sys.stdout.write("\r{%s}" % message)
|
|
|
|
|
|
def simple_table(item_tuples) :
|
|
|
|
border_pattern = '+---------------------------------------'
|
|
whitespace = ' '
|
|
|
|
headings, cells, = [], []
|
|
|
|
for item in item_tuples :
|
|
|
|
heading, cell = str(item[0]), str(item[1])
|
|
|
|
pad_head = True if len(heading) < len(cell) else False
|
|
|
|
pad = abs(len(heading) - len(cell))
|
|
pad = whitespace[:pad]
|
|
|
|
pad_left = pad[:len(pad)//2]
|
|
pad_right = pad[len(pad)//2:]
|
|
|
|
if pad_head :
|
|
heading = pad_left + heading + pad_right
|
|
else :
|
|
cell = pad_left + cell + pad_right
|
|
|
|
headings += [heading]
|
|
cells += [cell]
|
|
|
|
border, head, body = '', '', ''
|
|
|
|
for i in range(len(item_tuples)) :
|
|
|
|
temp_head = f'| {headings[i]} '
|
|
temp_body = f'| {cells[i]} '
|
|
|
|
border += border_pattern[:len(temp_head)]
|
|
head += temp_head
|
|
body += temp_body
|
|
|
|
if i == len(item_tuples) - 1 :
|
|
head += '|'
|
|
body += '|'
|
|
border += '+'
|
|
|
|
print(border)
|
|
print(head)
|
|
print(border)
|
|
print(body)
|
|
print(border)
|
|
print(' ')
|
|
|
|
|
|
def time_since(started) :
|
|
elapsed = time.time() - started
|
|
m = int(elapsed // 60)
|
|
s = int(elapsed % 60)
|
|
if m >= 60 :
|
|
h = int(m // 60)
|
|
m = m % 60
|
|
return f'{h}h {m}m {s}s'
|
|
else :
|
|
return f'{m}m {s}s'
|
|
|
|
|
|
def save_attention(attn, path) :
|
|
fig = plt.figure(figsize=(12, 6))
|
|
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
|
|
fig.savefig(f'{path}.png', bbox_inches='tight')
|
|
plt.close(fig)
|
|
|
|
|
|
def save_and_trace_attention(attn, path, sw, step):
|
|
fig = plt.figure(figsize=(12, 6))
|
|
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
|
|
fig.savefig(f'{path}.png', bbox_inches='tight')
|
|
sw.add_figure('attention', fig, step)
|
|
plt.close(fig)
|
|
|
|
|
|
def save_spectrogram(M, path, length=None) :
|
|
M = np.flip(M, axis=0)
|
|
if length : M = M[:, :length]
|
|
fig = plt.figure(figsize=(12, 6))
|
|
plt.imshow(M, interpolation='nearest', aspect='auto')
|
|
fig.savefig(f'{path}.png', bbox_inches='tight')
|
|
plt.close(fig)
|
|
|
|
|
|
def plot(array) :
|
|
fig = plt.figure(figsize=(30, 5))
|
|
ax = fig.add_subplot(111)
|
|
ax.xaxis.label.set_color('grey')
|
|
ax.yaxis.label.set_color('grey')
|
|
ax.xaxis.label.set_fontsize(23)
|
|
ax.yaxis.label.set_fontsize(23)
|
|
ax.tick_params(axis='x', colors='grey', labelsize=23)
|
|
ax.tick_params(axis='y', colors='grey', labelsize=23)
|
|
plt.plot(array)
|
|
|
|
|
|
def plot_spec(M) :
|
|
M = np.flip(M, axis=0)
|
|
plt.figure(figsize=(18,4))
|
|
plt.imshow(M, interpolation='nearest', aspect='auto')
|
|
plt.show()
|
|
|