mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Make it backward compatible
This commit is contained in:
parent
37f11ab9ce
commit
43c86eb411
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
|
@ -17,7 +17,7 @@
|
|||
"request": "launch",
|
||||
"program": "vocoder_preprocess.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["..\\..\\chs1"]
|
||||
"args": ["..\\audiodata"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Vocoder Train",
|
||||
|
@ -25,7 +25,7 @@
|
|||
"request": "launch",
|
||||
"program": "vocoder_train.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["dev", "..\\..\\chs1"]
|
||||
"args": ["dev", "..\\audiodata"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Demo Box",
|
||||
|
@ -33,7 +33,7 @@
|
|||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\..\\chs"]
|
||||
"args": ["-d","..\\audiodata"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Synth Train",
|
||||
|
|
|
@ -70,7 +70,7 @@ class Synthesizer:
|
|||
|
||||
def synthesize_spectrograms(self, texts: List[str],
|
||||
embeddings: Union[np.ndarray, List[np.ndarray]],
|
||||
return_alignments=False):
|
||||
return_alignments=False, style_idx=0):
|
||||
"""
|
||||
Synthesizes mel spectrograms from texts and speaker embeddings.
|
||||
|
||||
|
@ -125,7 +125,7 @@ class Synthesizer:
|
|||
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
||||
|
||||
# Inference
|
||||
_, mels, alignments = self._model.generate(chars, speaker_embeddings)
|
||||
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx)
|
||||
mels = mels.detach().cpu().numpy()
|
||||
for m in mels:
|
||||
# Trim silence from end of each spectrogram
|
||||
|
|
|
@ -90,7 +90,7 @@ class STL(nn.Module):
|
|||
keys = tFunctional.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
||||
style_embed = self.attention(query, keys)
|
||||
|
||||
return style_embed, keys
|
||||
return style_embed
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from synthesizer.global_style_token import GlobalStyleToken
|
||||
from synthesizer.models.global_style_token import GlobalStyleToken
|
||||
|
||||
|
||||
class HighwayNetwork(nn.Module):
|
||||
|
@ -385,10 +385,10 @@ class Tacotron(nn.Module):
|
|||
# The projection avoids unnecessary matmuls in the decoder loop
|
||||
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||
# put after encoder
|
||||
style_embed, _ = self.gst(speaker_embedding) # [N, 256]
|
||||
style_embed = style_embed.expand_as(encoder_seq)
|
||||
encoder_seq = encoder_seq + style_embed
|
||||
|
||||
if self.gst is not None:
|
||||
style_embed = self.gst(speaker_embedding)
|
||||
style_embed = style_embed.expand_as(encoder_seq)
|
||||
encoder_seq = encoder_seq + style_embed
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
||||
# Need a couple of lists for outputs
|
||||
|
@ -419,7 +419,7 @@ class Tacotron(nn.Module):
|
|||
|
||||
return mel_outputs, linear, attn_scores, stop_outputs
|
||||
|
||||
def generate(self, x, speaker_embedding=None, steps=200):
|
||||
def generate(self, x, speaker_embedding=None, steps=200, style_idx=0):
|
||||
self.eval()
|
||||
device = next(self.parameters()).device # use same device as parameters
|
||||
|
||||
|
@ -447,7 +447,14 @@ class Tacotron(nn.Module):
|
|||
encoder_seq = self.encoder(x, speaker_embedding)
|
||||
|
||||
# put after encoder
|
||||
style_embed = self.gst(speaker_embedding) # [N, 256]
|
||||
if self.gst is not None and style_idx >= 0 and style_idx < 10:
|
||||
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
|
||||
gst_embed = np.tile(gst_embed, (1, 8))
|
||||
scale = np.zeros(512)
|
||||
scale[:] = 0.3
|
||||
speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32)
|
||||
speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device)
|
||||
style_embed = self.gst(speaker_embedding)
|
||||
style_embed = style_embed.expand_as(encoder_seq)
|
||||
encoder_seq = encoder_seq + style_embed
|
||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||
|
@ -504,7 +511,7 @@ class Tacotron(nn.Module):
|
|||
# Use device of model params as location for loaded state
|
||||
device = next(self.parameters()).device
|
||||
checkpoint = torch.load(str(path), map_location=device)
|
||||
self.load_state_dict(checkpoint["model_state"])
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
|
||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
|
|
|
@ -71,6 +71,7 @@ class Toolbox:
|
|||
|
||||
# Initialize the events and the interface
|
||||
self.ui = UI()
|
||||
self.style_idx = 0
|
||||
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
|
||||
self.setup_events()
|
||||
self.ui.start()
|
||||
|
@ -233,7 +234,7 @@ class Toolbox:
|
|||
texts = processed_texts
|
||||
embed = self.ui.selected_utterance.embed
|
||||
embeds = [embed] * len(texts)
|
||||
specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
|
||||
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_idx_textbox.text()))
|
||||
breaks = [spec.shape[1] for spec in specs]
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
|
||||
|
|
|
@ -574,10 +574,14 @@ class UI(QDialog):
|
|||
self.seed_textbox = QLineEdit()
|
||||
self.seed_textbox.setMaximumWidth(80)
|
||||
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
||||
layout_seed.addWidget(QLabel("Style#:(0~9)"), 0, 2)
|
||||
self.style_idx_textbox = QLineEdit("-1")
|
||||
self.style_idx_textbox.setMaximumWidth(80)
|
||||
layout_seed.addWidget(self.style_idx_textbox, 0, 3)
|
||||
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
||||
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
||||
" This feature requires `webrtcvad` to be installed.")
|
||||
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
||||
layout_seed.addWidget(self.trim_silences_checkbox, 0, 4, 1, 2)
|
||||
gen_layout.addLayout(layout_seed)
|
||||
|
||||
self.loading_bar = QProgressBar()
|
||||
|
|
Loading…
Reference in New Issue
Block a user