mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Fix bug of importing GST and add more parameters in toolbox
This commit is contained in:
parent
aa35fb3139
commit
31bc6656c3
|
@ -70,7 +70,7 @@ class Synthesizer:
|
||||||
|
|
||||||
def synthesize_spectrograms(self, texts: List[str],
|
def synthesize_spectrograms(self, texts: List[str],
|
||||||
embeddings: Union[np.ndarray, List[np.ndarray]],
|
embeddings: Union[np.ndarray, List[np.ndarray]],
|
||||||
return_alignments=False, style_idx=0):
|
return_alignments=False, style_idx=0, min_stop_token=5):
|
||||||
"""
|
"""
|
||||||
Synthesizes mel spectrograms from texts and speaker embeddings.
|
Synthesizes mel spectrograms from texts and speaker embeddings.
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ class Synthesizer:
|
||||||
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx)
|
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token)
|
||||||
mels = mels.detach().cpu().numpy()
|
mels = mels.detach().cpu().numpy()
|
||||||
for m in mels:
|
for m in mels:
|
||||||
# Trim silence from end of each spectrogram
|
# Trim silence from end of each spectrogram
|
||||||
|
|
|
@ -419,7 +419,7 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
return mel_outputs, linear, attn_scores, stop_outputs
|
return mel_outputs, linear, attn_scores, stop_outputs
|
||||||
|
|
||||||
def generate(self, x, speaker_embedding=None, steps=200, style_idx=0):
|
def generate(self, x, speaker_embedding=None, steps=200, style_idx=0, min_stop_token=5):
|
||||||
self.eval()
|
self.eval()
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = next(self.parameters()).device # use same device as parameters
|
||||||
|
|
||||||
|
@ -454,9 +454,9 @@ class Tacotron(nn.Module):
|
||||||
scale[:] = 0.3
|
scale[:] = 0.3
|
||||||
speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32)
|
speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32)
|
||||||
speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device)
|
speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device)
|
||||||
style_embed = self.gst(speaker_embedding)
|
style_embed = self.gst(speaker_embedding)
|
||||||
style_embed = style_embed.expand_as(encoder_seq)
|
style_embed = style_embed.expand_as(encoder_seq)
|
||||||
encoder_seq = encoder_seq + style_embed
|
encoder_seq = encoder_seq + style_embed
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
# Need a couple of lists for outputs
|
# Need a couple of lists for outputs
|
||||||
|
@ -472,7 +472,7 @@ class Tacotron(nn.Module):
|
||||||
attn_scores.append(scores)
|
attn_scores.append(scores)
|
||||||
stop_outputs.extend([stop_tokens] * self.r)
|
stop_outputs.extend([stop_tokens] * self.r)
|
||||||
# Stop the loop when all stop tokens in batch exceed threshold
|
# Stop the loop when all stop tokens in batch exceed threshold
|
||||||
if (stop_tokens > 0.5).all() and t > 10: break
|
if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||||
|
|
||||||
# Concat the mel outputs into sequence
|
# Concat the mel outputs into sequence
|
||||||
mel_outputs = torch.cat(mel_outputs, dim=2)
|
mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||||
|
|
|
@ -234,7 +234,8 @@ class Toolbox:
|
||||||
texts = processed_texts
|
texts = processed_texts
|
||||||
embed = self.ui.selected_utterance.embed
|
embed = self.ui.selected_utterance.embed
|
||||||
embeds = [embed] * len(texts)
|
embeds = [embed] * len(texts)
|
||||||
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.slider.value()))
|
min_token = int(self.ui.token_slider.value())
|
||||||
|
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token)
|
||||||
breaks = [spec.shape[1] for spec in specs]
|
breaks = [spec.shape[1] for spec in specs]
|
||||||
spec = np.concatenate(specs, axis=1)
|
spec = np.concatenate(specs, axis=1)
|
||||||
|
|
||||||
|
|
|
@ -588,18 +588,36 @@ class UI(QDialog):
|
||||||
self.seed_textbox = QLineEdit()
|
self.seed_textbox = QLineEdit()
|
||||||
self.seed_textbox.setMaximumWidth(80)
|
self.seed_textbox.setMaximumWidth(80)
|
||||||
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
||||||
self.slider = QSlider(Qt.Horizontal)
|
|
||||||
self.slider.setTickInterval(1)
|
|
||||||
self.slider.setFocusPolicy(Qt.NoFocus)
|
|
||||||
self.slider.setSingleStep(1)
|
|
||||||
self.slider.setRange(-1, 9)
|
|
||||||
self.slider.setValue(-1)
|
|
||||||
layout_seed.addWidget(QLabel("Style:"), 0, 2)
|
|
||||||
layout_seed.addWidget(self.slider, 0, 3)
|
|
||||||
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
||||||
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
||||||
" This feature requires `webrtcvad` to be installed.")
|
" This feature requires `webrtcvad` to be installed.")
|
||||||
layout_seed.addWidget(self.trim_silences_checkbox, 0, 4, 1, 2)
|
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
||||||
|
self.style_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.style_slider.setTickInterval(1)
|
||||||
|
self.style_slider.setFocusPolicy(Qt.NoFocus)
|
||||||
|
self.style_slider.setSingleStep(1)
|
||||||
|
self.style_slider.setRange(-1, 9)
|
||||||
|
self.style_value_label = QLabel("-1")
|
||||||
|
self.style_slider.setValue(-1)
|
||||||
|
layout_seed.addWidget(QLabel("Style:"), 1, 0)
|
||||||
|
|
||||||
|
self.style_slider.valueChanged.connect(lambda s: self.style_value_label.setNum(s))
|
||||||
|
layout_seed.addWidget(self.style_value_label, 1, 1)
|
||||||
|
layout_seed.addWidget(self.style_slider, 1, 3)
|
||||||
|
|
||||||
|
self.token_slider = QSlider(Qt.Horizontal)
|
||||||
|
self.token_slider.setTickInterval(1)
|
||||||
|
self.token_slider.setFocusPolicy(Qt.NoFocus)
|
||||||
|
self.token_slider.setSingleStep(1)
|
||||||
|
self.token_slider.setRange(3, 9)
|
||||||
|
self.token_value_label = QLabel("5")
|
||||||
|
self.token_slider.setValue(4)
|
||||||
|
layout_seed.addWidget(QLabel("Accuracy(精度):"), 2, 0)
|
||||||
|
|
||||||
|
self.token_slider.valueChanged.connect(lambda s: self.token_value_label.setNum(s))
|
||||||
|
layout_seed.addWidget(self.token_value_label, 2, 1)
|
||||||
|
layout_seed.addWidget(self.token_slider, 2, 3)
|
||||||
|
|
||||||
gen_layout.addLayout(layout_seed)
|
gen_layout.addLayout(layout_seed)
|
||||||
|
|
||||||
self.loading_bar = QProgressBar()
|
self.loading_bar = QProgressBar()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user