mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Add web gui of training and reconstruct taco model methods
This commit is contained in:
parent
a39b6d3117
commit
6f023e313d
|
@ -815,6 +815,9 @@ def getOpyrator(mode: str) -> Opyrator:
|
||||||
if mode == None or mode.startswith('模型训练'):
|
if mode == None or mode.startswith('模型训练'):
|
||||||
from mkgui.train import train
|
from mkgui.train import train
|
||||||
return Opyrator(train)
|
return Opyrator(train)
|
||||||
|
if mode == None or mode.startswith('模型训练(VC)'):
|
||||||
|
from mkgui.train_vc import train_vc
|
||||||
|
return Opyrator(train_vc)
|
||||||
from mkgui.app import synthesize
|
from mkgui.app import synthesize
|
||||||
return Opyrator(synthesize)
|
return Opyrator(synthesize)
|
||||||
|
|
||||||
|
@ -829,7 +832,7 @@ def render_streamlit_ui() -> None:
|
||||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||||
session_state.mode = st.sidebar.selectbox(
|
session_state.mode = st.sidebar.selectbox(
|
||||||
'模式选择',
|
'模式选择',
|
||||||
( "AI拟音", "VC拟音", "预处理", "模型训练")
|
( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
|
||||||
)
|
)
|
||||||
if "mode" in session_state:
|
if "mode" in session_state:
|
||||||
mode = session_state.mode
|
mode = session_state.mode
|
||||||
|
|
148
mkgui/train.py
148
mkgui/train.py
|
@ -2,66 +2,55 @@ from pydantic import BaseModel, Field
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Tuple
|
from typing import Any
|
||||||
import numpy as np
|
from synthesizer.hparams import hparams
|
||||||
from utils.load_yaml import HpsYaml
|
from synthesizer.train import train as synt_train
|
||||||
from utils.util import AttrDict
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# TODO: seperator for *unix systems
|
|
||||||
# Constants
|
# Constants
|
||||||
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
|
||||||
CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
|
||||||
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||||||
|
|
||||||
|
|
||||||
if os.path.isdir(EXT_MODELS_DIRT):
|
# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
||||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
||||||
print("Loaded extractor models: " + str(len(extractors)))
|
# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||||||
else:
|
|
||||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
|
||||||
|
|
||||||
if os.path.isdir(CONV_MODELS_DIRT):
|
# Pre-Load models
|
||||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
if os.path.isdir(SYN_MODELS_DIRT):
|
||||||
print("Loaded convertor models: " + str(len(convertors)))
|
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
if os.path.isdir(ENC_MODELS_DIRT):
|
if os.path.isdir(ENC_MODELS_DIRT):
|
||||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||||
print("Loaded encoders models: " + str(len(encoders)))
|
print("Loaded encoders models: " + str(len(encoders)))
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
class Model(str, Enum):
|
class Model(str, Enum):
|
||||||
VC_PPG2MEL = "ppg2mel"
|
DEFAULT = "default"
|
||||||
|
|
||||||
class Dataset(str, Enum):
|
|
||||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
|
||||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
|
||||||
|
|
||||||
class Input(BaseModel):
|
class Input(BaseModel):
|
||||||
# def render_input_ui(st, input) -> Dict:
|
|
||||||
# input["selected_dataset"] = st.selectbox(
|
|
||||||
# '选择数据集',
|
|
||||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
|
||||||
# )
|
|
||||||
# return input
|
|
||||||
model: Model = Field(
|
model: Model = Field(
|
||||||
Model.VC_PPG2MEL, title="模型类型",
|
Model.DEFAULT, title="模型类型",
|
||||||
)
|
)
|
||||||
# datasets_root: str = Field(
|
# datasets_root: str = Field(
|
||||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||||
# format=True,
|
# format=True,
|
||||||
# example="..\\trainning_data\\"
|
# example="..\\trainning_data\\"
|
||||||
# )
|
# )
|
||||||
output_root: str = Field(
|
input_root: str = Field(
|
||||||
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
..., alias="输入目录", description="预处理数据根目录",
|
||||||
format=True,
|
format=True,
|
||||||
example=""
|
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
|
||||||
)
|
)
|
||||||
continue_mode: bool = Field(
|
run_id: str = Field(
|
||||||
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
"", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练",
|
||||||
|
)
|
||||||
|
synthesizer: synthesizers = Field(
|
||||||
|
..., alias="已有合成模型",
|
||||||
|
description="选择语音合成模型文件."
|
||||||
)
|
)
|
||||||
gpu: bool = Field(
|
gpu: bool = Field(
|
||||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||||
|
@ -69,32 +58,18 @@ class Input(BaseModel):
|
||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||||
)
|
)
|
||||||
# TODO: Move to hiden fields by default
|
|
||||||
convertor: convertors = Field(
|
|
||||||
..., alias="转换模型",
|
|
||||||
description="选择语音转换模型文件."
|
|
||||||
)
|
|
||||||
extractor: extractors = Field(
|
|
||||||
..., alias="特征提取模型",
|
|
||||||
description="选择PPG特征提取模型文件."
|
|
||||||
)
|
|
||||||
encoder: encoders = Field(
|
encoder: encoders = Field(
|
||||||
..., alias="语音编码模型",
|
..., alias="语音编码模型",
|
||||||
description="选择语音编码模型文件."
|
description="选择语音编码模型文件."
|
||||||
)
|
)
|
||||||
njobs: int = Field(
|
save_every: int = Field(
|
||||||
8, alias="进程数", description="适用于ppg2mel",
|
1000, alias="更新间隔", description="每隔n步则更新一次模型",
|
||||||
)
|
)
|
||||||
seed: int = Field(
|
backup_every: int = Field(
|
||||||
default=0, alias="初始随机数", description="适用于ppg2mel",
|
10000, alias="保存间隔", description="每隔n步则保存一次模型",
|
||||||
)
|
)
|
||||||
model_name: str = Field(
|
log_every: int = Field(
|
||||||
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
|
||||||
example="test"
|
|
||||||
)
|
|
||||||
model_config: str = Field(
|
|
||||||
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
|
||||||
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
class AudioEntity(BaseModel):
|
class AudioEntity(BaseModel):
|
||||||
|
@ -102,55 +77,30 @@ class AudioEntity(BaseModel):
|
||||||
mel: Any
|
mel: Any
|
||||||
|
|
||||||
class Output(BaseModel):
|
class Output(BaseModel):
|
||||||
__root__: Tuple[str, int]
|
__root__: int
|
||||||
|
|
||||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
def render_output_ui(self, streamlit_app) -> None: # type: ignore
|
||||||
"""Custom output UI.
|
"""Custom output UI.
|
||||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||||
"""
|
"""
|
||||||
sr, count = self.__root__
|
streamlit_app.subheader(f"Training started with code: {self.__root__}")
|
||||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
|
||||||
|
|
||||||
def train(input: Input) -> Output:
|
def train(input: Input) -> Output:
|
||||||
"""Train(训练)"""
|
"""Train(训练)"""
|
||||||
|
|
||||||
print(">>> OneShot VC training ...")
|
print(">>> Start training ...")
|
||||||
params = AttrDict()
|
force_restart = len(input.run_id) > 0
|
||||||
params.update({
|
if not force_restart:
|
||||||
"gpu": input.gpu,
|
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
|
||||||
"cpu": not input.gpu,
|
|
||||||
"njobs": input.njobs,
|
synt_train(
|
||||||
"seed": input.seed,
|
input.run_id,
|
||||||
"verbose": input.verbose,
|
input.input_root,
|
||||||
"load": input.convertor.value,
|
f"synthesizer{os.sep}saved_models",
|
||||||
"warm_start": False,
|
input.save_every,
|
||||||
})
|
input.backup_every,
|
||||||
if input.continue_mode:
|
input.log_every,
|
||||||
# trace old model and config
|
force_restart,
|
||||||
p = Path(input.convertor.value)
|
hparams
|
||||||
params.name = p.parent.name
|
)
|
||||||
# search a config file
|
return Output(__root__=0)
|
||||||
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
|
||||||
if len(model_config_fpaths) == 0:
|
|
||||||
raise "No model yaml config found for convertor"
|
|
||||||
config = HpsYaml(model_config_fpaths[0])
|
|
||||||
params.ckpdir = p.parent.parent
|
|
||||||
params.config = model_config_fpaths[0]
|
|
||||||
params.logdir = os.path.join(p.parent, "log")
|
|
||||||
else:
|
|
||||||
# Make the config dict dot visitable
|
|
||||||
config = HpsYaml(input.config)
|
|
||||||
np.random.seed(input.seed)
|
|
||||||
torch.manual_seed(input.seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(input.seed)
|
|
||||||
mode = "train"
|
|
||||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
|
||||||
solver = Solver(config, params, mode)
|
|
||||||
solver.load_data()
|
|
||||||
solver.set_model()
|
|
||||||
solver.exec()
|
|
||||||
print(">>> Oneshot VC train finished!")
|
|
||||||
|
|
||||||
# TODO: pass useful return code
|
|
||||||
return Output(__root__=(input.dataset, 0))
|
|
155
mkgui/train_vc.py
Normal file
155
mkgui/train_vc.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Tuple
|
||||||
|
import numpy as np
|
||||||
|
from utils.load_yaml import HpsYaml
|
||||||
|
from utils.util import AttrDict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
||||||
|
CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
||||||
|
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||||||
|
|
||||||
|
|
||||||
|
if os.path.isdir(EXT_MODELS_DIRT):
|
||||||
|
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded extractor models: " + str(len(extractors)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(CONV_MODELS_DIRT):
|
||||||
|
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||||
|
print("Loaded convertor models: " + str(len(convertors)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
if os.path.isdir(ENC_MODELS_DIRT):
|
||||||
|
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||||
|
print("Loaded encoders models: " + str(len(encoders)))
|
||||||
|
else:
|
||||||
|
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||||
|
|
||||||
|
class Model(str, Enum):
|
||||||
|
VC_PPG2MEL = "ppg2mel"
|
||||||
|
|
||||||
|
class Dataset(str, Enum):
|
||||||
|
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||||
|
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||||
|
|
||||||
|
class Input(BaseModel):
|
||||||
|
# def render_input_ui(st, input) -> Dict:
|
||||||
|
# input["selected_dataset"] = st.selectbox(
|
||||||
|
# '选择数据集',
|
||||||
|
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||||
|
# )
|
||||||
|
# return input
|
||||||
|
model: Model = Field(
|
||||||
|
Model.VC_PPG2MEL, title="模型类型",
|
||||||
|
)
|
||||||
|
# datasets_root: str = Field(
|
||||||
|
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||||
|
# format=True,
|
||||||
|
# example="..\\trainning_data\\"
|
||||||
|
# )
|
||||||
|
output_root: str = Field(
|
||||||
|
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
||||||
|
format=True,
|
||||||
|
example=""
|
||||||
|
)
|
||||||
|
continue_mode: bool = Field(
|
||||||
|
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
||||||
|
)
|
||||||
|
gpu: bool = Field(
|
||||||
|
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||||
|
)
|
||||||
|
verbose: bool = Field(
|
||||||
|
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||||
|
)
|
||||||
|
# TODO: Move to hiden fields by default
|
||||||
|
convertor: convertors = Field(
|
||||||
|
..., alias="转换模型",
|
||||||
|
description="选择语音转换模型文件."
|
||||||
|
)
|
||||||
|
extractor: extractors = Field(
|
||||||
|
..., alias="特征提取模型",
|
||||||
|
description="选择PPG特征提取模型文件."
|
||||||
|
)
|
||||||
|
encoder: encoders = Field(
|
||||||
|
..., alias="语音编码模型",
|
||||||
|
description="选择语音编码模型文件."
|
||||||
|
)
|
||||||
|
njobs: int = Field(
|
||||||
|
8, alias="进程数", description="适用于ppg2mel",
|
||||||
|
)
|
||||||
|
seed: int = Field(
|
||||||
|
default=0, alias="初始随机数", description="适用于ppg2mel",
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
||||||
|
example="test"
|
||||||
|
)
|
||||||
|
model_config: str = Field(
|
||||||
|
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
||||||
|
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
class AudioEntity(BaseModel):
|
||||||
|
content: bytes
|
||||||
|
mel: Any
|
||||||
|
|
||||||
|
class Output(BaseModel):
|
||||||
|
__root__: Tuple[str, int]
|
||||||
|
|
||||||
|
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||||
|
"""Custom output UI.
|
||||||
|
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||||
|
"""
|
||||||
|
sr, count = self.__root__
|
||||||
|
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||||
|
|
||||||
|
def train_vc(input: Input) -> Output:
|
||||||
|
"""Train VC(训练 VC)"""
|
||||||
|
|
||||||
|
print(">>> OneShot VC training ...")
|
||||||
|
params = AttrDict()
|
||||||
|
params.update({
|
||||||
|
"gpu": input.gpu,
|
||||||
|
"cpu": not input.gpu,
|
||||||
|
"njobs": input.njobs,
|
||||||
|
"seed": input.seed,
|
||||||
|
"verbose": input.verbose,
|
||||||
|
"load": input.convertor.value,
|
||||||
|
"warm_start": False,
|
||||||
|
})
|
||||||
|
if input.continue_mode:
|
||||||
|
# trace old model and config
|
||||||
|
p = Path(input.convertor.value)
|
||||||
|
params.name = p.parent.name
|
||||||
|
# search a config file
|
||||||
|
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
||||||
|
if len(model_config_fpaths) == 0:
|
||||||
|
raise "No model yaml config found for convertor"
|
||||||
|
config = HpsYaml(model_config_fpaths[0])
|
||||||
|
params.ckpdir = p.parent.parent
|
||||||
|
params.config = model_config_fpaths[0]
|
||||||
|
params.logdir = os.path.join(p.parent, "log")
|
||||||
|
else:
|
||||||
|
# Make the config dict dot visitable
|
||||||
|
config = HpsYaml(input.config)
|
||||||
|
np.random.seed(input.seed)
|
||||||
|
torch.manual_seed(input.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(input.seed)
|
||||||
|
mode = "train"
|
||||||
|
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||||
|
solver = Solver(config, params, mode)
|
||||||
|
solver.load_data()
|
||||||
|
solver.set_model()
|
||||||
|
solver.exec()
|
||||||
|
print(">>> Oneshot VC train finished!")
|
||||||
|
|
||||||
|
# TODO: pass useful return code
|
||||||
|
return Output(__root__=(input.dataset, 0))
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from matplotlib.pyplot import step
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -297,7 +298,7 @@ class Decoder(nn.Module):
|
||||||
x = torch.cat([context_vec, attn_hidden], dim=1)
|
x = torch.cat([context_vec, attn_hidden], dim=1)
|
||||||
x = self.rnn_input(x)
|
x = self.rnn_input(x)
|
||||||
|
|
||||||
# Compute first Residual RNN
|
# Compute first Residual RNN, training with fixed zoneout rate 0.1
|
||||||
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
||||||
if self.training:
|
if self.training:
|
||||||
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
|
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device)
|
||||||
|
@ -372,11 +373,15 @@ class Tacotron(nn.Module):
|
||||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def forward(self, texts, mels, speaker_embedding):
|
def forward(self, texts, mels, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
||||||
|
|
||||||
device = texts.device # use same device as parameters
|
device = texts.device # use same device as parameters
|
||||||
|
|
||||||
self.step += 1
|
if self.training:
|
||||||
batch_size, _, steps = mels.size()
|
self.step += 1
|
||||||
|
batch_size, _, steps = mels.size()
|
||||||
|
else:
|
||||||
|
batch_size, _ = texts.size()
|
||||||
|
|
||||||
# Initialise all hidden states and pack into tuple
|
# Initialise all hidden states and pack into tuple
|
||||||
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||||
|
@ -401,11 +406,22 @@ class Tacotron(nn.Module):
|
||||||
# SV2TTS: Run the encoder with the speaker embedding
|
# SV2TTS: Run the encoder with the speaker embedding
|
||||||
# The projection avoids unnecessary matmuls in the decoder loop
|
# The projection avoids unnecessary matmuls in the decoder loop
|
||||||
encoder_seq = self.encoder(texts, speaker_embedding)
|
encoder_seq = self.encoder(texts, speaker_embedding)
|
||||||
# put after encoder
|
|
||||||
if hparams.use_gst and self.gst is not None:
|
if hparams.use_gst and self.gst is not None:
|
||||||
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
if self.training:
|
||||||
# style_embed = style_embed.expand_as(encoder_seq)
|
style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced
|
||||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
# style_embed = style_embed.expand_as(encoder_seq)
|
||||||
|
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||||
|
elif style_idx >= 0 and style_idx < 10:
|
||||||
|
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
query = query.cuda()
|
||||||
|
gst_embed = torch.tanh(self.gst.stl.embed)
|
||||||
|
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
style_embed = self.gst.stl.attention(query, key)
|
||||||
|
else:
|
||||||
|
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||||
|
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
|
@ -414,13 +430,17 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
# Run the decoder loop
|
# Run the decoder loop
|
||||||
for t in range(0, steps, self.r):
|
for t in range(0, steps, self.r):
|
||||||
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
|
if self.training:
|
||||||
|
prenet_in = mels[:, :, t -1] if t > 0 else go_frame
|
||||||
|
else:
|
||||||
|
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
||||||
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||||
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||||
hidden_states, cell_states, context_vec, t, texts)
|
hidden_states, cell_states, context_vec, t, texts)
|
||||||
mel_outputs.append(mel_frames)
|
mel_outputs.append(mel_frames)
|
||||||
attn_scores.append(scores)
|
attn_scores.append(scores)
|
||||||
stop_outputs.extend([stop_tokens] * self.r)
|
stop_outputs.extend([stop_tokens] * self.r)
|
||||||
|
if not self.training and (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||||
|
|
||||||
# Concat the mel outputs into sequence
|
# Concat the mel outputs into sequence
|
||||||
mel_outputs = torch.cat(mel_outputs, dim=2)
|
mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||||
|
@ -435,87 +455,93 @@ class Tacotron(nn.Module):
|
||||||
# attn_scores = attn_scores.cpu().data.numpy()
|
# attn_scores = attn_scores.cpu().data.numpy()
|
||||||
stop_outputs = torch.cat(stop_outputs, 1)
|
stop_outputs = torch.cat(stop_outputs, 1)
|
||||||
|
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
self.train()
|
||||||
|
|
||||||
return mel_outputs, linear, attn_scores, stop_outputs
|
return mel_outputs, linear, attn_scores, stop_outputs
|
||||||
|
|
||||||
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5):
|
||||||
self.eval()
|
self.eval()
|
||||||
device = x.device # use same device as parameters
|
mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token)
|
||||||
|
|
||||||
batch_size, _ = x.size()
|
|
||||||
|
|
||||||
# Need to initialise all hidden states and pack into tuple for tidyness
|
|
||||||
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
|
||||||
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
|
||||||
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
|
||||||
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
|
||||||
|
|
||||||
# Need to initialise all lstm cell states and pack into tuple for tidyness
|
|
||||||
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
|
||||||
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
|
||||||
cell_states = (rnn1_cell, rnn2_cell)
|
|
||||||
|
|
||||||
# Need a <GO> Frame for start of decoder loop
|
|
||||||
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
|
||||||
|
|
||||||
# Need an initial context vector
|
|
||||||
size = self.encoder_dims + self.speaker_embedding_size
|
|
||||||
if hparams.use_gst:
|
|
||||||
size += gst_hp.E
|
|
||||||
context_vec = torch.zeros(batch_size, size, device=device)
|
|
||||||
|
|
||||||
# SV2TTS: Run the encoder with the speaker embedding
|
|
||||||
# The projection avoids unnecessary matmuls in the decoder loop
|
|
||||||
encoder_seq = self.encoder(x, speaker_embedding)
|
|
||||||
|
|
||||||
# put after encoder
|
|
||||||
if hparams.use_gst and self.gst is not None:
|
|
||||||
if style_idx >= 0 and style_idx < 10:
|
|
||||||
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
|
||||||
if device.type == 'cuda':
|
|
||||||
query = query.cuda()
|
|
||||||
gst_embed = torch.tanh(self.gst.stl.embed)
|
|
||||||
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
|
||||||
style_embed = self.gst.stl.attention(query, key)
|
|
||||||
else:
|
|
||||||
speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
|
||||||
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
|
||||||
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
|
||||||
# style_embed = style_embed.expand_as(encoder_seq)
|
|
||||||
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
|
||||||
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
|
||||||
|
|
||||||
# Need a couple of lists for outputs
|
|
||||||
mel_outputs, attn_scores, stop_outputs = [], [], []
|
|
||||||
|
|
||||||
# Run the decoder loop
|
|
||||||
for t in range(0, steps, self.r):
|
|
||||||
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
|
||||||
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
|
||||||
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
|
||||||
hidden_states, cell_states, context_vec, t, x)
|
|
||||||
mel_outputs.append(mel_frames)
|
|
||||||
attn_scores.append(scores)
|
|
||||||
stop_outputs.extend([stop_tokens] * self.r)
|
|
||||||
# Stop the loop when all stop tokens in batch exceed threshold
|
|
||||||
if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
|
||||||
|
|
||||||
# Concat the mel outputs into sequence
|
|
||||||
mel_outputs = torch.cat(mel_outputs, dim=2)
|
|
||||||
|
|
||||||
# Post-Process for Linear Spectrograms
|
|
||||||
postnet_out = self.postnet(mel_outputs)
|
|
||||||
linear = self.post_proj(postnet_out)
|
|
||||||
|
|
||||||
|
|
||||||
linear = linear.transpose(1, 2)
|
|
||||||
|
|
||||||
# For easy visualisation
|
|
||||||
attn_scores = torch.cat(attn_scores, 1)
|
|
||||||
stop_outputs = torch.cat(stop_outputs, 1)
|
|
||||||
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
return mel_outputs, linear, attn_scores
|
return mel_outputs, linear, attn_scores
|
||||||
|
# device = x.device # use same device as parameters
|
||||||
|
|
||||||
|
# batch_size, _ = x.size()
|
||||||
|
|
||||||
|
# # Need to initialise all hidden states and pack into tuple for tidyness
|
||||||
|
# attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
||||||
|
# rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||||
|
# rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||||
|
# hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
||||||
|
|
||||||
|
# # Need to initialise all lstm cell states and pack into tuple for tidyness
|
||||||
|
# rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||||
|
# rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
||||||
|
# cell_states = (rnn1_cell, rnn2_cell)
|
||||||
|
|
||||||
|
# # Need a <GO> Frame for start of decoder loop
|
||||||
|
# go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
||||||
|
|
||||||
|
# # Need an initial context vector
|
||||||
|
# size = self.encoder_dims + self.speaker_embedding_size
|
||||||
|
# if hparams.use_gst:
|
||||||
|
# size += gst_hp.E
|
||||||
|
# context_vec = torch.zeros(batch_size, size, device=device)
|
||||||
|
|
||||||
|
# # SV2TTS: Run the encoder with the speaker embedding
|
||||||
|
# # The projection avoids unnecessary matmuls in the decoder loop
|
||||||
|
# encoder_seq = self.encoder(x, speaker_embedding)
|
||||||
|
|
||||||
|
# # put after encoder
|
||||||
|
# if hparams.use_gst and self.gst is not None:
|
||||||
|
# if style_idx >= 0 and style_idx < 10:
|
||||||
|
# query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
||||||
|
# if device.type == 'cuda':
|
||||||
|
# query = query.cuda()
|
||||||
|
# gst_embed = torch.tanh(self.gst.stl.embed)
|
||||||
|
# key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
# style_embed = self.gst.stl.attention(query, key)
|
||||||
|
# else:
|
||||||
|
# speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device)
|
||||||
|
# style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
||||||
|
# encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
||||||
|
# # style_embed = style_embed.expand_as(encoder_seq)
|
||||||
|
# # encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
||||||
|
# encoder_seq_proj = self.encoder_proj(encoder_seq)
|
||||||
|
|
||||||
|
# # Need a couple of lists for outputs
|
||||||
|
# mel_outputs, attn_scores, stop_outputs = [], [], []
|
||||||
|
|
||||||
|
# # Run the decoder loop
|
||||||
|
# for t in range(0, steps, self.r):
|
||||||
|
# prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
||||||
|
# mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
||||||
|
# self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
||||||
|
# hidden_states, cell_states, context_vec, t, x)
|
||||||
|
# mel_outputs.append(mel_frames)
|
||||||
|
# attn_scores.append(scores)
|
||||||
|
# stop_outputs.extend([stop_tokens] * self.r)
|
||||||
|
# # Stop the loop when all stop tokens in batch exceed threshold
|
||||||
|
# if (stop_tokens * 10 > min_stop_token).all() and t > 10: break
|
||||||
|
|
||||||
|
# # Concat the mel outputs into sequence
|
||||||
|
# mel_outputs = torch.cat(mel_outputs, dim=2)
|
||||||
|
|
||||||
|
# # Post-Process for Linear Spectrograms
|
||||||
|
# postnet_out = self.postnet(mel_outputs)
|
||||||
|
# linear = self.post_proj(postnet_out)
|
||||||
|
|
||||||
|
|
||||||
|
# linear = linear.transpose(1, 2)
|
||||||
|
|
||||||
|
# # For easy visualisation
|
||||||
|
# attn_scores = torch.cat(attn_scores, 1)
|
||||||
|
# stop_outputs = torch.cat(stop_outputs, 1)
|
||||||
|
|
||||||
|
# self.train()
|
||||||
|
|
||||||
|
# return mel_outputs, linear, attn_scores
|
||||||
|
|
||||||
def init_model(self):
|
def init_model(self):
|
||||||
for p in self.parameters():
|
for p in self.parameters():
|
||||||
|
|
|
@ -15,9 +15,8 @@ from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
||||||
|
|
||||||
|
@ -265,7 +264,19 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
hparams=hparams,
|
hparams=hparams,
|
||||||
sw=sw)
|
sw=sw)
|
||||||
|
MAX_SAVED_COUNT = 20
|
||||||
|
if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT:
|
||||||
|
# clean up and save last MAX_SAVED_COUNT;
|
||||||
|
plots = next(os.walk(plot_dir), (None, None, []))[2]
|
||||||
|
for plot in plots[-MAX_SAVED_COUNT:]:
|
||||||
|
os.remove(plot_dir.joinpath(plot))
|
||||||
|
mel_files = next(os.walk(mel_output_dir), (None, None, []))[2]
|
||||||
|
for mel_file in mel_files[-MAX_SAVED_COUNT:]:
|
||||||
|
os.remove(mel_output_dir.joinpath(mel_file))
|
||||||
|
wavs = next(os.walk(wav_dir), (None, None, []))[2]
|
||||||
|
for w in wavs[-MAX_SAVED_COUNT:]:
|
||||||
|
os.remove(wav_dir.joinpath(w))
|
||||||
|
|
||||||
# Break out of loop to update training schedule
|
# Break out of loop to update training schedule
|
||||||
if step >= max_step:
|
if step >= max_step:
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue
Block a user