Add web gui of training and reconstruct taco model methods

This commit is contained in:
babysor00 2022-06-26 23:21:32 +08:00
parent a39b6d3117
commit 6f023e313d
5 changed files with 333 additions and 188 deletions

View File

@ -815,6 +815,9 @@ def getOpyrator(mode: str) -> Opyrator:
if mode == None or mode.startswith('模型训练'): if mode == None or mode.startswith('模型训练'):
from mkgui.train import train from mkgui.train import train
return Opyrator(train) return Opyrator(train)
if mode == None or mode.startswith('模型训练(VC)'):
from mkgui.train_vc import train_vc
return Opyrator(train_vc)
from mkgui.app import synthesize from mkgui.app import synthesize
return Opyrator(synthesize) return Opyrator(synthesize)
@ -829,7 +832,7 @@ def render_streamlit_ui() -> None:
with st.spinner("Loading MockingBird GUI. Please wait..."): with st.spinner("Loading MockingBird GUI. Please wait..."):
session_state.mode = st.sidebar.selectbox( session_state.mode = st.sidebar.selectbox(
'模式选择', '模式选择',
( "AI拟音", "VC拟音", "预处理", "模型训练") ( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
) )
if "mode" in session_state: if "mode" in session_state:
mode = session_state.mode mode = session_state.mode

View File

@ -2,66 +2,55 @@ from pydantic import BaseModel, Field
import os import os
from pathlib import Path from pathlib import Path
from enum import Enum from enum import Enum
from typing import Any, Tuple from typing import Any
import numpy as np from synthesizer.hparams import hparams
from utils.load_yaml import HpsYaml from synthesizer.train import train as synt_train
from utils.util import AttrDict
import torch
# TODO: seperator for *unix systems
# Constants # Constants
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models" SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
if os.path.isdir(EXT_MODELS_DIRT): # EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) # CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
print("Loaded extractor models: " + str(len(extractors))) # ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
else:
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
if os.path.isdir(CONV_MODELS_DIRT): # Pre-Load models
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth"))) if os.path.isdir(SYN_MODELS_DIRT):
print("Loaded convertor models: " + str(len(convertors))) synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
print("Loaded synthesizer models: " + str(len(synthesizers)))
else: else:
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.") raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT): if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders))) print("Loaded encoders models: " + str(len(encoders)))
else: else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
class Model(str, Enum): class Model(str, Enum):
VC_PPG2MEL = "ppg2mel" DEFAULT = "default"
class Dataset(str, Enum):
AIDATATANG_200ZH = "aidatatang_200zh"
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
class Input(BaseModel): class Input(BaseModel):
# def render_input_ui(st, input) -> Dict:
# input["selected_dataset"] = st.selectbox(
# '选择数据集',
# ("aidatatang_200zh", "aidatatang_200zh_s")
# )
# return input
model: Model = Field( model: Model = Field(
Model.VC_PPG2MEL, title="模型类型", Model.DEFAULT, title="模型类型",
) )
# datasets_root: str = Field( # datasets_root: str = Field(
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
# format=True, # format=True,
# example="..\\trainning_data\\" # example="..\\trainning_data\\"
# ) # )
output_root: str = Field( input_root: str = Field(
..., alias="出目录(可选)", description="建议不填,保持默认", ..., alias="入目录", description="预处理数据根目录",
format=True, format=True,
example="" example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
) )
continue_mode: bool = Field( run_id: str = Field(
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练", "", alias="新模型名/运行ID", description="使用新ID进行重新训练否则选择下面的模型进行继续训练",
)
synthesizer: synthesizers = Field(
..., alias="已有合成模型",
description="选择语音合成模型文件."
) )
gpu: bool = Field( gpu: bool = Field(
True, alias="GPU训练", description="选择“是”则使用GPU训练", True, alias="GPU训练", description="选择“是”则使用GPU训练",
@ -69,32 +58,18 @@ class Input(BaseModel):
verbose: bool = Field( verbose: bool = Field(
True, alias="打印详情", description="选择“是”,输出更多详情", True, alias="打印详情", description="选择“是”,输出更多详情",
) )
# TODO: Move to hiden fields by default
convertor: convertors = Field(
..., alias="转换模型",
description="选择语音转换模型文件."
)
extractor: extractors = Field(
..., alias="特征提取模型",
description="选择PPG特征提取模型文件."
)
encoder: encoders = Field( encoder: encoders = Field(
..., alias="语音编码模型", ..., alias="语音编码模型",
description="选择语音编码模型文件." description="选择语音编码模型文件."
) )
njobs: int = Field( save_every: int = Field(
8, alias="进程数", description="适用于ppg2mel", 1000, alias="更新间隔", description="每隔n步则更新一次模型",
) )
seed: int = Field( backup_every: int = Field(
default=0, alias="初始随机数", description="适用于ppg2mel", 10000, alias="保存间隔", description="每隔n步则保存一次模型",
) )
model_name: str = Field( log_every: int = Field(
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效", 500, alias="打印间隔", description="每隔n步则打印一次训练统计",
example="test"
)
model_config: str = Field(
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
) )
class AudioEntity(BaseModel): class AudioEntity(BaseModel):
@ -102,55 +77,30 @@ class AudioEntity(BaseModel):
mel: Any mel: Any
class Output(BaseModel): class Output(BaseModel):
__root__: Tuple[str, int] __root__: int
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore def render_output_ui(self, streamlit_app) -> None: # type: ignore
"""Custom output UI. """Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer. If this method is implmeneted, it will be used instead of the default Output UI renderer.
""" """
sr, count = self.__root__ streamlit_app.subheader(f"Training started with code: {self.__root__}")
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
def train(input: Input) -> Output: def train(input: Input) -> Output:
"""Train(训练)""" """Train(训练)"""
print(">>> OneShot VC training ...") print(">>> Start training ...")
params = AttrDict() force_restart = len(input.run_id) > 0
params.update({ if not force_restart:
"gpu": input.gpu, input.run_id = Path(input.synthesizer.value).name.split('.')[0]
"cpu": not input.gpu,
"njobs": input.njobs, synt_train(
"seed": input.seed, input.run_id,
"verbose": input.verbose, input.input_root,
"load": input.convertor.value, f"synthesizer{os.sep}saved_models",
"warm_start": False, input.save_every,
}) input.backup_every,
if input.continue_mode: input.log_every,
# trace old model and config force_restart,
p = Path(input.convertor.value) hparams
params.name = p.parent.name )
# search a config file return Output(__root__=0)
model_config_fpaths = list(p.parent.rglob("*.yaml"))
if len(model_config_fpaths) == 0:
raise "No model yaml config found for convertor"
config = HpsYaml(model_config_fpaths[0])
params.ckpdir = p.parent.parent
params.config = model_config_fpaths[0]
params.logdir = os.path.join(p.parent, "log")
else:
# Make the config dict dot visitable
config = HpsYaml(input.config)
np.random.seed(input.seed)
torch.manual_seed(input.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(input.seed)
mode = "train"
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
solver = Solver(config, params, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
# TODO: pass useful return code
return Output(__root__=(input.dataset, 0))

155
mkgui/train_vc.py Normal file
View File

@ -0,0 +1,155 @@
from pydantic import BaseModel, Field
import os
from pathlib import Path
from enum import Enum
from typing import Any, Tuple
import numpy as np
from utils.load_yaml import HpsYaml
from utils.util import AttrDict
import torch
# Constants
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
if os.path.isdir(EXT_MODELS_DIRT):
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
print("Loaded extractor models: " + str(len(extractors)))
else:
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
if os.path.isdir(CONV_MODELS_DIRT):
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
print("Loaded convertor models: " + str(len(convertors)))
else:
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
if os.path.isdir(ENC_MODELS_DIRT):
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
print("Loaded encoders models: " + str(len(encoders)))
else:
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
class Model(str, Enum):
VC_PPG2MEL = "ppg2mel"
class Dataset(str, Enum):
AIDATATANG_200ZH = "aidatatang_200zh"
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
class Input(BaseModel):
# def render_input_ui(st, input) -> Dict:
# input["selected_dataset"] = st.selectbox(
# '选择数据集',
# ("aidatatang_200zh", "aidatatang_200zh_s")
# )
# return input
model: Model = Field(
Model.VC_PPG2MEL, title="模型类型",
)
# datasets_root: str = Field(
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
# format=True,
# example="..\\trainning_data\\"
# )
output_root: str = Field(
..., alias="输出目录(可选)", description="建议不填,保持默认",
format=True,
example=""
)
continue_mode: bool = Field(
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
)
gpu: bool = Field(
True, alias="GPU训练", description="选择“是”则使用GPU训练",
)
verbose: bool = Field(
True, alias="打印详情", description="选择“是”,输出更多详情",
)
# TODO: Move to hiden fields by default
convertor: convertors = Field(
..., alias="转换模型",
description="选择语音转换模型文件."
)
extractor: extractors = Field(
..., alias="特征提取模型",
description="选择PPG特征提取模型文件."
)
encoder: encoders = Field(
..., alias="语音编码模型",
description="选择语音编码模型文件."
)
njobs: int = Field(
8, alias="进程数", description="适用于ppg2mel",
)
seed: int = Field(
default=0, alias="初始随机数", description="适用于ppg2mel",
)
model_name: str = Field(
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
example="test"
)
model_config: str = Field(
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
)
class AudioEntity(BaseModel):
content: bytes
mel: Any
class Output(BaseModel):
__root__: Tuple[str, int]
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
"""Custom output UI.
If this method is implmeneted, it will be used instead of the default Output UI renderer.
"""
sr, count = self.__root__
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
def train_vc(input: Input) -> Output:
"""Train VC(训练 VC)"""
print(">>> OneShot VC training ...")
params = AttrDict()
params.update({
"gpu": input.gpu,
"cpu": not input.gpu,
"njobs": input.njobs,
"seed": input.seed,
"verbose": input.verbose,
"load": input.convertor.value,
"warm_start": False,
})
if input.continue_mode:
# trace old model and config
p = Path(input.convertor.value)
params.name = p.parent.name
# search a config file
model_config_fpaths = list(p.parent.rglob("*.yaml"))
if len(model_config_fpaths) == 0:
raise "No model yaml config found for convertor"
config = HpsYaml(model_config_fpaths[0])
params.ckpdir = p.parent.parent
params.config = model_config_fpaths[0]
params.logdir = os.path.join(p.parent, "log")
else:
# Make the config dict dot visitable
config = HpsYaml(input.config)
np.random.seed(input.seed)
torch.manual_seed(input.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(input.seed)
mode = "train"
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
solver = Solver(config, params, mode)
solver.load_data()
solver.set_model()
solver.exec()
print(">>> Oneshot VC train finished!")
# TODO: pass useful return code
return Output(__root__=(input.dataset, 0))

View File

@ -1,4 +1,5 @@
import os import os
from matplotlib.pyplot import step
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -297,7 +298,7 @@ class Decoder(nn.Module):
x = torch.cat([context_vec, attn_hidden], dim=1) x = torch.cat([context_vec, attn_hidden], dim=1)
x = self.rnn_input(x) x = self.rnn_input(x)
# Compute first Residual RNN # Compute first Residual RNN, training with fixed zoneout rate 0.1
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
if self.training: if self.training:
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device) rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
@ -372,11 +373,15 @@ class Tacotron(nn.Module):
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
return outputs return outputs
def forward(self, texts, mels, speaker_embedding): def forward(self, texts, mels, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
device = texts.device # use same device as parameters device = texts.device # use same device as parameters
self.step += 1 if self.training:
batch_size, _, steps = mels.size() self.step += 1
batch_size, _, steps = mels.size()
else:
batch_size, _ = texts.size()
# Initialise all hidden states and pack into tuple # Initialise all hidden states and pack into tuple
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
@ -401,11 +406,22 @@ class Tacotron(nn.Module):
# SV2TTS: Run the encoder with the speaker embedding # SV2TTS: Run the encoder with the speaker embedding
# The projection avoids unnecessary matmuls in the decoder loop # The projection avoids unnecessary matmuls in the decoder loop
encoder_seq = self.encoder(texts, speaker_embedding) encoder_seq = self.encoder(texts, speaker_embedding)
# put after encoder
if hparams.use_gst and self.gst is not None: if hparams.use_gst and self.gst is not None:
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced if self.training:
# style_embed = style_embed.expand_as(encoder_seq) style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
# encoder_seq = torch.cat((encoder_seq, style_embed), 2) # style_embed = style_embed.expand_as(encoder_seq)
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
elif style_idx >= 0 and style_idx < 10:
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
if device.type == 'cuda':
query = query.cuda()
gst_embed = torch.tanh(self.gst.stl.embed)
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
style_embed = self.gst.stl.attention(query, key)
else:
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
encoder_seq_proj = self.encoder_proj(encoder_seq) encoder_seq_proj = self.encoder_proj(encoder_seq)
@ -414,13 +430,17 @@ class Tacotron(nn.Module):
# Run the decoder loop # Run the decoder loop
for t in range(0, steps, self.r): for t in range(0, steps, self.r):
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame if self.training:
prenet_in = mels[:, :, t -1] if t > 0 else go_frame
else:
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
self.decoder(encoder_seq, encoder_seq_proj, prenet_in, self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
hidden_states, cell_states, context_vec, t, texts) hidden_states, cell_states, context_vec, t, texts)
mel_outputs.append(mel_frames) mel_outputs.append(mel_frames)
attn_scores.append(scores) attn_scores.append(scores)
stop_outputs.extend([stop_tokens] * self.r) stop_outputs.extend([stop_tokens] * self.r)
if not self.training and (stop_tokens * 10 > min_stop_token).all() and t > 10: break
# Concat the mel outputs into sequence # Concat the mel outputs into sequence
mel_outputs = torch.cat(mel_outputs, dim=2) mel_outputs = torch.cat(mel_outputs, dim=2)
@ -435,87 +455,93 @@ class Tacotron(nn.Module):
# attn_scores = attn_scores.cpu().data.numpy() # attn_scores = attn_scores.cpu().data.numpy()
stop_outputs = torch.cat(stop_outputs, 1) stop_outputs = torch.cat(stop_outputs, 1)
if self.training:
self.train()
return mel_outputs, linear, attn_scores, stop_outputs return mel_outputs, linear, attn_scores, stop_outputs
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
self.eval() self.eval()
device = x.device # use same device as parameters mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token)
batch_size, _ = x.size()
# Need to initialise all hidden states and pack into tuple for tidyness
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
# Need to initialise all lstm cell states and pack into tuple for tidyness
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
cell_states = (rnn1_cell, rnn2_cell)
# Need a <GO> Frame for start of decoder loop
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
# Need an initial context vector
size = self.encoder_dims + self.speaker_embedding_size
if hparams.use_gst:
size += gst_hp.E
context_vec = torch.zeros(batch_size, size, device=device)
# SV2TTS: Run the encoder with the speaker embedding
# The projection avoids unnecessary matmuls in the decoder loop
encoder_seq = self.encoder(x, speaker_embedding)
# put after encoder
if hparams.use_gst and self.gst is not None:
if style_idx >= 0 and style_idx < 10:
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
if device.type == 'cuda':
query = query.cuda()
gst_embed = torch.tanh(self.gst.stl.embed)
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
style_embed = self.gst.stl.attention(query, key)
else:
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
# style_embed = style_embed.expand_as(encoder_seq)
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
encoder_seq_proj = self.encoder_proj(encoder_seq)
# Need a couple of lists for outputs
mel_outputs, attn_scores, stop_outputs = [], [], []
# Run the decoder loop
for t in range(0, steps, self.r):
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
hidden_states, cell_states, context_vec, t, x)
mel_outputs.append(mel_frames)
attn_scores.append(scores)
stop_outputs.extend([stop_tokens] * self.r)
# Stop the loop when all stop tokens in batch exceed threshold
if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
# Concat the mel outputs into sequence
mel_outputs = torch.cat(mel_outputs, dim=2)
# Post-Process for Linear Spectrograms
postnet_out = self.postnet(mel_outputs)
linear = self.post_proj(postnet_out)
linear = linear.transpose(1, 2)
# For easy visualisation
attn_scores = torch.cat(attn_scores, 1)
stop_outputs = torch.cat(stop_outputs, 1)
self.train()
return mel_outputs, linear, attn_scores return mel_outputs, linear, attn_scores
# device = x.device # use same device as parameters
# batch_size, _ = x.size()
# # Need to initialise all hidden states and pack into tuple for tidyness
# attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
# rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
# rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
# hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
# # Need to initialise all lstm cell states and pack into tuple for tidyness
# rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
# rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
# cell_states = (rnn1_cell, rnn2_cell)
# # Need a <GO> Frame for start of decoder loop
# go_frame = torch.zeros(batch_size, self.n_mels, device=device)
# # Need an initial context vector
# size = self.encoder_dims + self.speaker_embedding_size
# if hparams.use_gst:
# size += gst_hp.E
# context_vec = torch.zeros(batch_size, size, device=device)
# # SV2TTS: Run the encoder with the speaker embedding
# # The projection avoids unnecessary matmuls in the decoder loop
# encoder_seq = self.encoder(x, speaker_embedding)
# # put after encoder
# if hparams.use_gst and self.gst is not None:
# if style_idx >= 0 and style_idx < 10:
# query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
# if device.type == 'cuda':
# query = query.cuda()
# gst_embed = torch.tanh(self.gst.stl.embed)
# key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
# style_embed = self.gst.stl.attention(query, key)
# else:
# speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
# style_embed = self.gst(speaker_embedding_style, speaker_embedding)
# encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
# # style_embed = style_embed.expand_as(encoder_seq)
# # encoder_seq = torch.cat((encoder_seq, style_embed), 2)
# encoder_seq_proj = self.encoder_proj(encoder_seq)
# # Need a couple of lists for outputs
# mel_outputs, attn_scores, stop_outputs = [], [], []
# # Run the decoder loop
# for t in range(0, steps, self.r):
# prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
# mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
# self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
# hidden_states, cell_states, context_vec, t, x)
# mel_outputs.append(mel_frames)
# attn_scores.append(scores)
# stop_outputs.extend([stop_tokens] * self.r)
# # Stop the loop when all stop tokens in batch exceed threshold
# if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
# # Concat the mel outputs into sequence
# mel_outputs = torch.cat(mel_outputs, dim=2)
# # Post-Process for Linear Spectrograms
# postnet_out = self.postnet(mel_outputs)
# linear = self.post_proj(postnet_out)
# linear = linear.transpose(1, 2)
# # For easy visualisation
# attn_scores = torch.cat(attn_scores, 1)
# stop_outputs = torch.cat(stop_outputs, 1)
# self.train()
# return mel_outputs, linear, attn_scores
def init_model(self): def init_model(self):
for p in self.parameters(): for p in self.parameters():

View File

@ -15,9 +15,8 @@ from datetime import datetime
import json import json
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import sys
import time import time
import os
def np_now(x: torch.Tensor): return x.detach().cpu().numpy() def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
@ -265,7 +264,19 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
loss=loss, loss=loss,
hparams=hparams, hparams=hparams,
sw=sw) sw=sw)
MAX_SAVED_COUNT = 20
if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT:
# clean up and save last MAX_SAVED_COUNT;
plots = next(os.walk(plot_dir), (None, None, []))[2]
for plot in plots[-MAX_SAVED_COUNT:]:
os.remove(plot_dir.joinpath(plot))
mel_files = next(os.walk(mel_output_dir), (None, None, []))[2]
for mel_file in mel_files[-MAX_SAVED_COUNT:]:
os.remove(mel_output_dir.joinpath(mel_file))
wavs = next(os.walk(wav_dir), (None, None, []))[2]
for w in wavs[-MAX_SAVED_COUNT:]:
os.remove(wav_dir.joinpath(w))
# Break out of loop to update training schedule # Break out of loop to update training schedule
if step >= max_step: if step >= max_step:
break break