From 43c86eb411f81d005726beb86edc5cb01d7ab05c Mon Sep 17 00:00:00 2001 From: babysor00 Date: Tue, 12 Oct 2021 09:12:58 +0800 Subject: [PATCH] Make it backward compatible --- .vscode/launch.json | 6 ++--- synthesizer/inference.py | 4 ++-- .../{ => models}/global_style_token.py | 2 +- synthesizer/models/tacotron.py | 23 ++++++++++++------- toolbox/__init__.py | 3 ++- toolbox/ui.py | 6 ++++- 6 files changed, 28 insertions(+), 16 deletions(-) rename synthesizer/{ => models}/global_style_token.py (99%) diff --git a/.vscode/launch.json b/.vscode/launch.json index 0f6b728..3b264f6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -17,7 +17,7 @@ "request": "launch", "program": "vocoder_preprocess.py", "console": "integratedTerminal", - "args": ["..\\..\\chs1"] + "args": ["..\\audiodata"] }, { "name": "Python: Vocoder Train", @@ -25,7 +25,7 @@ "request": "launch", "program": "vocoder_train.py", "console": "integratedTerminal", - "args": ["dev", "..\\..\\chs1"] + "args": ["dev", "..\\audiodata"] }, { "name": "Python: Demo Box", @@ -33,7 +33,7 @@ "request": "launch", "program": "demo_toolbox.py", "console": "integratedTerminal", - "args": ["-d","..\\..\\chs"] + "args": ["-d","..\\audiodata"] }, { "name": "Python: Synth Train", diff --git a/synthesizer/inference.py b/synthesizer/inference.py index 987a70d..2a62754 100644 --- a/synthesizer/inference.py +++ b/synthesizer/inference.py @@ -70,7 +70,7 @@ class Synthesizer: def synthesize_spectrograms(self, texts: List[str], embeddings: Union[np.ndarray, List[np.ndarray]], - return_alignments=False): + return_alignments=False, style_idx=0): """ Synthesizes mel spectrograms from texts and speaker embeddings. @@ -125,7 +125,7 @@ class Synthesizer: speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device) # Inference - _, mels, alignments = self._model.generate(chars, speaker_embeddings) + _, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx) mels = mels.detach().cpu().numpy() for m in mels: # Trim silence from end of each spectrogram diff --git a/synthesizer/global_style_token.py b/synthesizer/models/global_style_token.py similarity index 99% rename from synthesizer/global_style_token.py rename to synthesizer/models/global_style_token.py index a884867..79282c2 100644 --- a/synthesizer/global_style_token.py +++ b/synthesizer/models/global_style_token.py @@ -90,7 +90,7 @@ class STL(nn.Module): keys = tFunctional.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads] style_embed = self.attention(query, keys) - return style_embed, keys + return style_embed class MultiHeadAttention(nn.Module): diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 44407e3..818a17f 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -3,7 +3,7 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from synthesizer.global_style_token import GlobalStyleToken +from synthesizer.models.global_style_token import GlobalStyleToken class HighwayNetwork(nn.Module): @@ -385,10 +385,10 @@ class Tacotron(nn.Module): # The projection avoids unnecessary matmuls in the decoder loop encoder_seq = self.encoder(texts, speaker_embedding) # put after encoder - style_embed, _ = self.gst(speaker_embedding) # [N, 256] - style_embed = style_embed.expand_as(encoder_seq) - encoder_seq = encoder_seq + style_embed - + if self.gst is not None: + style_embed = self.gst(speaker_embedding) + style_embed = style_embed.expand_as(encoder_seq) + encoder_seq = encoder_seq + style_embed encoder_seq_proj = self.encoder_proj(encoder_seq) # Need a couple of lists for outputs @@ -419,7 +419,7 @@ class Tacotron(nn.Module): return mel_outputs, linear, attn_scores, stop_outputs - def generate(self, x, speaker_embedding=None, steps=200): + def generate(self, x, speaker_embedding=None, steps=200, style_idx=0): self.eval() device = next(self.parameters()).device # use same device as parameters @@ -447,7 +447,14 @@ class Tacotron(nn.Module): encoder_seq = self.encoder(x, speaker_embedding) # put after encoder - style_embed = self.gst(speaker_embedding) # [N, 256] + if self.gst is not None and style_idx >= 0 and style_idx < 10: + gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token] + gst_embed = np.tile(gst_embed, (1, 8)) + scale = np.zeros(512) + scale[:] = 0.3 + speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32) + speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device) + style_embed = self.gst(speaker_embedding) style_embed = style_embed.expand_as(encoder_seq) encoder_seq = encoder_seq + style_embed encoder_seq_proj = self.encoder_proj(encoder_seq) @@ -504,7 +511,7 @@ class Tacotron(nn.Module): # Use device of model params as location for loaded state device = next(self.parameters()).device checkpoint = torch.load(str(path), map_location=device) - self.load_state_dict(checkpoint["model_state"]) + self.load_state_dict(checkpoint["model_state"], strict=False) if "optimizer_state" in checkpoint and optimizer is not None: optimizer.load_state_dict(checkpoint["optimizer_state"]) diff --git a/toolbox/__init__.py b/toolbox/__init__.py index 08994b6..7d67b52 100644 --- a/toolbox/__init__.py +++ b/toolbox/__init__.py @@ -71,6 +71,7 @@ class Toolbox: # Initialize the events and the interface self.ui = UI() + self.style_idx = 0 self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed) self.setup_events() self.ui.start() @@ -233,7 +234,7 @@ class Toolbox: texts = processed_texts embed = self.ui.selected_utterance.embed embeds = [embed] * len(texts) - specs = self.synthesizer.synthesize_spectrograms(texts, embeds) + specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_idx_textbox.text())) breaks = [spec.shape[1] for spec in specs] spec = np.concatenate(specs, axis=1) diff --git a/toolbox/ui.py b/toolbox/ui.py index 6ae6a7e..0dfded3 100644 --- a/toolbox/ui.py +++ b/toolbox/ui.py @@ -574,10 +574,14 @@ class UI(QDialog): self.seed_textbox = QLineEdit() self.seed_textbox.setMaximumWidth(80) layout_seed.addWidget(self.seed_textbox, 0, 1) + layout_seed.addWidget(QLabel("Style#:(0~9)"), 0, 2) + self.style_idx_textbox = QLineEdit("-1") + self.style_idx_textbox.setMaximumWidth(80) + layout_seed.addWidget(self.style_idx_textbox, 0, 3) self.trim_silences_checkbox = QCheckBox("Enhance vocoder output") self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output." " This feature requires `webrtcvad` to be installed.") - layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2) + layout_seed.addWidget(self.trim_silences_checkbox, 0, 4, 1, 2) gen_layout.addLayout(layout_seed) self.loading_bar = QProgressBar()