diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 5c3fce6..534b0fa 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -127,7 +127,7 @@ class CBHG(nn.Module): # Although we `_flatten_parameters()` on init, when using DataParallel # the model gets replicated, making it no longer guaranteed that the # weights are contiguous in GPU memory. Hence, we must call it again - self._flatten_parameters() + self.rnn.flatten_parameters() # Save these for later residual = x @@ -214,7 +214,7 @@ class LSA(nn.Module): self.attention = None def init_attention(self, encoder_seq_proj): - device = next(self.parameters()).device # use same device as parameters + device = encoder_seq_proj.device # use same device as parameters b, t, c = encoder_seq_proj.size() self.cumulative = torch.zeros(b, t, device=device) self.attention = torch.zeros(b, t, device=device) @@ -265,9 +265,8 @@ class Decoder(nn.Module): self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1) - def zoneout(self, prev, current, p=0.1): - device = next(self.parameters()).device # Use same device as parameters - mask = torch.zeros(prev.size(), device=device).bernoulli_(p) + def zoneout(self, prev, current, device, p=0.1): + mask = torch.zeros(prev.size(),device=device).bernoulli_(p) return prev * mask + current * (1 - mask) def forward(self, encoder_seq, encoder_seq_proj, prenet_in, @@ -275,7 +274,7 @@ class Decoder(nn.Module): # Need this for reshaping mels batch_size = encoder_seq.size(0) - + device = encoder_seq.device # Unpack the hidden and cell states attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states rnn1_cell, rnn2_cell = cell_states @@ -301,7 +300,7 @@ class Decoder(nn.Module): # Compute first Residual RNN rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) if self.training: - rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next) + rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device) else: rnn1_hidden = rnn1_hidden_next x = x + rnn1_hidden @@ -309,7 +308,7 @@ class Decoder(nn.Module): # Compute second Residual RNN rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) if self.training: - rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next) + rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device) else: rnn2_hidden = rnn2_hidden_next x = x + rnn2_hidden @@ -374,7 +373,7 @@ class Tacotron(nn.Module): return outputs def forward(self, texts, mels, speaker_embedding): - device = next(self.parameters()).device # use same device as parameters + device = texts.device # use same device as parameters self.step += 1 batch_size, _, steps = mels.size() @@ -440,7 +439,7 @@ class Tacotron(nn.Module): def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): self.eval() - device = next(self.parameters()).device # use same device as parameters + device = x.device # use same device as parameters batch_size, _ = x.size() @@ -542,8 +541,7 @@ class Tacotron(nn.Module): def load(self, path, optimizer=None): # Use device of model params as location for loaded state - device = next(self.parameters()).device - checkpoint = torch.load(str(path), map_location=device) + checkpoint = torch.load(str(path)) self.load_state_dict(checkpoint["model_state"], strict=False) if "optimizer_state" in checkpoint and optimizer is not None: