mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Add gst (#137)
* Commit with working GST * Make it backward compatible * Add readme
This commit is contained in:
parent
a824b54122
commit
2a99f0ff05
16
.vscode/launch.json
vendored
16
.vscode/launch.json
vendored
|
@ -17,7 +17,7 @@
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "vocoder_preprocess.py",
|
"program": "vocoder_preprocess.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["..\\..\\chs1"]
|
"args": ["..\\audiodata"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: Vocoder Train",
|
"name": "Python: Vocoder Train",
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "vocoder_train.py",
|
"program": "vocoder_train.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["dev", "..\\..\\chs1"]
|
"args": ["dev", "..\\audiodata"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: Demo Box",
|
"name": "Python: Demo Box",
|
||||||
|
@ -33,7 +33,15 @@
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "demo_toolbox.py",
|
"program": "demo_toolbox.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": ["-d","..\\..\\chs"]
|
"args": ["-d","..\\audiodata"]
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"name": "Python: Synth Train",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "synthesizer_train.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": ["my_run", "..\\"]
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,8 +121,9 @@
|
||||||
|
|
||||||
| URL | Designation | 标题 | 实现源码 |
|
| URL | Designation | 标题 | 实现源码 |
|
||||||
| --- | ----------- | ----- | --------------------- |
|
| --- | ----------- | ----- | --------------------- |
|
||||||
|
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
|
||||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|
||||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | This repo |
|
||||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||||
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
||||||
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
|
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
|
||||||
|
|
|
@ -77,6 +77,7 @@ You can then try the toolbox:
|
||||||
|
|
||||||
| URL | Designation | Title | Implementation source |
|
| URL | Designation | Title | Implementation source |
|
||||||
| --- | ----------- | ----- | --------------------- |
|
| --- | ----------- | ----- | --------------------- |
|
||||||
|
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | This repo |
|
||||||
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
|
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | This repo |
|
||||||
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | **SV2TTS** | **Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis** | This repo |
|
||||||
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
||||||
|
|
|
@ -19,4 +19,5 @@ flask
|
||||||
flask_wtf
|
flask_wtf
|
||||||
flask_cors
|
flask_cors
|
||||||
gevent==21.8.0
|
gevent==21.8.0
|
||||||
flask_restx
|
flask_restx
|
||||||
|
tensorboard
|
13
synthesizer/gst_hyperparameters.py
Normal file
13
synthesizer/gst_hyperparameters.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
class GSTHyperparameters():
|
||||||
|
E = 512
|
||||||
|
|
||||||
|
# reference encoder
|
||||||
|
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
||||||
|
|
||||||
|
# style token layer
|
||||||
|
token_num = 10
|
||||||
|
# token_emb_size = 256
|
||||||
|
num_heads = 8
|
||||||
|
|
||||||
|
n_mels = 256 # Number of Mel banks to generate
|
||||||
|
|
|
@ -70,7 +70,7 @@ class Synthesizer:
|
||||||
|
|
||||||
def synthesize_spectrograms(self, texts: List[str],
|
def synthesize_spectrograms(self, texts: List[str],
|
||||||
embeddings: Union[np.ndarray, List[np.ndarray]],
|
embeddings: Union[np.ndarray, List[np.ndarray]],
|
||||||
return_alignments=False):
|
return_alignments=False, style_idx=0):
|
||||||
"""
|
"""
|
||||||
Synthesizes mel spectrograms from texts and speaker embeddings.
|
Synthesizes mel spectrograms from texts and speaker embeddings.
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ class Synthesizer:
|
||||||
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
_, mels, alignments = self._model.generate(chars, speaker_embeddings)
|
_, mels, alignments = self._model.generate(chars, speaker_embeddings, style_idx=style_idx)
|
||||||
mels = mels.detach().cpu().numpy()
|
mels = mels.detach().cpu().numpy()
|
||||||
for m in mels:
|
for m in mels:
|
||||||
# Trim silence from end of each spectrogram
|
# Trim silence from end of each spectrogram
|
||||||
|
|
135
synthesizer/models/global_style_token.py
Normal file
135
synthesizer/models/global_style_token.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.init as init
|
||||||
|
import torch.nn.functional as tFunctional
|
||||||
|
from synthesizer.gst_hyperparameters import GSTHyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalStyleToken(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = ReferenceEncoder()
|
||||||
|
self.stl = STL()
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
enc_out = self.encoder(inputs)
|
||||||
|
style_embed = self.stl(enc_out)
|
||||||
|
|
||||||
|
return style_embed
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceEncoder(nn.Module):
|
||||||
|
'''
|
||||||
|
inputs --- [N, Ty/r, n_mels*r] mels
|
||||||
|
outputs --- [N, ref_enc_gru_size]
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
K = len(hp.ref_enc_filters)
|
||||||
|
filters = [1] + hp.ref_enc_filters
|
||||||
|
convs = [nn.Conv2d(in_channels=filters[i],
|
||||||
|
out_channels=filters[i + 1],
|
||||||
|
kernel_size=(3, 3),
|
||||||
|
stride=(2, 2),
|
||||||
|
padding=(1, 1)) for i in range(K)]
|
||||||
|
self.convs = nn.ModuleList(convs)
|
||||||
|
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=hp.ref_enc_filters[i]) for i in range(K)])
|
||||||
|
|
||||||
|
out_channels = self.calculate_channels(hp.n_mels, 3, 2, 1, K)
|
||||||
|
self.gru = nn.GRU(input_size=hp.ref_enc_filters[-1] * out_channels,
|
||||||
|
hidden_size=hp.E // 2,
|
||||||
|
batch_first=True)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
N = inputs.size(0)
|
||||||
|
out = inputs.view(N, 1, -1, hp.n_mels) # [N, 1, Ty, n_mels]
|
||||||
|
for conv, bn in zip(self.convs, self.bns):
|
||||||
|
out = conv(out)
|
||||||
|
out = bn(out)
|
||||||
|
out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
||||||
|
|
||||||
|
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
||||||
|
T = out.size(1)
|
||||||
|
N = out.size(0)
|
||||||
|
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
||||||
|
|
||||||
|
self.gru.flatten_parameters()
|
||||||
|
memory, out = self.gru(out) # out --- [1, N, E//2]
|
||||||
|
|
||||||
|
return out.squeeze(0)
|
||||||
|
|
||||||
|
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
||||||
|
for i in range(n_convs):
|
||||||
|
L = (L - kernel_size + 2 * pad) // stride + 1
|
||||||
|
return L
|
||||||
|
|
||||||
|
|
||||||
|
class STL(nn.Module):
|
||||||
|
'''
|
||||||
|
inputs --- [N, E//2]
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.embed = nn.Parameter(torch.FloatTensor(hp.token_num, hp.E // hp.num_heads))
|
||||||
|
d_q = hp.E // 2
|
||||||
|
d_k = hp.E // hp.num_heads
|
||||||
|
# self.attention = MultiHeadAttention(hp.num_heads, d_model, d_q, d_v)
|
||||||
|
self.attention = MultiHeadAttention(query_dim=d_q, key_dim=d_k, num_units=hp.E, num_heads=hp.num_heads)
|
||||||
|
|
||||||
|
init.normal_(self.embed, mean=0, std=0.5)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
N = inputs.size(0)
|
||||||
|
query = inputs.unsqueeze(1) # [N, 1, E//2]
|
||||||
|
keys = tFunctional.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads]
|
||||||
|
style_embed = self.attention(query, keys)
|
||||||
|
|
||||||
|
return style_embed
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
'''
|
||||||
|
input:
|
||||||
|
query --- [N, T_q, query_dim]
|
||||||
|
key --- [N, T_k, key_dim]
|
||||||
|
output:
|
||||||
|
out --- [N, T_q, num_units]
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.num_units = num_units
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.key_dim = key_dim
|
||||||
|
|
||||||
|
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||||
|
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||||
|
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||||
|
|
||||||
|
def forward(self, query, key):
|
||||||
|
querys = self.W_query(query) # [N, T_q, num_units]
|
||||||
|
keys = self.W_key(key) # [N, T_k, num_units]
|
||||||
|
values = self.W_value(key)
|
||||||
|
|
||||||
|
split_size = self.num_units // self.num_heads
|
||||||
|
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
||||||
|
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||||
|
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||||
|
|
||||||
|
# score = softmax(QK^T / (d_k ** 0.5))
|
||||||
|
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||||
|
scores = scores / (self.key_dim ** 0.5)
|
||||||
|
scores = tFunctional.softmax(scores, dim=3)
|
||||||
|
|
||||||
|
# out = score * V
|
||||||
|
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||||
|
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||||
|
|
||||||
|
return out
|
|
@ -3,8 +3,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pathlib import Path
|
from synthesizer.models.global_style_token import GlobalStyleToken
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
|
|
||||||
class HighwayNetwork(nn.Module):
|
class HighwayNetwork(nn.Module):
|
||||||
|
@ -338,6 +337,7 @@ class Tacotron(nn.Module):
|
||||||
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
||||||
encoder_K, num_highways, dropout)
|
encoder_K, num_highways, dropout)
|
||||||
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
||||||
|
self.gst = GlobalStyleToken()
|
||||||
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
||||||
dropout, speaker_embedding_size)
|
dropout, speaker_embedding_size)
|
||||||
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
||||||
|
@ -358,11 +358,11 @@ class Tacotron(nn.Module):
|
||||||
def r(self, value):
|
def r(self, value):
|
||||||
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
||||||
|
|
||||||
def forward(self, x, m, speaker_embedding):
|
def forward(self, texts, mels, speaker_embedding):
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = next(self.parameters()).device # use same device as parameters
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
batch_size, _, steps = m.size()
|
batch_size, _, steps = mels.size()
|
||||||
|
|
||||||
# Initialise all hidden states and pack into tuple
|
# Initialise all hidden states and pack into tuple
|
||||||
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||||
|
@ -383,7 +383,12 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
# SV2TTS: Run the encoder with the speaker embedding
|
# SV2TTS: Run the encoder with the speaker embedding
|
||||||
# The projection avoids unnecessary matmuls in the decoder loop
|
# The projection avoids unnecessary matmuls in the decoder loop
|
||||||
encoder_seq = self.encoder(x, speaker_embedding)
|
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||||
|
# put after encoder
|
||||||
|
if self.gst is not None:
|
||||||
|
style_embed = self.gst(speaker_embedding)
|
||||||
|
style_embed = style_embed.expand_as(encoder_seq)
|
||||||
|
encoder_seq = encoder_seq + style_embed
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
# Need a couple of lists for outputs
|
# Need a couple of lists for outputs
|
||||||
|
@ -391,10 +396,10 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
# Run the decoder loop
|
# Run the decoder loop
|
||||||
for t in range(0, steps, self.r):
|
for t in range(0, steps, self.r):
|
||||||
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
|
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
|
||||||
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||||
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||||
hidden_states, cell_states, context_vec, t, x)
|
hidden_states, cell_states, context_vec, t, texts)
|
||||||
mel_outputs.append(mel_frames)
|
mel_outputs.append(mel_frames)
|
||||||
attn_scores.append(scores)
|
attn_scores.append(scores)
|
||||||
stop_outputs.extend([stop_tokens] * self.r)
|
stop_outputs.extend([stop_tokens] * self.r)
|
||||||
|
@ -414,7 +419,7 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
return mel_outputs, linear, attn_scores, stop_outputs
|
return mel_outputs, linear, attn_scores, stop_outputs
|
||||||
|
|
||||||
def generate(self, x, speaker_embedding=None, steps=2000):
|
def generate(self, x, speaker_embedding=None, steps=200, style_idx=0):
|
||||||
self.eval()
|
self.eval()
|
||||||
device = next(self.parameters()).device # use same device as parameters
|
device = next(self.parameters()).device # use same device as parameters
|
||||||
|
|
||||||
|
@ -440,6 +445,18 @@ class Tacotron(nn.Module):
|
||||||
# SV2TTS: Run the encoder with the speaker embedding
|
# SV2TTS: Run the encoder with the speaker embedding
|
||||||
# The projection avoids unnecessary matmuls in the decoder loop
|
# The projection avoids unnecessary matmuls in the decoder loop
|
||||||
encoder_seq = self.encoder(x, speaker_embedding)
|
encoder_seq = self.encoder(x, speaker_embedding)
|
||||||
|
|
||||||
|
# put after encoder
|
||||||
|
if self.gst is not None and style_idx >= 0 and style_idx < 10:
|
||||||
|
gst_embed = self.gst.stl.embed.cpu().data.numpy() #[0, number_token]
|
||||||
|
gst_embed = np.tile(gst_embed, (1, 8))
|
||||||
|
scale = np.zeros(512)
|
||||||
|
scale[:] = 0.3
|
||||||
|
speaker_embedding = (gst_embed[style_idx] * scale).astype(np.float32)
|
||||||
|
speaker_embedding = torch.from_numpy(np.tile(speaker_embedding, (x.shape[0], 1))).to(device)
|
||||||
|
style_embed = self.gst(speaker_embedding)
|
||||||
|
style_embed = style_embed.expand_as(encoder_seq)
|
||||||
|
encoder_seq = encoder_seq + style_embed
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
# Need a couple of lists for outputs
|
# Need a couple of lists for outputs
|
||||||
|
@ -494,7 +511,7 @@ class Tacotron(nn.Module):
|
||||||
# Use device of model params as location for loaded state
|
# Use device of model params as location for loaded state
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
checkpoint = torch.load(str(path), map_location=device)
|
checkpoint = torch.load(str(path), map_location=device)
|
||||||
self.load_state_dict(checkpoint["model_state"])
|
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||||
|
|
||||||
if "optimizer_state" in checkpoint and optimizer is not None:
|
if "optimizer_state" in checkpoint and optimizer is not None:
|
||||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||||
|
|
|
@ -71,6 +71,7 @@ class Toolbox:
|
||||||
|
|
||||||
# Initialize the events and the interface
|
# Initialize the events and the interface
|
||||||
self.ui = UI()
|
self.ui = UI()
|
||||||
|
self.style_idx = 0
|
||||||
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
|
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
|
||||||
self.setup_events()
|
self.setup_events()
|
||||||
self.ui.start()
|
self.ui.start()
|
||||||
|
@ -233,7 +234,7 @@ class Toolbox:
|
||||||
texts = processed_texts
|
texts = processed_texts
|
||||||
embed = self.ui.selected_utterance.embed
|
embed = self.ui.selected_utterance.embed
|
||||||
embeds = [embed] * len(texts)
|
embeds = [embed] * len(texts)
|
||||||
specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
|
specs = self.synthesizer.synthesize_spectrograms(texts, embeds, style_idx=int(self.ui.style_idx_textbox.text()))
|
||||||
breaks = [spec.shape[1] for spec in specs]
|
breaks = [spec.shape[1] for spec in specs]
|
||||||
spec = np.concatenate(specs, axis=1)
|
spec = np.concatenate(specs, axis=1)
|
||||||
|
|
||||||
|
|
|
@ -576,10 +576,14 @@ class UI(QDialog):
|
||||||
self.seed_textbox = QLineEdit()
|
self.seed_textbox = QLineEdit()
|
||||||
self.seed_textbox.setMaximumWidth(80)
|
self.seed_textbox.setMaximumWidth(80)
|
||||||
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
||||||
|
layout_seed.addWidget(QLabel("Style#:(0~9)"), 0, 2)
|
||||||
|
self.style_idx_textbox = QLineEdit("-1")
|
||||||
|
self.style_idx_textbox.setMaximumWidth(80)
|
||||||
|
layout_seed.addWidget(self.style_idx_textbox, 0, 3)
|
||||||
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
||||||
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
||||||
" This feature requires `webrtcvad` to be installed.")
|
" This feature requires `webrtcvad` to be installed.")
|
||||||
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
layout_seed.addWidget(self.trim_silences_checkbox, 0, 4, 1, 2)
|
||||||
gen_layout.addLayout(layout_seed)
|
gen_layout.addLayout(layout_seed)
|
||||||
|
|
||||||
self.loading_bar = QProgressBar()
|
self.loading_bar = QProgressBar()
|
||||||
|
|
|
@ -11,7 +11,6 @@ def check_model_paths(encoder_path: Path, synthesizer_path: Path, vocoder_path:
|
||||||
|
|
||||||
# If none of the paths exist, remind the user to download models if needed
|
# If none of the paths exist, remind the user to download models if needed
|
||||||
print("********************************************************************************")
|
print("********************************************************************************")
|
||||||
print("Error: Model files not found. Follow these instructions to get and install the models:")
|
print("Error: Model files not found. Please download the models")
|
||||||
print("https://github.com/CorentinJ/Real-Time-Voice-Cloning/wiki/Pretrained-models")
|
|
||||||
print("********************************************************************************\n")
|
print("********************************************************************************\n")
|
||||||
quit(-1)
|
quit(-1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user