mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
tacotron.py-Multi GPU with DataParallel (#231)
This commit is contained in:
parent
26fe4a047d
commit
b50c7984ab
|
@ -127,7 +127,7 @@ class CBHG(nn.Module):
|
||||||
# Although we `_flatten_parameters()` on init, when using DataParallel
|
# Although we `_flatten_parameters()` on init, when using DataParallel
|
||||||
# the model gets replicated, making it no longer guaranteed that the
|
# the model gets replicated, making it no longer guaranteed that the
|
||||||
# weights are contiguous in GPU memory. Hence, we must call it again
|
# weights are contiguous in GPU memory. Hence, we must call it again
|
||||||
self._flatten_parameters()
|
self.rnn.flatten_parameters()
|
||||||
|
|
||||||
# Save these for later
|
# Save these for later
|
||||||
residual = x
|
residual = x
|
||||||
|
@ -214,7 +214,7 @@ class LSA(nn.Module):
|
||||||
self.attention = None
|
self.attention = None
|
||||||
|
|
||||||
def init_attention(self, encoder_seq_proj):
|
def init_attention(self, encoder_seq_proj):
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = encoder_seq_proj.device # use same device as parameters
|
||||||
b, t, c = encoder_seq_proj.size()
|
b, t, c = encoder_seq_proj.size()
|
||||||
self.cumulative = torch.zeros(b, t, device=device)
|
self.cumulative = torch.zeros(b, t, device=device)
|
||||||
self.attention = torch.zeros(b, t, device=device)
|
self.attention = torch.zeros(b, t, device=device)
|
||||||
|
@ -265,9 +265,8 @@ class Decoder(nn.Module):
|
||||||
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
||||||
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
||||||
|
|
||||||
def zoneout(self, prev, current, p=0.1):
|
def zoneout(self, prev, current, device, p=0.1):
|
||||||
device = next(self.parameters()).device # Use same device as parameters
|
mask = torch.zeros(prev.size(),device=device).bernoulli_(p)
|
||||||
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
|
||||||
return prev * mask + current * (1 - mask)
|
return prev * mask + current * (1 - mask)
|
||||||
|
|
||||||
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
||||||
|
@ -275,7 +274,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
# Need this for reshaping mels
|
# Need this for reshaping mels
|
||||||
batch_size = encoder_seq.size(0)
|
batch_size = encoder_seq.size(0)
|
||||||
|
device = encoder_seq.device
|
||||||
# Unpack the hidden and cell states
|
# Unpack the hidden and cell states
|
||||||
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
||||||
rnn1_cell, rnn2_cell = cell_states
|
rnn1_cell, rnn2_cell = cell_states
|
||||||
|
@ -301,7 +300,7 @@ class Decoder(nn.Module):
|
||||||
# Compute first Residual RNN
|
# Compute first Residual RNN
|
||||||
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
||||||
if self.training:
|
if self.training:
|
||||||
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
|
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
|
||||||
else:
|
else:
|
||||||
rnn1_hidden = rnn1_hidden_next
|
rnn1_hidden = rnn1_hidden_next
|
||||||
x = x + rnn1_hidden
|
x = x + rnn1_hidden
|
||||||
|
@ -309,7 +308,7 @@ class Decoder(nn.Module):
|
||||||
# Compute second Residual RNN
|
# Compute second Residual RNN
|
||||||
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
||||||
if self.training:
|
if self.training:
|
||||||
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
|
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
|
||||||
else:
|
else:
|
||||||
rnn2_hidden = rnn2_hidden_next
|
rnn2_hidden = rnn2_hidden_next
|
||||||
x = x + rnn2_hidden
|
x = x + rnn2_hidden
|
||||||
|
@ -374,7 +373,7 @@ class Tacotron(nn.Module):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def forward(self, texts, mels, speaker_embedding):
|
def forward(self, texts, mels, speaker_embedding):
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = texts.device # use same device as parameters
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
batch_size, _, steps = mels.size()
|
batch_size, _, steps = mels.size()
|
||||||
|
@ -440,7 +439,7 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
||||||
self.eval()
|
self.eval()
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = x.device # use same device as parameters
|
||||||
|
|
||||||
batch_size, _ = x.size()
|
batch_size, _ = x.size()
|
||||||
|
|
||||||
|
@ -542,8 +541,7 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
def load(self, path, optimizer=None):
|
def load(self, path, optimizer=None):
|
||||||
# Use device of model params as location for loaded state
|
# Use device of model params as location for loaded state
|
||||||
device = next(self.parameters()).device
|
checkpoint = torch.load(str(path))
|
||||||
checkpoint = torch.load(str(path), map_location=device)
|
|
||||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||||
|
|
||||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user