2022-05-09 18:44:02 +08:00
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from enum import Enum
|
2022-06-26 23:21:32 +08:00
|
|
|
|
from typing import Any
|
|
|
|
|
from synthesizer.hparams import hparams
|
|
|
|
|
from synthesizer.train import train as synt_train
|
2022-05-09 18:44:02 +08:00
|
|
|
|
|
|
|
|
|
# Constants
|
2022-06-26 23:21:32 +08:00
|
|
|
|
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
|
2022-06-25 20:17:06 +08:00
|
|
|
|
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
2022-05-09 18:44:02 +08:00
|
|
|
|
|
|
|
|
|
|
2022-06-26 23:21:32 +08:00
|
|
|
|
# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
|
|
|
|
# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
|
|
|
|
# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
2022-05-09 18:44:02 +08:00
|
|
|
|
|
2022-06-26 23:21:32 +08:00
|
|
|
|
# Pre-Load models
|
|
|
|
|
if os.path.isdir(SYN_MODELS_DIRT):
|
|
|
|
|
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
|
|
|
|
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
2022-05-09 18:44:02 +08:00
|
|
|
|
else:
|
2022-06-26 23:21:32 +08:00
|
|
|
|
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
2022-05-09 18:44:02 +08:00
|
|
|
|
|
|
|
|
|
if os.path.isdir(ENC_MODELS_DIRT):
|
2022-06-26 23:21:32 +08:00
|
|
|
|
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
2022-05-09 18:44:02 +08:00
|
|
|
|
print("Loaded encoders models: " + str(len(encoders)))
|
|
|
|
|
else:
|
|
|
|
|
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
|
|
|
|
|
|
|
|
|
class Model(str, Enum):
|
2022-06-26 23:21:32 +08:00
|
|
|
|
DEFAULT = "default"
|
2022-05-09 18:44:02 +08:00
|
|
|
|
|
|
|
|
|
class Input(BaseModel):
|
|
|
|
|
model: Model = Field(
|
2022-06-26 23:21:32 +08:00
|
|
|
|
Model.DEFAULT, title="模型类型",
|
2022-05-09 18:44:02 +08:00
|
|
|
|
)
|
|
|
|
|
# datasets_root: str = Field(
|
|
|
|
|
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
|
|
|
|
# format=True,
|
|
|
|
|
# example="..\\trainning_data\\"
|
|
|
|
|
# )
|
2022-06-26 23:21:32 +08:00
|
|
|
|
input_root: str = Field(
|
|
|
|
|
..., alias="输入目录", description="预处理数据根目录",
|
2022-05-09 18:44:02 +08:00
|
|
|
|
format=True,
|
2022-06-26 23:21:32 +08:00
|
|
|
|
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
|
2022-05-09 18:44:02 +08:00
|
|
|
|
)
|
2022-06-26 23:21:32 +08:00
|
|
|
|
run_id: str = Field(
|
|
|
|
|
"", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练",
|
|
|
|
|
)
|
|
|
|
|
synthesizer: synthesizers = Field(
|
|
|
|
|
..., alias="已有合成模型",
|
|
|
|
|
description="选择语音合成模型文件."
|
2022-05-09 18:44:02 +08:00
|
|
|
|
)
|
|
|
|
|
gpu: bool = Field(
|
|
|
|
|
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
|
|
|
|
)
|
|
|
|
|
verbose: bool = Field(
|
|
|
|
|
True, alias="打印详情", description="选择“是”,输出更多详情",
|
|
|
|
|
)
|
|
|
|
|
encoder: encoders = Field(
|
|
|
|
|
..., alias="语音编码模型",
|
|
|
|
|
description="选择语音编码模型文件."
|
|
|
|
|
)
|
2022-06-26 23:21:32 +08:00
|
|
|
|
save_every: int = Field(
|
|
|
|
|
1000, alias="更新间隔", description="每隔n步则更新一次模型",
|
2022-05-09 18:44:02 +08:00
|
|
|
|
)
|
2022-06-26 23:21:32 +08:00
|
|
|
|
backup_every: int = Field(
|
|
|
|
|
10000, alias="保存间隔", description="每隔n步则保存一次模型",
|
2022-05-09 18:44:02 +08:00
|
|
|
|
)
|
2022-06-26 23:21:32 +08:00
|
|
|
|
log_every: int = Field(
|
|
|
|
|
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
|
2022-05-09 18:44:02 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
class AudioEntity(BaseModel):
|
|
|
|
|
content: bytes
|
|
|
|
|
mel: Any
|
|
|
|
|
|
|
|
|
|
class Output(BaseModel):
|
2022-06-26 23:21:32 +08:00
|
|
|
|
__root__: int
|
2022-05-09 18:44:02 +08:00
|
|
|
|
|
2022-06-26 23:21:32 +08:00
|
|
|
|
def render_output_ui(self, streamlit_app) -> None: # type: ignore
|
2022-05-09 18:44:02 +08:00
|
|
|
|
"""Custom output UI.
|
|
|
|
|
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
|
|
|
|
"""
|
2022-06-26 23:21:32 +08:00
|
|
|
|
streamlit_app.subheader(f"Training started with code: {self.__root__}")
|
2022-05-09 18:44:02 +08:00
|
|
|
|
|
|
|
|
|
def train(input: Input) -> Output:
|
|
|
|
|
"""Train(训练)"""
|
|
|
|
|
|
2022-06-26 23:21:32 +08:00
|
|
|
|
print(">>> Start training ...")
|
|
|
|
|
force_restart = len(input.run_id) > 0
|
|
|
|
|
if not force_restart:
|
|
|
|
|
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
|
|
|
|
|
|
|
|
|
|
synt_train(
|
|
|
|
|
input.run_id,
|
|
|
|
|
input.input_root,
|
|
|
|
|
f"synthesizer{os.sep}saved_models",
|
|
|
|
|
input.save_every,
|
|
|
|
|
input.backup_every,
|
|
|
|
|
input.log_every,
|
|
|
|
|
force_restart,
|
|
|
|
|
hparams
|
|
|
|
|
)
|
|
|
|
|
return Output(__root__=0)
|