mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
106 lines
3.5 KiB
Python
106 lines
3.5 KiB
Python
from pydantic import BaseModel, Field
|
||
import os
|
||
from pathlib import Path
|
||
from enum import Enum
|
||
from typing import Any
|
||
from synthesizer.hparams import hparams
|
||
from synthesizer.train import train as synt_train
|
||
|
||
# Constants
|
||
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
|
||
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||
|
||
|
||
# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
||
# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
||
# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
||
|
||
# Pre-Load models
|
||
if os.path.isdir(SYN_MODELS_DIRT):
|
||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||
else:
|
||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||
|
||
if os.path.isdir(ENC_MODELS_DIRT):
|
||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||
print("Loaded encoders models: " + str(len(encoders)))
|
||
else:
|
||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||
|
||
class Model(str, Enum):
|
||
DEFAULT = "default"
|
||
|
||
class Input(BaseModel):
|
||
model: Model = Field(
|
||
Model.DEFAULT, title="模型类型",
|
||
)
|
||
# datasets_root: str = Field(
|
||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||
# format=True,
|
||
# example="..\\trainning_data\\"
|
||
# )
|
||
input_root: str = Field(
|
||
..., alias="输入目录", description="预处理数据根目录",
|
||
format=True,
|
||
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
|
||
)
|
||
run_id: str = Field(
|
||
"", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练",
|
||
)
|
||
synthesizer: synthesizers = Field(
|
||
..., alias="已有合成模型",
|
||
description="选择语音合成模型文件."
|
||
)
|
||
gpu: bool = Field(
|
||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||
)
|
||
verbose: bool = Field(
|
||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||
)
|
||
encoder: encoders = Field(
|
||
..., alias="语音编码模型",
|
||
description="选择语音编码模型文件."
|
||
)
|
||
save_every: int = Field(
|
||
1000, alias="更新间隔", description="每隔n步则更新一次模型",
|
||
)
|
||
backup_every: int = Field(
|
||
10000, alias="保存间隔", description="每隔n步则保存一次模型",
|
||
)
|
||
log_every: int = Field(
|
||
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
|
||
)
|
||
|
||
class AudioEntity(BaseModel):
|
||
content: bytes
|
||
mel: Any
|
||
|
||
class Output(BaseModel):
|
||
__root__: int
|
||
|
||
def render_output_ui(self, streamlit_app) -> None: # type: ignore
|
||
"""Custom output UI.
|
||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||
"""
|
||
streamlit_app.subheader(f"Training started with code: {self.__root__}")
|
||
|
||
def train(input: Input) -> Output:
|
||
"""Train(训练)"""
|
||
|
||
print(">>> Start training ...")
|
||
force_restart = len(input.run_id) > 0
|
||
if not force_restart:
|
||
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
|
||
|
||
synt_train(
|
||
input.run_id,
|
||
input.input_root,
|
||
f"synthesizer{os.sep}saved_models",
|
||
input.save_every,
|
||
input.backup_every,
|
||
input.log_every,
|
||
force_restart,
|
||
hparams
|
||
)
|
||
return Output(__root__=0) |