diff --git a/synthesizer/models/base.py b/synthesizer/models/base.py new file mode 100644 index 0000000..13b32a1 --- /dev/null +++ b/synthesizer/models/base.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import imp +import numpy as np + +class Base(nn.Module): + def __init__(self, stop_threshold): + super().__init__() + + self.init_model() + self.num_params() + + self.register_buffer("step", torch.zeros(1, dtype=torch.long)) + self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32)) + + @property + def r(self): + return self.decoder.r.item() + + @r.setter + def r(self, value): + self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) + + def init_model(self): + for p in self.parameters(): + if p.dim() > 1: nn.init.xavier_uniform_(p) + + def finetune_partial(self, whitelist_layers): + self.zero_grad() + for name, child in self.named_children(): + if name in whitelist_layers: + print("Trainable Layer: %s" % name) + print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()])) + for param in child.parameters(): + param.requires_grad = False + + def get_step(self): + return self.step.data.item() + + def reset_step(self): + # assignment to parameters or buffers is overloaded, updates internal dict entry + self.step = self.step.data.new_tensor(1) + + def log(self, path, msg): + with open(path, "a") as f: + print(msg, file=f) + + def load(self, path, device, optimizer=None): + # Use device of model params as location for loaded state + checkpoint = torch.load(str(path), map_location=device) + self.load_state_dict(checkpoint["model_state"], strict=False) + + if "optimizer_state" in checkpoint and optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer_state"]) + + def save(self, path, optimizer=None): + if optimizer is not None: + torch.save({ + "model_state": self.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, str(path)) + else: + torch.save({ + "model_state": self.state_dict(), + }, str(path)) + + + def num_params(self, print_out=True): + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print("Trainable Parameters: %.3fM" % parameters) + return parameters diff --git a/synthesizer/models/sublayer/__init__.py b/synthesizer/models/sublayer/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/synthesizer/models/sublayer/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/synthesizer/models/sublayer/cbhg.py b/synthesizer/models/sublayer/cbhg.py new file mode 100644 index 0000000..10eb6bb --- /dev/null +++ b/synthesizer/models/sublayer/cbhg.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +from .common.batch_norm_conv import BatchNormConv +from .common.highway_network import HighwayNetwork + +class CBHG(nn.Module): + def __init__(self, K, in_channels, channels, proj_channels, num_highways): + super().__init__() + + # List of all rnns to call `flatten_parameters()` on + self._to_flatten = [] + + self.bank_kernels = [i for i in range(1, K + 1)] + self.conv1d_bank = nn.ModuleList() + for k in self.bank_kernels: + conv = BatchNormConv(in_channels, channels, k) + self.conv1d_bank.append(conv) + + self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) + + self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) + self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) + + # Fix the highway input if necessary + if proj_channels[-1] != channels: + self.highway_mismatch = True + self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) + else: + self.highway_mismatch = False + + self.highways = nn.ModuleList() + for i in range(num_highways): + hn = HighwayNetwork(channels) + self.highways.append(hn) + + self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True) + self._to_flatten.append(self.rnn) + + # Avoid fragmentation of RNN parameters and associated warning + self._flatten_parameters() + + def forward(self, x): + # Although we `_flatten_parameters()` on init, when using DataParallel + # the model gets replicated, making it no longer guaranteed that the + # weights are contiguous in GPU memory. Hence, we must call it again + self.rnn.flatten_parameters() + + # Save these for later + residual = x + seq_len = x.size(-1) + conv_bank = [] + + # Convolution Bank + for conv in self.conv1d_bank: + c = conv(x) # Convolution + conv_bank.append(c[:, :, :seq_len]) + + # Stack along the channel axis + conv_bank = torch.cat(conv_bank, dim=1) + + # dump the last padding to fit residual + x = self.maxpool(conv_bank)[:, :, :seq_len] + + # Conv1d projections + x = self.conv_project1(x) + x = self.conv_project2(x) + + # Residual Connect + x = x + residual + + # Through the highways + x = x.transpose(1, 2) + if self.highway_mismatch is True: + x = self.pre_highway(x) + for h in self.highways: x = h(x) + + # And then the RNN + x, _ = self.rnn(x) + return x + + def _flatten_parameters(self): + """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used + to improve efficiency and avoid PyTorch yelling at us.""" + [m.flatten_parameters() for m in self._to_flatten] + diff --git a/synthesizer/models/sublayer/common/batch_norm_conv.py b/synthesizer/models/sublayer/common/batch_norm_conv.py new file mode 100644 index 0000000..0d07a4a --- /dev/null +++ b/synthesizer/models/sublayer/common/batch_norm_conv.py @@ -0,0 +1,14 @@ +import torch.nn as nn +import torch.nn.functional as F + +class BatchNormConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel, relu=True): + super().__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) + self.bnorm = nn.BatchNorm1d(out_channels) + self.relu = relu + + def forward(self, x): + x = self.conv(x) + x = F.relu(x) if self.relu is True else x + return self.bnorm(x) \ No newline at end of file diff --git a/synthesizer/models/sublayer/common/highway_network.py b/synthesizer/models/sublayer/common/highway_network.py new file mode 100644 index 0000000..d311c69 --- /dev/null +++ b/synthesizer/models/sublayer/common/highway_network.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class HighwayNetwork(nn.Module): + def __init__(self, size): + super().__init__() + self.W1 = nn.Linear(size, size) + self.W2 = nn.Linear(size, size) + self.W1.bias.data.fill_(0.) + + def forward(self, x): + x1 = self.W1(x) + x2 = self.W2(x) + g = torch.sigmoid(x2) + y = g * F.relu(x1) + (1. - g) * x + return y diff --git a/synthesizer/models/global_style_token.py b/synthesizer/models/sublayer/global_style_token.py similarity index 100% rename from synthesizer/models/global_style_token.py rename to synthesizer/models/sublayer/global_style_token.py diff --git a/synthesizer/models/sublayer/lsa.py b/synthesizer/models/sublayer/lsa.py new file mode 100644 index 0000000..9a32913 --- /dev/null +++ b/synthesizer/models/sublayer/lsa.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LSA(nn.Module): + def __init__(self, attn_dim, kernel_size=31, filters=32): + super().__init__() + self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True) + self.L = nn.Linear(filters, attn_dim, bias=False) + self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term + self.v = nn.Linear(attn_dim, 1, bias=False) + self.cumulative = None + self.attention = None + + def init_attention(self, encoder_seq_proj): + device = encoder_seq_proj.device # use same device as parameters + b, t, c = encoder_seq_proj.size() + self.cumulative = torch.zeros(b, t, device=device) + self.attention = torch.zeros(b, t, device=device) + + def forward(self, encoder_seq_proj, query, t, chars): + + if t == 0: self.init_attention(encoder_seq_proj) + + processed_query = self.W(query).unsqueeze(1) + + location = self.cumulative.unsqueeze(1) + processed_loc = self.L(self.conv(location).transpose(1, 2)) + + u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) + u = u.squeeze(-1) + + # Mask zero padding chars + u = u * (chars != 0).float() + + # Smooth Attention + # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) + scores = F.softmax(u, dim=1) + self.attention = scores + self.cumulative = self.cumulative + self.attention + + return scores.unsqueeze(-1).transpose(1, 2) diff --git a/synthesizer/models/sublayer/pre_net.py b/synthesizer/models/sublayer/pre_net.py new file mode 100644 index 0000000..3c8ebb8 --- /dev/null +++ b/synthesizer/models/sublayer/pre_net.py @@ -0,0 +1,18 @@ +import torch.nn as nn +import torch.nn.functional as F + +class PreNet(nn.Module): + def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): + super().__init__() + self.fc1 = nn.Linear(in_dims, fc1_dims) + self.fc2 = nn.Linear(fc1_dims, fc2_dims) + self.p = dropout + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = F.dropout(x, self.p, training=True) + x = self.fc2(x) + x = F.relu(x) + x = F.dropout(x, self.p, training=True) + return x diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 9cfabf7..a76a50f 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -1,29 +1,14 @@ -import os -from matplotlib.pyplot import step -import numpy as np + import torch import torch.nn as nn -import torch.nn.functional as F -from synthesizer.models.global_style_token import GlobalStyleToken +from .sublayer.global_style_token import GlobalStyleToken +from .sublayer.pre_net import PreNet +from .sublayer.cbhg import CBHG +from .sublayer.lsa import LSA +from .base import Base from synthesizer.gst_hyperparameters import GSTHyperparameters as gst_hp from synthesizer.hparams import hparams - -class HighwayNetwork(nn.Module): - def __init__(self, size): - super().__init__() - self.W1 = nn.Linear(size, size) - self.W2 = nn.Linear(size, size) - self.W1.bias.data.fill_(0.) - - def forward(self, x): - x1 = self.W1(x) - x2 = self.W2(x) - g = torch.sigmoid(x2) - y = g * F.relu(x1) + (1. - g) * x - return y - - class Encoder(nn.Module): def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout): super().__init__() @@ -36,213 +21,11 @@ class Encoder(nn.Module): proj_channels=[cbhg_channels, cbhg_channels], num_highways=num_highways) - def forward(self, x, speaker_embedding=None): + def forward(self, x): x = self.embedding(x) x = self.pre_net(x) x.transpose_(1, 2) - x = self.cbhg(x) - if speaker_embedding is not None: - x = self.add_speaker_embedding(x, speaker_embedding) - return x - - def add_speaker_embedding(self, x, speaker_embedding): - # SV2TTS - # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims) - # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size) - # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size)) - # This concats the speaker embedding for each char in the encoder output - - # Save the dimensions as human-readable names - batch_size = x.size()[0] - num_chars = x.size()[1] - - if speaker_embedding.dim() == 1: - idx = 0 - else: - idx = 1 - - # Start by making a copy of each speaker embedding to match the input text length - # The output of this has size (batch_size, num_chars * speaker_embedding_size) - speaker_embedding_size = speaker_embedding.size()[idx] - e = speaker_embedding.repeat_interleave(num_chars, dim=idx) - - # Reshape it and transpose - e = e.reshape(batch_size, speaker_embedding_size, num_chars) - e = e.transpose(1, 2) - - # Concatenate the tiled speaker embedding with the encoder output - x = torch.cat((x, e), 2) - return x - - -class BatchNormConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel, relu=True): - super().__init__() - self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) - self.bnorm = nn.BatchNorm1d(out_channels) - self.relu = relu - - def forward(self, x): - x = self.conv(x) - x = F.relu(x) if self.relu is True else x - return self.bnorm(x) - - -class CBHG(nn.Module): - def __init__(self, K, in_channels, channels, proj_channels, num_highways): - super().__init__() - - # List of all rnns to call `flatten_parameters()` on - self._to_flatten = [] - - self.bank_kernels = [i for i in range(1, K + 1)] - self.conv1d_bank = nn.ModuleList() - for k in self.bank_kernels: - conv = BatchNormConv(in_channels, channels, k) - self.conv1d_bank.append(conv) - - self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) - - self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) - self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) - - # Fix the highway input if necessary - if proj_channels[-1] != channels: - self.highway_mismatch = True - self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) - else: - self.highway_mismatch = False - - self.highways = nn.ModuleList() - for i in range(num_highways): - hn = HighwayNetwork(channels) - self.highways.append(hn) - - self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True) - self._to_flatten.append(self.rnn) - - # Avoid fragmentation of RNN parameters and associated warning - self._flatten_parameters() - - def forward(self, x): - # Although we `_flatten_parameters()` on init, when using DataParallel - # the model gets replicated, making it no longer guaranteed that the - # weights are contiguous in GPU memory. Hence, we must call it again - self.rnn.flatten_parameters() - - # Save these for later - residual = x - seq_len = x.size(-1) - conv_bank = [] - - # Convolution Bank - for conv in self.conv1d_bank: - c = conv(x) # Convolution - conv_bank.append(c[:, :, :seq_len]) - - # Stack along the channel axis - conv_bank = torch.cat(conv_bank, dim=1) - - # dump the last padding to fit residual - x = self.maxpool(conv_bank)[:, :, :seq_len] - - # Conv1d projections - x = self.conv_project1(x) - x = self.conv_project2(x) - - # Residual Connect - x = x + residual - - # Through the highways - x = x.transpose(1, 2) - if self.highway_mismatch is True: - x = self.pre_highway(x) - for h in self.highways: x = h(x) - - # And then the RNN - x, _ = self.rnn(x) - return x - - def _flatten_parameters(self): - """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used - to improve efficiency and avoid PyTorch yelling at us.""" - [m.flatten_parameters() for m in self._to_flatten] - -class PreNet(nn.Module): - def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): - super().__init__() - self.fc1 = nn.Linear(in_dims, fc1_dims) - self.fc2 = nn.Linear(fc1_dims, fc2_dims) - self.p = dropout - - def forward(self, x): - x = self.fc1(x) - x = F.relu(x) - x = F.dropout(x, self.p, training=True) - x = self.fc2(x) - x = F.relu(x) - x = F.dropout(x, self.p, training=True) - return x - - -class Attention(nn.Module): - def __init__(self, attn_dims): - super().__init__() - self.W = nn.Linear(attn_dims, attn_dims, bias=False) - self.v = nn.Linear(attn_dims, 1, bias=False) - - def forward(self, encoder_seq_proj, query, t): - - # print(encoder_seq_proj.shape) - # Transform the query vector - query_proj = self.W(query).unsqueeze(1) - - # Compute the scores - u = self.v(torch.tanh(encoder_seq_proj + query_proj)) - scores = F.softmax(u, dim=1) - - return scores.transpose(1, 2) - - -class LSA(nn.Module): - def __init__(self, attn_dim, kernel_size=31, filters=32): - super().__init__() - self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True) - self.L = nn.Linear(filters, attn_dim, bias=False) - self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term - self.v = nn.Linear(attn_dim, 1, bias=False) - self.cumulative = None - self.attention = None - - def init_attention(self, encoder_seq_proj): - device = encoder_seq_proj.device # use same device as parameters - b, t, c = encoder_seq_proj.size() - self.cumulative = torch.zeros(b, t, device=device) - self.attention = torch.zeros(b, t, device=device) - - def forward(self, encoder_seq_proj, query, t, chars): - - if t == 0: self.init_attention(encoder_seq_proj) - - processed_query = self.W(query).unsqueeze(1) - - location = self.cumulative.unsqueeze(1) - processed_loc = self.L(self.conv(location).transpose(1, 2)) - - u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) - u = u.squeeze(-1) - - # Mask zero padding chars - u = u * (chars != 0).float() - - # Smooth Attention - # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) - scores = F.softmax(u, dim=1) - self.attention = scores - self.cumulative = self.cumulative + self.attention - - return scores.unsqueeze(-1).transpose(1, 2) - + return self.cbhg(x) class Decoder(nn.Module): # Class variable because its value doesn't change between classes @@ -327,12 +110,11 @@ class Decoder(nn.Module): return mels, scores, hidden_states, cell_states, context_vec, stop_tokens - -class Tacotron(nn.Module): +class Tacotron(Base): def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways, dropout, stop_threshold, speaker_embedding_size): - super().__init__() + super().__init__(stop_threshold) self.n_mels = n_mels self.lstm_dims = lstm_dims self.encoder_dims = encoder_dims @@ -352,20 +134,6 @@ class Tacotron(nn.Module): [postnet_dims, fft_bins], num_highways) self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False) - self.init_model() - self.num_params() - - self.register_buffer("step", torch.zeros(1, dtype=torch.long)) - self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32)) - - @property - def r(self): - return self.decoder.r.item() - - @r.setter - def r(self, value): - self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) - @staticmethod def _concat_speaker_embedding(outputs, speaker_embeddings): speaker_embeddings_ = speaker_embeddings.expand( @@ -373,6 +141,36 @@ class Tacotron(nn.Module): outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) return outputs + @staticmethod + def _add_speaker_embedding(x, speaker_embedding): + # SV2TTS + # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims) + # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size) + # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size)) + # This concats the speaker embedding for each char in the encoder output + + # Save the dimensions as human-readable names + batch_size = x.size()[0] + num_chars = x.size()[1] + + if speaker_embedding.dim() == 1: + idx = 0 + else: + idx = 1 + + # Start by making a copy of each speaker embedding to match the input text length + # The output of this has size (batch_size, num_chars * speaker_embedding_size) + speaker_embedding_size = speaker_embedding.size()[idx] + e = speaker_embedding.repeat_interleave(num_chars, dim=idx) + + # Reshape it and transpose + e = e.reshape(batch_size, speaker_embedding_size, num_chars) + e = e.transpose(1, 2) + + # Concatenate the tiled speaker embedding with the encoder output + x = torch.cat((x, e), 2) + return x + def forward(self, texts, mels, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): device = texts.device # use same device as parameters @@ -405,8 +203,11 @@ class Tacotron(nn.Module): # SV2TTS: Run the encoder with the speaker embedding # The projection avoids unnecessary matmuls in the decoder loop - encoder_seq = self.encoder(texts, speaker_embedding) + encoder_seq = self.encoder(texts) + if speaker_embedding is not None: + encoder_seq = self._add_speaker_embedding(encoder_seq, speaker_embedding) + if hparams.use_gst and self.gst is not None: if self.training: style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced @@ -455,7 +256,6 @@ class Tacotron(nn.Module): # attn_scores = attn_scores.cpu().data.numpy() stop_outputs = torch.cat(stop_outputs, 1) - if self.training: self.train() @@ -465,131 +265,3 @@ class Tacotron(nn.Module): self.eval() mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token) return mel_outputs, linear, attn_scores - # device = x.device # use same device as parameters - - # batch_size, _ = x.size() - - # # Need to initialise all hidden states and pack into tuple for tidyness - # attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) - # rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) - # rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) - # hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) - - # # Need to initialise all lstm cell states and pack into tuple for tidyness - # rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) - # rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) - # cell_states = (rnn1_cell, rnn2_cell) - - # # Need a Frame for start of decoder loop - # go_frame = torch.zeros(batch_size, self.n_mels, device=device) - - # # Need an initial context vector - # size = self.encoder_dims + self.speaker_embedding_size - # if hparams.use_gst: - # size += gst_hp.E - # context_vec = torch.zeros(batch_size, size, device=device) - - # # SV2TTS: Run the encoder with the speaker embedding - # # The projection avoids unnecessary matmuls in the decoder loop - # encoder_seq = self.encoder(x, speaker_embedding) - - # # put after encoder - # if hparams.use_gst and self.gst is not None: - # if style_idx >= 0 and style_idx < 10: - # query = torch.zeros(1, 1, self.gst.stl.attention.num_units) - # if device.type == 'cuda': - # query = query.cuda() - # gst_embed = torch.tanh(self.gst.stl.embed) - # key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1) - # style_embed = self.gst.stl.attention(query, key) - # else: - # speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) - # style_embed = self.gst(speaker_embedding_style, speaker_embedding) - # encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) - # # style_embed = style_embed.expand_as(encoder_seq) - # # encoder_seq = torch.cat((encoder_seq, style_embed), 2) - # encoder_seq_proj = self.encoder_proj(encoder_seq) - - # # Need a couple of lists for outputs - # mel_outputs, attn_scores, stop_outputs = [], [], [] - - # # Run the decoder loop - # for t in range(0, steps, self.r): - # prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame - # mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ - # self.decoder(encoder_seq, encoder_seq_proj, prenet_in, - # hidden_states, cell_states, context_vec, t, x) - # mel_outputs.append(mel_frames) - # attn_scores.append(scores) - # stop_outputs.extend([stop_tokens] * self.r) - # # Stop the loop when all stop tokens in batch exceed threshold - # if (stop_tokens * 10 > min_stop_token).all() and t > 10: break - - # # Concat the mel outputs into sequence - # mel_outputs = torch.cat(mel_outputs, dim=2) - - # # Post-Process for Linear Spectrograms - # postnet_out = self.postnet(mel_outputs) - # linear = self.post_proj(postnet_out) - - - # linear = linear.transpose(1, 2) - - # # For easy visualisation - # attn_scores = torch.cat(attn_scores, 1) - # stop_outputs = torch.cat(stop_outputs, 1) - - # self.train() - - # return mel_outputs, linear, attn_scores - - def init_model(self): - for p in self.parameters(): - if p.dim() > 1: nn.init.xavier_uniform_(p) - - def finetune_partial(self, whitelist_layers): - self.zero_grad() - for name, child in self.named_children(): - if name in whitelist_layers: - print("Trainable Layer: %s" % name) - print("Trainable Parameters: %.3f" % sum([np.prod(p.size()) for p in child.parameters()])) - for param in child.parameters(): - param.requires_grad = False - - def get_step(self): - return self.step.data.item() - - def reset_step(self): - # assignment to parameters or buffers is overloaded, updates internal dict entry - self.step = self.step.data.new_tensor(1) - - def log(self, path, msg): - with open(path, "a") as f: - print(msg, file=f) - - def load(self, path, device, optimizer=None): - # Use device of model params as location for loaded state - checkpoint = torch.load(str(path), map_location=device) - self.load_state_dict(checkpoint["model_state"], strict=False) - - if "optimizer_state" in checkpoint and optimizer is not None: - optimizer.load_state_dict(checkpoint["optimizer_state"]) - - def save(self, path, optimizer=None): - if optimizer is not None: - torch.save({ - "model_state": self.state_dict(), - "optimizer_state": optimizer.state_dict(), - }, str(path)) - else: - torch.save({ - "model_state": self.state_dict(), - }, str(path)) - - - def num_params(self, print_out=True): - parameters = filter(lambda p: p.requires_grad, self.parameters()) - parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 - if print_out: - print("Trainable Parameters: %.3fM" % parameters) - return parameters diff --git a/synthesizer/train.py b/synthesizer/train.py index 8799e84..bd1f8a0 100644 --- a/synthesizer/train.py +++ b/synthesizer/train.py @@ -265,7 +265,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, hparams=hparams, sw=sw) MAX_SAVED_COUNT = 20 - if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT: + if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT == 0: # clean up and save last MAX_SAVED_COUNT; plots = next(os.walk(plot_dir), (None, None, []))[2] for plot in plots[-MAX_SAVED_COUNT:]: