From 31bc6656c3f9152a11f27c155111335ed761c56a Mon Sep 17 00:00:00 2001 From: babysor00 Date: Thu, 21 Oct 2021 00:40:00 +0800 Subject: [PATCH] Fix bug of importing GST and add more parameters in toolbox --- synthesizer/inference.py | 4 ++-- synthesizer/models/tacotron.py | 10 +++++----- toolbox/__init__.py | 3 ++- toolbox/ui.py | 36 +++++++++++++++++++++++++--------- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/synthesizer/inference.py b/synthesizer/inference.py index 2a62754..f89dbac 100644 --- a/synthesizer/inference.py +++ b/synthesizer/inference.py @@ -70,7 +70,7 @@ class Synthesizer: def synthesize_spectrograms(self, texts: List[str], embeddings: Union[np.ndarray, List[np.ndarray]], - return_alignments=False, style_idx=0): + return_alignments=False, style_idx=0, min_stop_token=5): """ Synthesizes mel spectrograms from texts and speaker embeddings. @@ -125,7 +125,7 @@ class Synthesizer: speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device) # Inference - _, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx) + _, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx, min_stop_token=min_stop_token) mels = mels.detach().cpu().numpy() for m in mels: # Trim silence from end of each spectrogram diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 818a17f..dd6e241 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -419,7 +419,7 @@ class Tacotron(nn.Module): return mel_outputs, linear, attn_scores, stop_outputs - def generate(self, x, speaker_embedding=None, steps=200, style_idx=0): + def generate(self, x, speaker_embedding=None, steps=200, style_idx=0, min_stop_token=5): self.eval() device = next(self.parameters()).device # use same device as parameters @@ -454,9 +454,9 @@ class Tacotron(nn.Module): scale[:] = 0.3 speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32) speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device) - style_embed = self.gst(speaker_embedding) - style_embed = style_embed.expand_as(encoder_seq) - encoder_seq = encoder_seq + style_embed + style_embed = self.gst(speaker_embedding) + style_embed = style_embed.expand_as(encoder_seq) + encoder_seq = encoder_seq + style_embed encoder_seq_proj = self.encoder_proj(encoder_seq) # Need a couple of lists for outputs @@ -472,7 +472,7 @@ class Tacotron(nn.Module): attn_scores.append(scores) stop_outputs.extend([stop_tokens] * self.r) # Stop the loop when all stop tokens in batch exceed threshold - if (stop_tokens > 0.5).all() and t > 10: break + if (stop_tokens * 10 > min_stop_token).all() and t > 10: break # Concat the mel outputs into sequence mel_outputs = torch.cat(mel_outputs, dim=2) diff --git a/toolbox/__init__.py b/toolbox/__init__.py index d162a78..4517270 100644 --- a/toolbox/__init__.py +++ b/toolbox/__init__.py @@ -234,7 +234,8 @@ class Toolbox: texts = processed_texts embed = self.ui.selected_utterance.embed embeds = [embed] * len(texts) - specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.slider.value())) + min_token = int(self.ui.token_slider.value()) + specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_slider.value()), min_stop_token=min_token) breaks = [spec.shape[1] for spec in specs] spec = np.concatenate(specs, axis=1) diff --git a/toolbox/ui.py b/toolbox/ui.py index 8f0013c..ae5d2bc 100644 --- a/toolbox/ui.py +++ b/toolbox/ui.py @@ -588,18 +588,36 @@ class UI(QDialog): self.seed_textbox = QLineEdit() self.seed_textbox.setMaximumWidth(80) layout_seed.addWidget(self.seed_textbox, 0, 1) - self.slider = QSlider(Qt.Horizontal) - self.slider.setTickInterval(1) - self.slider.setFocusPolicy(Qt.NoFocus) - self.slider.setSingleStep(1) - self.slider.setRange(-1, 9) - self.slider.setValue(-1) - layout_seed.addWidget(QLabel("Style:"), 0, 2) - layout_seed.addWidget(self.slider, 0, 3) self.trim_silences_checkbox = QCheckBox("Enhance vocoder output") self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output." " This feature requires `webrtcvad` to be installed.") - layout_seed.addWidget(self.trim_silences_checkbox, 0, 4, 1, 2) + layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2) + self.style_slider = QSlider(Qt.Horizontal) + self.style_slider.setTickInterval(1) + self.style_slider.setFocusPolicy(Qt.NoFocus) + self.style_slider.setSingleStep(1) + self.style_slider.setRange(-1, 9) + self.style_value_label = QLabel("-1") + self.style_slider.setValue(-1) + layout_seed.addWidget(QLabel("Style:"), 1, 0) + + self.style_slider.valueChanged.connect(lambda s: self.style_value_label.setNum(s)) + layout_seed.addWidget(self.style_value_label, 1, 1) + layout_seed.addWidget(self.style_slider, 1, 3) + + self.token_slider = QSlider(Qt.Horizontal) + self.token_slider.setTickInterval(1) + self.token_slider.setFocusPolicy(Qt.NoFocus) + self.token_slider.setSingleStep(1) + self.token_slider.setRange(3, 9) + self.token_value_label = QLabel("5") + self.token_slider.setValue(4) + layout_seed.addWidget(QLabel("Accuracy(精度):"), 2, 0) + + self.token_slider.valueChanged.connect(lambda s: self.token_value_label.setNum(s)) + layout_seed.addWidget(self.token_value_label, 2, 1) + layout_seed.addWidget(self.token_slider, 2, 3) + gen_layout.addLayout(layout_seed) self.loading_bar = QProgressBar()