From 96993a5c610f024ca1dccccc8e188319c75e1abb Mon Sep 17 00:00:00 2001 From: babysor00 Date: Tue, 3 May 2022 10:24:39 +0800 Subject: [PATCH] Add training mode --- mkgui/base/ui/streamlit_ui.py | 26 +++--- mkgui/train.py | 148 ++++++++++++++++++++++++++++++++++ ppg2mel/train.py | 11 +-- ppg2mel/train/solver.py | 3 +- utils/util.py | 6 ++ vocoder/hifigan/env.py | 7 -- vocoder/hifigan/inference.py | 2 +- vocoder/hifigan/train.py | 1 - vocoder_train.py | 2 +- 9 files changed, 178 insertions(+), 28 deletions(-) create mode 100644 mkgui/train.py diff --git a/mkgui/base/ui/streamlit_ui.py b/mkgui/base/ui/streamlit_ui.py index 08232f7..2e5159d 100644 --- a/mkgui/base/ui/streamlit_ui.py +++ b/mkgui/base/ui/streamlit_ui.py @@ -51,9 +51,13 @@ def launch_ui(port: int = 8501) -> None: python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"' if system() == "Windows": python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&" + subprocess.run( + f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""", + shell=True, + ) subprocess.run( - f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""", + f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""", shell=True, ) @@ -122,10 +126,11 @@ class InputUI: property["title"] = name_to_title(property_key) try: - self._store_value( - property_key, - self._render_property(streamlit_app_root, property_key, property), - ) + if "input_data" in self._session_state: + self._store_value( + property_key, + self._render_property(streamlit_app_root, property_key, property), + ) except Exception as e: print("Exception!", e) pass @@ -807,6 +812,9 @@ def getOpyrator(mode: str) -> Opyrator: if mode == None or mode.startswith('预处理'): from mkgui.preprocess import preprocess return Opyrator(preprocess) + if mode == None or mode.startswith('模型训练'): + from mkgui.train import train + return Opyrator(train) from mkgui.app import synthesize return Opyrator(synthesize) @@ -815,11 +823,13 @@ def render_streamlit_ui() -> None: # init session_state = st.session_state session_state.input_data = {} + # Add custom css settings + st.markdown(f"", unsafe_allow_html=True) with st.spinner("Loading MockingBird GUI. Please wait..."): session_state.mode = st.sidebar.selectbox( '模式选择', - ( "AI拟音", "VC拟音", "预处理") + ( "AI拟音", "VC拟音", "预处理", "模型训练") ) if "mode" in session_state: mode = session_state.mode @@ -872,6 +882,4 @@ def render_streamlit_ui() -> None: # placeholder st.caption("请使用左侧控制板进行输入并运行获得结果") - # Add custom css settings - st.markdown(f"", unsafe_allow_html=True) - + diff --git a/mkgui/train.py b/mkgui/train.py new file mode 100644 index 0000000..02287d7 --- /dev/null +++ b/mkgui/train.py @@ -0,0 +1,148 @@ +from pydantic import BaseModel, Field +import os +from pathlib import Path +from enum import Enum +from typing import Any +import numpy as np +from utils.load_yaml import HpsYaml +from utils.util import AttrDict +import torch + +# TODO: seperator for *unix systems +# Constants +EXT_MODELS_DIRT = "ppg_extractor\\saved_models" +CONV_MODELS_DIRT = "ppg2mel\\saved_models" +ENC_MODELS_DIRT = "encoder\\saved_models" + + +if os.path.isdir(EXT_MODELS_DIRT): + extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded extractor models: " + str(len(extractors))) +if os.path.isdir(CONV_MODELS_DIRT): + convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth"))) + print("Loaded convertor models: " + str(len(convertors))) +if os.path.isdir(ENC_MODELS_DIRT): + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded encoders models: " + str(len(encoders))) + +class Model(str, Enum): + VC_PPG2MEL = "ppg2mel" + +class Dataset(str, Enum): + AIDATATANG_200ZH = "aidatatang_200zh" + AIDATATANG_200ZH_S = "aidatatang_200zh_s" + +class Input(BaseModel): + # def render_input_ui(st, input) -> Dict: + # input["selected_dataset"] = st.selectbox( + # '选择数据集', + # ("aidatatang_200zh", "aidatatang_200zh_s") + # ) + # return input + model: Model = Field( + Model.VC_PPG2MEL, title="模型类型", + ) + # datasets_root: str = Field( + # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", + # format=True, + # example="..\\trainning_data\\" + # ) + output_root: str = Field( + ..., alias="输出目录(可选)", description="建议不填,保持默认", + format=True, + example="" + ) + continue_mode: bool = Field( + True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练", + ) + gpu: bool = Field( + True, alias="GPU训练", description="选择“是”,则使用GPU训练", + ) + verbose: bool = Field( + True, alias="打印详情", description="选择“是”,输出更多详情", + ) + # TODO: Move to hiden fields by default + convertor: convertors = Field( + ..., alias="转换模型", + description="选择语音转换模型文件." + ) + extractor: extractors = Field( + ..., alias="特征提取模型", + description="选择PPG特征提取模型文件." + ) + encoder: encoders = Field( + ..., alias="语音编码模型", + description="选择语音编码模型文件." + ) + njobs: int = Field( + 8, alias="进程数", description="适用于ppg2mel", + ) + seed: int = Field( + default=0, alias="初始随机数", description="适用于ppg2mel", + ) + model_name: str = Field( + ..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效", + example="test" + ) + model_config: str = Field( + ..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效", + example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2" + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: tuple[str, int] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + sr, count = self.__root__ + streamlit_app.subheader(f"Dataset {sr} done processed total of {count}") + +def train(input: Input) -> Output: + """Train(训练)""" + + print(">>> OneShot VC training ...") + params = AttrDict() + params.update({ + "gpu": input.gpu, + "cpu": not input.gpu, + "njobs": input.njobs, + "seed": input.seed, + "verbose": input.verbose, + "load": input.convertor.value, + "warm_start": False, + }) + if input.continue_mode: + # trace old model and config + p = Path(input.convertor.value) + params.name = p.parent.name + # search a config file + model_config_fpaths = list(p.parent.rglob("*.yaml")) + if len(model_config_fpaths) == 0: + raise "No model yaml config found for convertor" + config = HpsYaml(model_config_fpaths[0]) + params.ckpdir = p.parent.parent + params.config = model_config_fpaths[0] + params.logdir = os.path.join(p.parent, "log") + else: + # Make the config dict dot visitable + config = HpsYaml(input.config) + np.random.seed(input.seed) + torch.manual_seed(input.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(input.seed) + mode = "train" + from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + solver = Solver(config, params, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + + # TODO: pass useful return code + return Output(__root__=(input.dataset, 0)) \ No newline at end of file diff --git a/ppg2mel/train.py b/ppg2mel/train.py index fed7501..d3ef729 100644 --- a/ppg2mel/train.py +++ b/ppg2mel/train.py @@ -31,15 +31,10 @@ def main(): parser.add_argument('--njobs', default=8, type=int, help='Number of threads for dataloader/decoding.', required=False) parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') - parser.add_argument('--no-pin', action='store_true', - help='Disable pin-memory for dataloader') - parser.add_argument('--test', action='store_true', help='Test the model.') + # parser.add_argument('--no-pin', action='store_true', + # help='Disable pin-memory for dataloader') parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') - parser.add_argument('--finetune', action='store_true', help='Finetune model') - parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') - parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') - parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') - + ### paras = parser.parse_args() diff --git a/ppg2mel/train/solver.py b/ppg2mel/train/solver.py index 264a91c..9ca71cb 100644 --- a/ppg2mel/train/solver.py +++ b/ppg2mel/train/solver.py @@ -93,6 +93,7 @@ class BaseSolver(): def load_ckpt(self): ''' Load ckpt if --load option is specified ''' + print(self.paras) if self.paras.load is not None: if self.paras.warm_start: self.verbose(f"Warm starting model from checkpoint {self.paras.load}.") @@ -100,7 +101,7 @@ class BaseSolver(): self.paras.load, map_location=self.device if self.mode == 'train' else 'cpu') model_dict = ckpt['model'] - if len(self.config.model.ignore_layers) > 0: + if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0: model_dict = {k:v for k, v in model_dict.items() if k not in self.config.model.ignore_layers} dummy_dict = self.model.state_dict() diff --git a/utils/util.py b/utils/util.py index 5227538..34bcffd 100644 --- a/utils/util.py +++ b/utils/util.py @@ -42,3 +42,9 @@ def human_format(num): # add more suffixes if you need them return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude]) + +# provide easy access of attribute from dict, such abc.key +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/vocoder/hifigan/env.py b/vocoder/hifigan/env.py index 2bdbc95..8f0d306 100644 --- a/vocoder/hifigan/env.py +++ b/vocoder/hifigan/env.py @@ -1,13 +1,6 @@ import os import shutil - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - def build_env(config, config_name, path): t_path = os.path.join(path, config_name) if config != t_path: diff --git a/vocoder/hifigan/inference.py b/vocoder/hifigan/inference.py index 423cbc6..8caf348 100644 --- a/vocoder/hifigan/inference.py +++ b/vocoder/hifigan/inference.py @@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import json import torch -from vocoder.hifigan.env import AttrDict +from utils.util import AttrDict from vocoder.hifigan.models import Generator generator = None # type: Generator diff --git a/vocoder/hifigan/train.py b/vocoder/hifigan/train.py index 987bcca..8760274 100644 --- a/vocoder/hifigan/train.py +++ b/vocoder/hifigan/train.py @@ -12,7 +12,6 @@ from torch.utils.data import DistributedSampler, DataLoader import torch.multiprocessing as mp from torch.distributed import init_process_group from torch.nn.parallel import DistributedDataParallel -from vocoder.hifigan.env import AttrDict, build_env from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\ discriminator_loss diff --git a/vocoder_train.py b/vocoder_train.py index d3ad0f5..1ef0e30 100644 --- a/vocoder_train.py +++ b/vocoder_train.py @@ -1,7 +1,7 @@ from utils.argutils import print_args from vocoder.wavernn.train import train from vocoder.hifigan.train import train as train_hifigan -from vocoder.hifigan.env import AttrDict +from utils.util import AttrDict from pathlib import Path import argparse import json