from pydantic import BaseModel, Field import os from pathlib import Path from enum import Enum from typing import Any from models.synthesizer.hparams import hparams from models.synthesizer.train import train as synt_train # Constants SYN_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}synthesizer" ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder" # EXT_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg_extractor" # CONV_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}ppg2mel" # ENC_MODELS_DIRT = f"data{os.sep}ckpt{os.sep}encoder" # Pre-Load models if os.path.isdir(SYN_MODELS_DIRT): synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt"))) print("Loaded synthesizer models: " + str(len(synthesizers))) else: raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.") if os.path.isdir(ENC_MODELS_DIRT): encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) print("Loaded encoders models: " + str(len(encoders))) else: raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") class Model(str, Enum): DEFAULT = "default" class Input(BaseModel): model: Model = Field( Model.DEFAULT, title="模型类型", ) # datasets_root: str = Field( # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", # format=True, # example="..\\trainning_data\\" # ) input_root: str = Field( ..., alias="输入目录", description="预处理数据根目录", format=True, example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer" ) run_id: str = Field( "", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练", ) synthesizer: synthesizers = Field( ..., alias="已有合成模型", description="选择语音合成模型文件." ) gpu: bool = Field( True, alias="GPU训练", description="选择“是”,则使用GPU训练", ) verbose: bool = Field( True, alias="打印详情", description="选择“是”,输出更多详情", ) encoder: encoders = Field( ..., alias="语音编码模型", description="选择语音编码模型文件." ) save_every: int = Field( 1000, alias="更新间隔", description="每隔n步则更新一次模型", ) backup_every: int = Field( 10000, alias="保存间隔", description="每隔n步则保存一次模型", ) log_every: int = Field( 500, alias="打印间隔", description="每隔n步则打印一次训练统计", ) class AudioEntity(BaseModel): content: bytes mel: Any class Output(BaseModel): __root__: int def render_output_ui(self, streamlit_app) -> None: # type: ignore """Custom output UI. If this method is implmeneted, it will be used instead of the default Output UI renderer. """ streamlit_app.subheader(f"Training started with code: {self.__root__}") def train(input: Input) -> Output: """Train(训练)""" print(">>> Start training ...") force_restart = len(input.run_id) > 0 if not force_restart: input.run_id = Path(input.synthesizer.value).name.split('.')[0] synt_train( input.run_id, input.input_root, f"data{os.sep}ckpt{os.sep}synthesizer", input.save_every, input.backup_every, input.log_every, force_restart, hparams ) return Output(__root__=0)