mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Fix inference on cpu device (#241)
This commit is contained in:
parent
a4daf42868
commit
4728863f9d
@ -62,7 +62,7 @@ class Synthesizer:
|
|||||||
stop_threshold=hparams.tts_stop_threshold,
|
stop_threshold=hparams.tts_stop_threshold,
|
||||||
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
||||||
|
|
||||||
self._model.load(self.model_fpath)
|
self._model.load(self.model_fpath, self.device)
|
||||||
self._model.eval()
|
self._model.eval()
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -470,7 +470,9 @@ class Tacotron(nn.Module):
|
|||||||
# put after encoder
|
# put after encoder
|
||||||
if hparams.use_gst and self.gst is not None:
|
if hparams.use_gst and self.gst is not None:
|
||||||
if style_idx >= 0 and style_idx < 10:
|
if style_idx >= 0 and style_idx < 10:
|
||||||
query = torch.zeros(1, 1, self.gst.stl.attention.num_units).cuda()
|
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
query = query.cuda()
|
||||||
gst_embed = torch.tanh(self.gst.stl.embed)
|
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||||
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||||
style_embed = self.gst.stl.attention(query, key)
|
style_embed = self.gst.stl.attention(query, key)
|
||||||
@ -539,9 +541,9 @@ class Tacotron(nn.Module):
|
|||||||
with open(path, "a") as f:
|
with open(path, "a") as f:
|
||||||
print(msg, file=f)
|
print(msg, file=f)
|
||||||
|
|
||||||
def load(self, path, optimizer=None):
|
def load(self, path, device, optimizer=None):
|
||||||
# Use device of model params as location for loaded state
|
# Use device of model params as location for loaded state
|
||||||
checkpoint = torch.load(str(path))
|
checkpoint = torch.load(str(path), map_location=device)
|
||||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||||
|
|
||||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||||
|
@ -45,7 +45,7 @@ def run_synthesis(in_dir, out_dir, model_dir, hparams):
|
|||||||
model_dir = Path(model_dir)
|
model_dir = Path(model_dir)
|
||||||
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
|
model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
|
||||||
print("\nLoading weights at %s" % model_fpath)
|
print("\nLoading weights at %s" % model_fpath)
|
||||||
model.load(model_fpath)
|
model.load(model_fpath, device)
|
||||||
print("Tacotron weights loaded from step %d" % model.step)
|
print("Tacotron weights loaded from step %d" % model.step)
|
||||||
|
|
||||||
# Synthesize using same reduction factor as the model is currently trained
|
# Synthesize using same reduction factor as the model is currently trained
|
||||||
|
@ -111,7 +111,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
print("\nLoading weights at %s" % weights_fpath)
|
print("\nLoading weights at %s" % weights_fpath)
|
||||||
model.load(weights_fpath, optimizer)
|
model.load(weights_fpath, device, optimizer)
|
||||||
print("Tacotron weights loaded from step %d" % model.step)
|
print("Tacotron weights loaded from step %d" % model.step)
|
||||||
|
|
||||||
# Initialize the dataset
|
# Initialize the dataset
|
||||||
|
Loading…
x
Reference in New Issue
Block a user