diff --git a/.gitignore b/.gitignore index 7df88c7..5302980 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,6 @@ *.bbl *.bcf *.toc -*.wav *.sh */saved_models !vocoder/saved_models/pretrained/** diff --git a/.vscode/launch.json b/.vscode/launch.json index 23e5203..85cf175 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -61,5 +61,13 @@ "-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\" ] }, + { + "name": "GUI", + "type": "python", + "request": "launch", + "program": "mkgui\\base\\_cli.py", + "console": "integratedTerminal", + "args": [] + }, ] } diff --git a/README-CN.md b/README-CN.md index d0bfffa..038deb5 100644 --- a/README-CN.md +++ b/README-CN.md @@ -18,6 +18,15 @@ 🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用 +### 进行中的工作 +* GUI/客户端大升级与合并 +[X] 初始化框架 `./mkgui` (基于streamlit + fastapi)和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee) +[X] 增加 Voice Cloning and Conversion的演示页面 +[X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面 +[ ] 增加其他的的预处理preprocessing 和训练 training 页面 +* 模型后端基于ESPnet2升级 + + ## 开始 ### 1. 安装要求 > 按照原始存储库测试您是否已准备好所有环境。 @@ -82,15 +91,10 @@ ### 3. 启动程序或工具箱 您可以尝试使用以下命令: -### 3.1 启动Web程序: +### 3.1 启动Web程序(v2): `python web.py` 运行成功后在浏览器打开地址, 默认为 `http://localhost:8080` -![123](https://user-images.githubusercontent.com/12797292/135494044-ae59181c-fe3a-406f-9c7d-d21d12fdb4cb.png) -> 注:目前界面比较buggy, -> * 第一次点击`录制`要等待几秒浏览器正常启动录音,否则会有重音 -> * 录制结束不要再点`录制`而是`停止` > * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒 -> * 默认使用第一个找到的模型,有动手能力的可以看代码修改 `web\__init__.py`。 ### 3.2 启动工具箱: `python demo_toolbox.py -d ` diff --git a/README.md b/README.md index 9bfcd78..443dcf0 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,14 @@ ### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) +### Ongoing Works(Helps Needed) +* Major upgrade on GUI/Client and unifying web and toolbox +[X] Init framework `./mkgui` and [tech design](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee) +[X] Add demo part of Voice Cloning and Conversion +[X] Add preprocessing and training for Voice Conversion +[ ] Add preprocessing and training for Encoder/Synthesizer/Vocoder +* Major upgrade on model backend based on ESPnet2(not yet started) + ## Quick Start ### 1. Install Requirements diff --git a/gui/___init__.py b/mkgui/__init__.py similarity index 100% rename from gui/___init__.py rename to mkgui/__init__.py diff --git a/gui/app.py b/mkgui/app.py similarity index 68% rename from gui/app.py rename to mkgui/app.py index d753126..9487284 100644 --- a/gui/app.py +++ b/mkgui/app.py @@ -8,9 +8,11 @@ import librosa from scipy.io.wavfile import write import re import numpy as np -from opyrator.components.types import FileContent +from mkgui.base.components.types import FileContent from vocoder.hifigan import inference as gan_vocoder from synthesizer.inference import Synthesizer +from typing import Any +import matplotlib.pyplot as plt # Constants AUDIO_SAMPLES_DIR = 'samples\\' @@ -27,20 +29,31 @@ if os.path.isdir(AUDIO_SAMPLES_DIR): if os.path.isdir(SYN_MODELS_DIRT): synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt"))) print("Loaded synthesizer models: " + str(len(synthesizers))) +else: + raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.") + if os.path.isdir(ENC_MODELS_DIRT): encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) print("Loaded encoders models: " + str(len(encoders))) +else: + raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") + if os.path.isdir(VOC_MODELS_DIRT): vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt"))) print("Loaded vocoders models: " + str(len(synthesizers))) +else: + raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.") class Input(BaseModel): + message: str = Field( + ..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容" + ) local_audio_file: audio_input_selection = Field( ..., alias="输入语音(本地wav)", description="选择本地语音文件." ) - upload_audio_file: FileContent = Field(..., alias="或上传语音", + upload_audio_file: FileContent = Field(default=None, alias="或上传语音", description="拖拽或点击上传.", mime_type="audio/wav") encoder: encoders = Field( ..., alias="编码模型", @@ -48,37 +61,48 @@ class Input(BaseModel): ) synthesizer: synthesizers = Field( ..., alias="合成模型", - description="选择语音编码模型文件." + description="选择语音合成模型文件." ) vocoder: vocoders = Field( - ..., alias="语音编码模型", - description="选择语音编码模型文件(目前只支持HifiGan类型)." - ) - message: str = Field( - ..., example="欢迎使用工具箱, 现已支持中文输入!", alias="输出文本内容" + ..., alias="语音解码模型", + description="选择语音解码模型文件(目前只支持HifiGan类型)." ) +class AudioEntity(BaseModel): + content: bytes + mel: Any + class Output(BaseModel): - result_file: FileContent = Field( - ..., - mime_type="audio/wav", - description="输出音频", - ) - source_file: FileContent = Field( - ..., - mime_type="audio/wav", - description="原始音频.", - ) + __root__: tuple[AudioEntity, AudioEntity] -def mocking_bird(input: Input) -> Output: - """欢迎使用MockingBird Web 2""" + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + src, result = self.__root__ + + streamlit_app.subheader("Synthesized Audio") + streamlit_app.audio(result.content, format="audio/wav") + + fig, ax = plt.subplots() + ax.imshow(src.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Source Audio)") + streamlit_app.pyplot(fig) + fig, ax = plt.subplots() + ax.imshow(result.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Result Audio)") + streamlit_app.pyplot(fig) + + +def synthesize(input: Input) -> Output: + """synthesize(合成)""" # load models encoder.load_model(Path(input.encoder.value)) current_synt = Synthesizer(Path(input.synthesizer.value)) gan_vocoder.load_model(Path(input.vocoder.value)) # load file - if input.upload_audio_file != NULL: + if input.upload_audio_file != None: with open(TEMP_SOURCE_AUDIO, "w+b") as f: f.write(input.upload_audio_file.as_bytes()) f.seek(0) @@ -87,6 +111,8 @@ def mocking_bird(input: Input) -> Output: wav, sample_rate = librosa.load(input.local_audio_file.value) write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav + source_spec = Synthesizer.make_spectrogram(wav) + # preprocess encoder_wav = encoder.preprocess_wav(wav, sample_rate) embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True) @@ -114,4 +140,4 @@ def mocking_bird(input: Input) -> Output: source_file = f.read() with open(TEMP_RESULT_AUDIO, "rb") as f: result_file = f.read() - return Output(source_file=source_file, result_file=result_file) \ No newline at end of file + return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec))) \ No newline at end of file diff --git a/mkgui/app_vc.py b/mkgui/app_vc.py new file mode 100644 index 0000000..3e4c793 --- /dev/null +++ b/mkgui/app_vc.py @@ -0,0 +1,167 @@ +from asyncio.windows_events import NULL +from synthesizer.inference import Synthesizer +from pydantic import BaseModel, Field +from encoder import inference as speacker_encoder +import torch +import os +from pathlib import Path +from enum import Enum +import ppg_extractor as Extractor +import ppg2mel as Convertor +import librosa +from scipy.io.wavfile import write +import re +import numpy as np +from mkgui.base.components.types import FileContent +from vocoder.hifigan import inference as gan_vocoder +from typing import Any +import matplotlib.pyplot as plt + + +# Constants +AUDIO_SAMPLES_DIR = 'samples\\' +EXT_MODELS_DIRT = "ppg_extractor\\saved_models" +CONV_MODELS_DIRT = "ppg2mel\\saved_models" +VOC_MODELS_DIRT = "vocoder\\saved_models" +TEMP_SOURCE_AUDIO = "wavs/temp_source.wav" +TEMP_TARGET_AUDIO = "wavs/temp_target.wav" +TEMP_RESULT_AUDIO = "wavs/temp_result.wav" + +# Load local sample audio as options TODO: load dataset +if os.path.isdir(AUDIO_SAMPLES_DIR): + audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav"))) +# Pre-Load models +if os.path.isdir(EXT_MODELS_DIRT): + extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded extractor models: " + str(len(extractors))) +else: + raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(CONV_MODELS_DIRT): + convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth"))) + print("Loaded convertor models: " + str(len(convertors))) +else: + raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(VOC_MODELS_DIRT): + vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt"))) + print("Loaded vocoders models: " + str(len(vocoders))) +else: + raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.") + +class Input(BaseModel): + local_audio_file: audio_input_selection = Field( + ..., alias="输入语音(本地wav)", + description="选择本地语音文件." + ) + upload_audio_file: FileContent = Field(default=None, alias="或上传语音", + description="拖拽或点击上传.", mime_type="audio/wav") + local_audio_file_target: audio_input_selection = Field( + ..., alias="目标语音(本地wav)", + description="选择本地语音文件." + ) + upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音", + description="拖拽或点击上传.", mime_type="audio/wav") + extractor: extractors = Field( + ..., alias="编码模型", + description="选择语音编码模型文件." + ) + convertor: convertors = Field( + ..., alias="转换模型", + description="选择语音转换模型文件." + ) + vocoder: vocoders = Field( + ..., alias="语音编码模型", + description="选择语音解码模型文件(目前只支持HifiGan类型)." + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: tuple[AudioEntity, AudioEntity, AudioEntity] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + src, target, result = self.__root__ + + streamlit_app.subheader("Synthesized Audio") + streamlit_app.audio(result.content, format="audio/wav") + + fig, ax = plt.subplots() + ax.imshow(src.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Source Audio)") + streamlit_app.pyplot(fig) + fig, ax = plt.subplots() + ax.imshow(target.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Target Audio)") + streamlit_app.pyplot(fig) + fig, ax = plt.subplots() + ax.imshow(result.mel, aspect="equal", interpolation="none") + ax.set_title("mel spectrogram(Result Audio)") + streamlit_app.pyplot(fig) + +def convert(input: Input) -> Output: + """convert(转换)""" + # load models + extractor = Extractor.load_model(Path(input.extractor.value)) + convertor = Convertor.load_model(Path(input.convertor.value)) + # current_synt = Synthesizer(Path(input.synthesizer.value)) + gan_vocoder.load_model(Path(input.vocoder.value)) + + # load file + if input.upload_audio_file != None: + with open(TEMP_SOURCE_AUDIO, "w+b") as f: + f.write(input.upload_audio_file.as_bytes()) + f.seek(0) + src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO) + else: + src_wav, sample_rate = librosa.load(input.local_audio_file.value) + write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav + + if input.upload_audio_file_target != None: + with open(TEMP_TARGET_AUDIO, "w+b") as f: + f.write(input.upload_audio_file_target.as_bytes()) + f.seek(0) + ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO) + else: + ref_wav, _ = librosa.load(input.local_audio_file_target.value) + write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav + + ppg = extractor.extract_from_wav(src_wav) + # Import necessary dependency of Voice Conversion + from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv + ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav))) + speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt")) + embed = speacker_encoder.embed_utterance(ref_wav) + lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True) + min_len = min(ppg.shape[1], len(lf0_uv)) + ppg = ppg[:, :min_len] + lf0_uv = lf0_uv[:min_len] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _, mel_pred, att_ws = convertor.inference( + ppg, + logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device), + spembs=torch.from_numpy(embed).unsqueeze(0).to(device), + ) + mel_pred= mel_pred.transpose(0, 1) + breaks = [mel_pred.shape[1]] + mel_pred= mel_pred.detach().cpu().numpy() + + # synthesize and vocode + wav, sample_rate = gan_vocoder.infer_waveform(mel_pred) + + # write and output + write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav + with open(TEMP_SOURCE_AUDIO, "rb") as f: + source_file = f.read() + with open(TEMP_TARGET_AUDIO, "rb") as f: + target_file = f.read() + with open(TEMP_RESULT_AUDIO, "rb") as f: + result_file = f.read() + + + return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav)))) \ No newline at end of file diff --git a/mkgui/base/__init__.py b/mkgui/base/__init__.py new file mode 100644 index 0000000..6905fa0 --- /dev/null +++ b/mkgui/base/__init__.py @@ -0,0 +1,2 @@ + +from .core import Opyrator diff --git a/mkgui/base/api/__init__.py b/mkgui/base/api/__init__.py new file mode 100644 index 0000000..a0c4102 --- /dev/null +++ b/mkgui/base/api/__init__.py @@ -0,0 +1 @@ +from .fastapi_app import create_api diff --git a/mkgui/base/api/fastapi_utils.py b/mkgui/base/api/fastapi_utils.py new file mode 100644 index 0000000..adf582a --- /dev/null +++ b/mkgui/base/api/fastapi_utils.py @@ -0,0 +1,102 @@ +"""Collection of utilities for FastAPI apps.""" + +import inspect +from typing import Any, Type + +from fastapi import FastAPI, Form +from pydantic import BaseModel + + +def as_form(cls: Type[BaseModel]) -> Any: + """Adds an as_form class method to decorated models. + + The as_form class method can be used with FastAPI endpoints + """ + new_params = [ + inspect.Parameter( + field.alias, + inspect.Parameter.POSITIONAL_ONLY, + default=(Form(field.default) if not field.required else Form(...)), + ) + for field in cls.__fields__.values() + ] + + async def _as_form(**data): # type: ignore + return cls(**data) + + sig = inspect.signature(_as_form) + sig = sig.replace(parameters=new_params) + _as_form.__signature__ = sig # type: ignore + setattr(cls, "as_form", _as_form) + return cls + + +def patch_fastapi(app: FastAPI) -> None: + """Patch function to allow relative url resolution. + + This patch is required to make fastapi fully functional with a relative url path. + This code snippet can be copy-pasted to any Fastapi application. + """ + from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html + from starlette.requests import Request + from starlette.responses import HTMLResponse + + async def redoc_ui_html(req: Request) -> HTMLResponse: + assert app.openapi_url is not None + redoc_ui = get_redoc_html( + openapi_url="./" + app.openapi_url.lstrip("/"), + title=app.title + " - Redoc UI", + ) + + return HTMLResponse(redoc_ui.body.decode("utf-8")) + + async def swagger_ui_html(req: Request) -> HTMLResponse: + assert app.openapi_url is not None + swagger_ui = get_swagger_ui_html( + openapi_url="./" + app.openapi_url.lstrip("/"), + title=app.title + " - Swagger UI", + oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, + ) + + # insert request interceptor to have all request run on relativ path + request_interceptor = ( + "requestInterceptor: (e) => {" + "\n\t\t\tvar url = window.location.origin + window.location.pathname" + '\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);' + "\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605 + "\n\t\t\te.contextUrl = url" + "\n\t\t\te.url = url" + "\n\t\t\treturn e;}" + ) + + return HTMLResponse( + swagger_ui.body.decode("utf-8").replace( + "dom_id: '#swagger-ui',", + "dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",", + ) + ) + + # remove old docs route and add our patched route + routes_new = [] + for app_route in app.routes: + if app_route.path == "/docs": # type: ignore + continue + + if app_route.path == "/redoc": # type: ignore + continue + + routes_new.append(app_route) + + app.router.routes = routes_new + + assert app.docs_url is not None + app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False) + assert app.redoc_url is not None + app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False) + + # Make graphql realtive + from starlette import graphql + + graphql.GRAPHIQL = graphql.GRAPHIQL.replace( + "({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}' + ) diff --git a/mkgui/base/components/__init__.py b/mkgui/base/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mkgui/base/components/outputs.py b/mkgui/base/components/outputs.py new file mode 100644 index 0000000..f4859c6 --- /dev/null +++ b/mkgui/base/components/outputs.py @@ -0,0 +1,43 @@ +from typing import List + +from pydantic import BaseModel + + +class ScoredLabel(BaseModel): + label: str + score: float + + +class ClassificationOutput(BaseModel): + __root__: List[ScoredLabel] + + def __iter__(self): # type: ignore + return iter(self.__root__) + + def __getitem__(self, item): # type: ignore + return self.__root__[item] + + def render_output_ui(self, streamlit) -> None: # type: ignore + import plotly.express as px + + sorted_predictions = sorted( + [prediction.dict() for prediction in self.__root__], + key=lambda k: k["score"], + ) + + num_labels = len(sorted_predictions) + if len(sorted_predictions) > 10: + num_labels = streamlit.slider( + "Maximum labels to show: ", + min_value=1, + max_value=len(sorted_predictions), + value=len(sorted_predictions), + ) + fig = px.bar( + sorted_predictions[len(sorted_predictions) - num_labels :], + x="score", + y="label", + orientation="h", + ) + streamlit.plotly_chart(fig, use_container_width=True) + # fig.show() diff --git a/mkgui/base/components/types.py b/mkgui/base/components/types.py new file mode 100644 index 0000000..125809a --- /dev/null +++ b/mkgui/base/components/types.py @@ -0,0 +1,46 @@ +import base64 +from typing import Any, Dict, overload + + +class FileContent(str): + def as_bytes(self) -> bytes: + return base64.b64decode(self, validate=True) + + def as_str(self) -> str: + return self.as_bytes().decode() + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(format="byte") + + @classmethod + def __get_validators__(cls) -> Any: # type: ignore + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> "FileContent": + if isinstance(value, FileContent): + return value + elif isinstance(value, str): + return FileContent(value) + elif isinstance(value, (bytes, bytearray, memoryview)): + return FileContent(base64.b64encode(value).decode()) + else: + raise Exception("Wrong type") + +# # 暂时无法使用,因为浏览器中没有考虑选择文件夹 +# class DirectoryContent(FileContent): +# @classmethod +# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: +# field_schema.update(format="path") + +# @classmethod +# def validate(cls, value: Any) -> "DirectoryContent": +# if isinstance(value, DirectoryContent): +# return value +# elif isinstance(value, str): +# return DirectoryContent(value) +# elif isinstance(value, (bytes, bytearray, memoryview)): +# return DirectoryContent(base64.b64encode(value).decode()) +# else: +# raise Exception("Wrong type") diff --git a/mkgui/base/core.py b/mkgui/base/core.py new file mode 100644 index 0000000..8166a33 --- /dev/null +++ b/mkgui/base/core.py @@ -0,0 +1,203 @@ +import importlib +import inspect +import re +from typing import Any, Callable, Type, Union, get_type_hints + +from pydantic import BaseModel, parse_raw_as +from pydantic.tools import parse_obj_as + + +def name_to_title(name: str) -> str: + """Converts a camelCase or snake_case name to title case.""" + # If camelCase -> convert to snake case + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + # Convert to title case + return name.replace("_", " ").strip().title() + + +def is_compatible_type(type: Type) -> bool: + """Returns `True` if the type is opyrator-compatible.""" + try: + if issubclass(type, BaseModel): + return True + except Exception: + pass + + try: + # valid list type + if type.__origin__ is list and issubclass(type.__args__[0], BaseModel): + return True + except Exception: + pass + + return False + + +def get_input_type(func: Callable) -> Type: + """Returns the input type of a given function (callable). + + Args: + func: The function for which to get the input type. + + Raises: + ValueError: If the function does not have a valid input type annotation. + """ + type_hints = get_type_hints(func) + + if "input" not in type_hints: + raise ValueError( + "The callable MUST have a parameter with the name `input` with typing annotation. " + "For example: `def my_opyrator(input: InputModel) -> OutputModel:`." + ) + + input_type = type_hints["input"] + + if not is_compatible_type(input_type): + raise ValueError( + "The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models." + ) + + # TODO: return warning if more than one input parameters + + return input_type + + +def get_output_type(func: Callable) -> Type: + """Returns the output type of a given function (callable). + + Args: + func: The function for which to get the output type. + + Raises: + ValueError: If the function does not have a valid output type annotation. + """ + type_hints = get_type_hints(func) + if "return" not in type_hints: + raise ValueError( + "The return type of the callable MUST be annotated with type hints." + "For example: `def my_opyrator(input: InputModel) -> OutputModel:`." + ) + + output_type = type_hints["return"] + + if not is_compatible_type(output_type): + raise ValueError( + "The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models." + ) + + return output_type + + +def get_callable(import_string: str) -> Callable: + """Import a callable from an string.""" + callable_seperator = ":" + if callable_seperator not in import_string: + # Use dot as seperator + callable_seperator = "." + + if callable_seperator not in import_string: + raise ValueError("The callable path MUST specify the function. ") + + mod_name, callable_name = import_string.rsplit(callable_seperator, 1) + mod = importlib.import_module(mod_name) + return getattr(mod, callable_name) + + +class Opyrator: + def __init__(self, func: Union[Callable, str]) -> None: + if isinstance(func, str): + # Try to load the function from a string notion + self.function = get_callable(func) + else: + self.function = func + + self._action = "Execute" + self._input_type = None + self._output_type = None + + if not callable(self.function): + raise ValueError("The provided function parameters is not a callable.") + + if inspect.isclass(self.function): + raise ValueError( + "The provided callable is an uninitialized Class. This is not allowed." + ) + + if inspect.isfunction(self.function): + # The provided callable is a function + self._input_type = get_input_type(self.function) + self._output_type = get_output_type(self.function) + + try: + # Get name + self._name = name_to_title(self.function.__name__) + except Exception: + pass + + try: + # Get description from function + doc_string = inspect.getdoc(self.function) + if doc_string: + self._action = doc_string + except Exception: + pass + elif hasattr(self.function, "__call__"): + # The provided callable is a function + self._input_type = get_input_type(self.function.__call__) # type: ignore + self._output_type = get_output_type(self.function.__call__) # type: ignore + + try: + # Get name + self._name = name_to_title(type(self.function).__name__) + except Exception: + pass + + try: + # Get action from + doc_string = inspect.getdoc(self.function.__call__) # type: ignore + if doc_string: + self._action = doc_string + + if ( + not self._action + or self._action == "Call" + ): + # Get docstring from class instead of __call__ function + doc_string = inspect.getdoc(self.function) + if doc_string: + self._action = doc_string + except Exception: + pass + else: + raise ValueError("Unknown callable type.") + + @property + def name(self) -> str: + return self._name + + @property + def action(self) -> str: + return self._action + + @property + def input_type(self) -> Any: + return self._input_type + + @property + def output_type(self) -> Any: + return self._output_type + + def __call__(self, input: Any, **kwargs: Any) -> Any: + + input_obj = input + + if isinstance(input, str): + # Allow json input + input_obj = parse_raw_as(self.input_type, input) + + if isinstance(input, dict): + # Allow dict input + input_obj = parse_obj_as(self.input_type, input) + + return self.function(input_obj, **kwargs) diff --git a/mkgui/base/ui/__init__.py b/mkgui/base/ui/__init__.py new file mode 100644 index 0000000..593b254 --- /dev/null +++ b/mkgui/base/ui/__init__.py @@ -0,0 +1 @@ +from .streamlit_ui import render_streamlit_ui diff --git a/mkgui/base/ui/schema_utils.py b/mkgui/base/ui/schema_utils.py new file mode 100644 index 0000000..a2be43c --- /dev/null +++ b/mkgui/base/ui/schema_utils.py @@ -0,0 +1,129 @@ +from typing import Dict + + +def resolve_reference(reference: str, references: Dict) -> Dict: + return references[reference.split("/")[-1]] + + +def get_single_reference_item(property: Dict, references: Dict) -> Dict: + # Ref can either be directly in the properties or the first element of allOf + reference = property.get("$ref") + if reference is None: + reference = property["allOf"][0]["$ref"] + return resolve_reference(reference, references) + + +def is_single_string_property(property: Dict) -> bool: + return property.get("type") == "string" + + +def is_single_datetime_property(property: Dict) -> bool: + if property.get("type") != "string": + return False + return property.get("format") in ["date-time", "time", "date"] + + +def is_single_boolean_property(property: Dict) -> bool: + return property.get("type") == "boolean" + + +def is_single_number_property(property: Dict) -> bool: + return property.get("type") in ["integer", "number"] + + +def is_single_file_property(property: Dict) -> bool: + if property.get("type") != "string": + return False + # TODO: binary? + return property.get("format") == "byte" + + +def is_single_directory_property(property: Dict) -> bool: + if property.get("type") != "string": + return False + return property.get("format") == "path" + +def is_multi_enum_property(property: Dict, references: Dict) -> bool: + if property.get("type") != "array": + return False + + if property.get("uniqueItems") is not True: + # Only relevant if it is a set or other datastructures with unique items + return False + + try: + _ = resolve_reference(property["items"]["$ref"], references)["enum"] + return True + except Exception: + return False + + +def is_single_enum_property(property: Dict, references: Dict) -> bool: + try: + _ = get_single_reference_item(property, references)["enum"] + return True + except Exception: + return False + + +def is_single_dict_property(property: Dict) -> bool: + if property.get("type") != "object": + return False + return "additionalProperties" in property + + +def is_single_reference(property: Dict) -> bool: + if property.get("type") is not None: + return False + + return bool(property.get("$ref")) + + +def is_multi_file_property(property: Dict) -> bool: + if property.get("type") != "array": + return False + + if property.get("items") is None: + return False + + try: + # TODO: binary + return property["items"]["format"] == "byte" + except Exception: + return False + + +def is_single_object(property: Dict, references: Dict) -> bool: + try: + object_reference = get_single_reference_item(property, references) + if object_reference["type"] != "object": + return False + return "properties" in object_reference + except Exception: + return False + + +def is_property_list(property: Dict) -> bool: + if property.get("type") != "array": + return False + + if property.get("items") is None: + return False + + try: + return property["items"]["type"] in ["string", "number", "integer"] + except Exception: + return False + + +def is_object_list_property(property: Dict, references: Dict) -> bool: + if property.get("type") != "array": + return False + + try: + object_reference = resolve_reference(property["items"]["$ref"], references) + if object_reference["type"] != "object": + return False + return "properties" in object_reference + except Exception: + return False diff --git a/mkgui/base/ui/streamlit_ui.py b/mkgui/base/ui/streamlit_ui.py new file mode 100644 index 0000000..2e5159d --- /dev/null +++ b/mkgui/base/ui/streamlit_ui.py @@ -0,0 +1,885 @@ +import datetime +import inspect +import mimetypes +import sys +from os import getcwd, unlink +from platform import system +from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, List, Type +from PIL import Image + +import pandas as pd +import streamlit as st +from fastapi.encoders import jsonable_encoder +from loguru import logger +from pydantic import BaseModel, ValidationError, parse_obj_as + +from mkgui.base import Opyrator +from mkgui.base.core import name_to_title +from mkgui.base.ui import schema_utils +from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS + +STREAMLIT_RUNNER_SNIPPET = """ +from mkgui.base.ui import render_streamlit_ui +from mkgui.base import Opyrator + +import streamlit as st + +# TODO: Make it configurable +# Page config can only be setup once +st.set_page_config( + page_title="MockingBird", + page_icon="🧊", + layout="wide") + +render_streamlit_ui() +""" + +# with st.spinner("Loading MockingBird GUI. Please wait..."): +# opyrator = Opyrator("{opyrator_path}") + + +def launch_ui(port: int = 8501) -> None: + with NamedTemporaryFile( + suffix=".py", mode="w", encoding="utf-8", delete=False + ) as f: + f.write(STREAMLIT_RUNNER_SNIPPET) + f.seek(0) + + import subprocess + + python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"' + if system() == "Windows": + python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&" + subprocess.run( + f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""", + shell=True, + ) + + subprocess.run( + f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""", + shell=True, + ) + + f.close() + unlink(f.name) + + +def function_has_named_arg(func: Callable, parameter: str) -> bool: + try: + sig = inspect.signature(func) + for param in sig.parameters.values(): + if param.name == "input": + return True + except Exception: + return False + return False + + +def has_output_ui_renderer(data_item: BaseModel) -> bool: + return hasattr(data_item, "render_output_ui") + + +def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool: + return hasattr(input_class, "render_input_ui") + + +def is_compatible_audio(mime_type: str) -> bool: + return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"] + + +def is_compatible_image(mime_type: str) -> bool: + return mime_type in ["image/png", "image/jpeg"] + + +def is_compatible_video(mime_type: str) -> bool: + return mime_type in ["video/mp4"] + + +class InputUI: + def __init__(self, session_state, input_class: Type[BaseModel]): + self._session_state = session_state + self._input_class = input_class + + self._schema_properties = input_class.schema(by_alias=True).get( + "properties", {} + ) + self._schema_references = input_class.schema(by_alias=True).get( + "definitions", {} + ) + + def render_ui(self, streamlit_app_root) -> None: + if has_input_ui_renderer(self._input_class): + # The input model has a rendering function + # The rendering also returns the current state of input data + self._session_state.input_data = self._input_class.render_input_ui( # type: ignore + st, self._session_state.input_data + ) + return + + # print(self._schema_properties) + for property_key in self._schema_properties.keys(): + property = self._schema_properties[property_key] + + if not property.get("title"): + # Set property key as fallback title + property["title"] = name_to_title(property_key) + + try: + if "input_data" in self._session_state: + self._store_value( + property_key, + self._render_property(streamlit_app_root, property_key, property), + ) + except Exception as e: + print("Exception!", e) + pass + + def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict: + streamlit_kwargs = { + "label": property.get("title"), + "key": key, + } + + if property.get("description"): + streamlit_kwargs["help"] = property.get("description") + return streamlit_kwargs + + def _store_value(self, key: str, value: Any) -> None: + data_element = self._session_state.input_data + key_elements = key.split(".") + for i, key_element in enumerate(key_elements): + if i == len(key_elements) - 1: + # add value to this element + data_element[key_element] = value + return + if key_element not in data_element: + data_element[key_element] = {} + data_element = data_element[key_element] + + def _get_value(self, key: str) -> Any: + data_element = self._session_state.input_data + key_elements = key.split(".") + for i, key_element in enumerate(key_elements): + if i == len(key_elements) - 1: + # add value to this element + if key_element not in data_element: + return None + return data_element[key_element] + if key_element not in data_element: + data_element[key_element] = {} + data_element = data_element[key_element] + return None + + def _render_single_datetime_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + if property.get("format") == "time": + if property.get("default"): + try: + streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore + property.get("default") + ) + except Exception: + pass + return streamlit_app.time_input(**streamlit_kwargs) + elif property.get("format") == "date": + if property.get("default"): + try: + streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore + property.get("default") + ) + except Exception: + pass + return streamlit_app.date_input(**streamlit_kwargs) + elif property.get("format") == "date-time": + if property.get("default"): + try: + streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore + property.get("default") + ) + except Exception: + pass + with streamlit_app.container(): + streamlit_app.subheader(streamlit_kwargs.get("label")) + if streamlit_kwargs.get("description"): + streamlit_app.text(streamlit_kwargs.get("description")) + selected_date = None + selected_time = None + date_col, time_col = streamlit_app.columns(2) + with date_col: + date_kwargs = {"label": "Date", "key": key + "-date-input"} + if streamlit_kwargs.get("value"): + try: + date_kwargs["value"] = streamlit_kwargs.get( # type: ignore + "value" + ).date() + except Exception: + pass + selected_date = streamlit_app.date_input(**date_kwargs) + + with time_col: + time_kwargs = {"label": "Time", "key": key + "-time-input"} + if streamlit_kwargs.get("value"): + try: + time_kwargs["value"] = streamlit_kwargs.get( # type: ignore + "value" + ).time() + except Exception: + pass + selected_time = streamlit_app.time_input(**time_kwargs) + return datetime.datetime.combine(selected_date, selected_time) + else: + streamlit_app.warning( + "Date format is not supported: " + str(property.get("format")) + ) + + def _render_single_file_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + file_extension = None + if "mime_type" in property: + file_extension = mimetypes.guess_extension(property["mime_type"]) + + uploaded_file = streamlit_app.file_uploader( + **streamlit_kwargs, accept_multiple_files=False, type=file_extension + ) + if uploaded_file is None: + return None + + bytes = uploaded_file.getvalue() + if property.get("mime_type"): + if is_compatible_audio(property["mime_type"]): + # Show audio + streamlit_app.audio(bytes, format=property.get("mime_type")) + if is_compatible_image(property["mime_type"]): + # Show image + streamlit_app.image(bytes) + if is_compatible_video(property["mime_type"]): + # Show video + streamlit_app.video(bytes, format=property.get("mime_type")) + return bytes + + def _render_single_string_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + if property.get("default"): + streamlit_kwargs["value"] = property.get("default") + elif property.get("example"): + # TODO: also use example for other property types + # Use example as value if it is provided + streamlit_kwargs["value"] = property.get("example") + + if property.get("maxLength") is not None: + streamlit_kwargs["max_chars"] = property.get("maxLength") + + if ( + property.get("format") + or ( + property.get("maxLength") is not None + and int(property.get("maxLength")) < 140 # type: ignore + ) + or property.get("writeOnly") + ): + # If any format is set, use single text input + # If max chars is set to less than 140, use single text input + # If write only -> password field + if property.get("writeOnly"): + streamlit_kwargs["type"] = "password" + return streamlit_app.text_input(**streamlit_kwargs) + else: + # Otherwise use multiline text area + return streamlit_app.text_area(**streamlit_kwargs) + + def _render_multi_enum_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + reference_item = schema_utils.resolve_reference( + property["items"]["$ref"], self._schema_references + ) + # TODO: how to select defaults + return streamlit_app.multiselect( + **streamlit_kwargs, options=reference_item["enum"] + ) + + def _render_single_enum_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + reference_item = schema_utils.get_single_reference_item( + property, self._schema_references + ) + + if property.get("default") is not None: + try: + streamlit_kwargs["index"] = reference_item["enum"].index( + property.get("default") + ) + except Exception: + # Use default selection + pass + + return streamlit_app.selectbox( + **streamlit_kwargs, options=reference_item["enum"] + ) + + def _render_single_dict_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + + # Add title and subheader + streamlit_app.subheader(property.get("title")) + if property.get("description"): + streamlit_app.markdown(property.get("description")) + + streamlit_app.markdown("---") + + current_dict = self._get_value(key) + if not current_dict: + current_dict = {} + + key_col, value_col = streamlit_app.columns(2) + + with key_col: + updated_key = streamlit_app.text_input( + "Key", value="", key=key + "-new-key" + ) + + with value_col: + # TODO: also add boolean? + value_kwargs = {"label": "Value", "key": key + "-new-value"} + if property["additionalProperties"].get("type") == "integer": + value_kwargs["value"] = 0 # type: ignore + updated_value = streamlit_app.number_input(**value_kwargs) + elif property["additionalProperties"].get("type") == "number": + value_kwargs["value"] = 0.0 # type: ignore + value_kwargs["format"] = "%f" + updated_value = streamlit_app.number_input(**value_kwargs) + else: + value_kwargs["value"] = "" + updated_value = streamlit_app.text_input(**value_kwargs) + + streamlit_app.markdown("---") + + with streamlit_app.container(): + clear_col, add_col = streamlit_app.columns([1, 2]) + + with clear_col: + if streamlit_app.button("Clear Items", key=key + "-clear-items"): + current_dict = {} + + with add_col: + if ( + streamlit_app.button("Add Item", key=key + "-add-item") + and updated_key + ): + current_dict[updated_key] = updated_value + + streamlit_app.write(current_dict) + + return current_dict + + def _render_single_reference( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + reference_item = schema_utils.get_single_reference_item( + property, self._schema_references + ) + return self._render_property(streamlit_app, key, reference_item) + + def _render_multi_file_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + file_extension = None + if "mime_type" in property: + file_extension = mimetypes.guess_extension(property["mime_type"]) + + uploaded_files = streamlit_app.file_uploader( + **streamlit_kwargs, accept_multiple_files=True, type=file_extension + ) + uploaded_files_bytes = [] + if uploaded_files: + for uploaded_file in uploaded_files: + uploaded_files_bytes.append(uploaded_file.read()) + return uploaded_files_bytes + + def _render_single_boolean_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + if property.get("default"): + streamlit_kwargs["value"] = property.get("default") + return streamlit_app.checkbox(**streamlit_kwargs) + + def _render_single_number_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + + number_transform = int + if property.get("type") == "number": + number_transform = float # type: ignore + streamlit_kwargs["format"] = "%f" + + if "multipleOf" in property: + # Set stepcount based on multiple of parameter + streamlit_kwargs["step"] = number_transform(property["multipleOf"]) + elif number_transform == int: + # Set step size to 1 as default + streamlit_kwargs["step"] = 1 + elif number_transform == float: + # Set step size to 0.01 as default + # TODO: adapt to default value + streamlit_kwargs["step"] = 0.01 + + if "minimum" in property: + streamlit_kwargs["min_value"] = number_transform(property["minimum"]) + if "exclusiveMinimum" in property: + streamlit_kwargs["min_value"] = number_transform( + property["exclusiveMinimum"] + streamlit_kwargs["step"] + ) + if "maximum" in property: + streamlit_kwargs["max_value"] = number_transform(property["maximum"]) + + if "exclusiveMaximum" in property: + streamlit_kwargs["max_value"] = number_transform( + property["exclusiveMaximum"] - streamlit_kwargs["step"] + ) + + if property.get("default") is not None: + streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore + else: + if "min_value" in streamlit_kwargs: + streamlit_kwargs["value"] = streamlit_kwargs["min_value"] + elif number_transform == int: + streamlit_kwargs["value"] = 0 + else: + # Set default value to step + streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"]) + + if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs: + # TODO: Only if less than X steps + return streamlit_app.slider(**streamlit_kwargs) + else: + return streamlit_app.number_input(**streamlit_kwargs) + + def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any: + properties = property["properties"] + object_inputs = {} + for property_key in properties: + property = properties[property_key] + if not property.get("title"): + # Set property key as fallback title + property["title"] = name_to_title(property_key) + # construct full key based on key parts -> required later to get the value + full_key = key + "." + property_key + object_inputs[property_key] = self._render_property( + streamlit_app, full_key, property + ) + return object_inputs + + def _render_single_object_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + # Add title and subheader + title = property.get("title") + streamlit_app.subheader(title) + if property.get("description"): + streamlit_app.markdown(property.get("description")) + + object_reference = schema_utils.get_single_reference_item( + property, self._schema_references + ) + return self._render_object_input(streamlit_app, key, object_reference) + + def _render_property_list_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + + # Add title and subheader + streamlit_app.subheader(property.get("title")) + if property.get("description"): + streamlit_app.markdown(property.get("description")) + + streamlit_app.markdown("---") + + current_list = self._get_value(key) + if not current_list: + current_list = [] + + value_kwargs = {"label": "Value", "key": key + "-new-value"} + if property["items"]["type"] == "integer": + value_kwargs["value"] = 0 # type: ignore + new_value = streamlit_app.number_input(**value_kwargs) + elif property["items"]["type"] == "number": + value_kwargs["value"] = 0.0 # type: ignore + value_kwargs["format"] = "%f" + new_value = streamlit_app.number_input(**value_kwargs) + else: + value_kwargs["value"] = "" + new_value = streamlit_app.text_input(**value_kwargs) + + streamlit_app.markdown("---") + + with streamlit_app.container(): + clear_col, add_col = streamlit_app.columns([1, 2]) + + with clear_col: + if streamlit_app.button("Clear Items", key=key + "-clear-items"): + current_list = [] + + with add_col: + if ( + streamlit_app.button("Add Item", key=key + "-add-item") + and new_value is not None + ): + current_list.append(new_value) + + streamlit_app.write(current_list) + + return current_list + + def _render_object_list_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + + # TODO: support max_items, and min_items properties + + # Add title and subheader + streamlit_app.subheader(property.get("title")) + if property.get("description"): + streamlit_app.markdown(property.get("description")) + + streamlit_app.markdown("---") + + current_list = self._get_value(key) + if not current_list: + current_list = [] + + object_reference = schema_utils.resolve_reference( + property["items"]["$ref"], self._schema_references + ) + input_data = self._render_object_input(streamlit_app, key, object_reference) + + streamlit_app.markdown("---") + + with streamlit_app.container(): + clear_col, add_col = streamlit_app.columns([1, 2]) + + with clear_col: + if streamlit_app.button("Clear Items", key=key + "-clear-items"): + current_list = [] + + with add_col: + if ( + streamlit_app.button("Add Item", key=key + "-add-item") + and input_data + ): + current_list.append(input_data) + + streamlit_app.write(current_list) + return current_list + + def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any: + if schema_utils.is_single_enum_property(property, self._schema_references): + return self._render_single_enum_input(streamlit_app, key, property) + + if schema_utils.is_multi_enum_property(property, self._schema_references): + return self._render_multi_enum_input(streamlit_app, key, property) + + if schema_utils.is_single_file_property(property): + return self._render_single_file_input(streamlit_app, key, property) + + if schema_utils.is_multi_file_property(property): + return self._render_multi_file_input(streamlit_app, key, property) + + if schema_utils.is_single_datetime_property(property): + return self._render_single_datetime_input(streamlit_app, key, property) + + if schema_utils.is_single_boolean_property(property): + return self._render_single_boolean_input(streamlit_app, key, property) + + if schema_utils.is_single_dict_property(property): + return self._render_single_dict_input(streamlit_app, key, property) + + if schema_utils.is_single_number_property(property): + return self._render_single_number_input(streamlit_app, key, property) + + if schema_utils.is_single_string_property(property): + return self._render_single_string_input(streamlit_app, key, property) + + if schema_utils.is_single_object(property, self._schema_references): + return self._render_single_object_input(streamlit_app, key, property) + + if schema_utils.is_object_list_property(property, self._schema_references): + return self._render_object_list_input(streamlit_app, key, property) + + if schema_utils.is_property_list(property): + return self._render_property_list_input(streamlit_app, key, property) + + if schema_utils.is_single_reference(property): + return self._render_single_reference(streamlit_app, key, property) + + streamlit_app.warning( + "The type of the following property is currently not supported: " + + str(property.get("title")) + ) + raise Exception("Unsupported property") + + +class OutputUI: + def __init__(self, output_data: Any, input_data: Any): + self._output_data = output_data + self._input_data = input_data + + def render_ui(self, streamlit_app) -> None: + try: + if isinstance(self._output_data, BaseModel): + self._render_single_output(streamlit_app, self._output_data) + return + if type(self._output_data) == list: + self._render_list_output(streamlit_app, self._output_data) + return + except Exception as ex: + streamlit_app.exception(ex) + # Fallback to + streamlit_app.json(jsonable_encoder(self._output_data)) + + def _render_single_text_property( + self, streamlit: st, property_schema: Dict, value: Any + ) -> None: + # Add title and subheader + streamlit.subheader(property_schema.get("title")) + if property_schema.get("description"): + streamlit.markdown(property_schema.get("description")) + if value is None or value == "": + streamlit.info("No value returned!") + else: + streamlit.code(str(value), language="plain") + + def _render_single_file_property( + self, streamlit: st, property_schema: Dict, value: Any + ) -> None: + # Add title and subheader + streamlit.subheader(property_schema.get("title")) + if property_schema.get("description"): + streamlit.markdown(property_schema.get("description")) + if value is None or value == "": + streamlit.info("No value returned!") + else: + # TODO: Detect if it is a FileContent instance + # TODO: detect if it is base64 + file_extension = "" + if "mime_type" in property_schema: + mime_type = property_schema["mime_type"] + file_extension = mimetypes.guess_extension(mime_type) or "" + + if is_compatible_audio(mime_type): + streamlit.audio(value.as_bytes(), format=mime_type) + return + + if is_compatible_image(mime_type): + streamlit.image(value.as_bytes()) + return + + if is_compatible_video(mime_type): + streamlit.video(value.as_bytes(), format=mime_type) + return + + filename = ( + (property_schema["title"] + file_extension) + .lower() + .strip() + .replace(" ", "-") + ) + streamlit.markdown( + f'', + unsafe_allow_html=True, + ) + + def _render_single_complex_property( + self, streamlit: st, property_schema: Dict, value: Any + ) -> None: + # Add title and subheader + streamlit.subheader(property_schema.get("title")) + if property_schema.get("description"): + streamlit.markdown(property_schema.get("description")) + + streamlit.json(jsonable_encoder(value)) + + def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None: + try: + if has_output_ui_renderer(output_data): + if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore + # render method also requests the input data + output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore + else: + output_data.render_output_ui(streamlit) # type: ignore + return + except Exception: + # Use default auto-generation methods if the custom rendering throws an exception + logger.exception( + "Failed to execute custom render_output_ui function. Using auto-generation instead" + ) + + model_schema = output_data.schema(by_alias=False) + model_properties = model_schema.get("properties") + definitions = model_schema.get("definitions") + + if model_properties: + for property_key in output_data.__dict__: + property_schema = model_properties.get(property_key) + if not property_schema.get("title"): + # Set property key as fallback title + property_schema["title"] = property_key + + output_property_value = output_data.__dict__[property_key] + + if has_output_ui_renderer(output_property_value): + output_property_value.render_output_ui(streamlit) # type: ignore + continue + + if isinstance(output_property_value, BaseModel): + # Render output recursivly + streamlit.subheader(property_schema.get("title")) + if property_schema.get("description"): + streamlit.markdown(property_schema.get("description")) + self._render_single_output(streamlit, output_property_value) + continue + + if property_schema: + if schema_utils.is_single_file_property(property_schema): + self._render_single_file_property( + streamlit, property_schema, output_property_value + ) + continue + + if ( + schema_utils.is_single_string_property(property_schema) + or schema_utils.is_single_number_property(property_schema) + or schema_utils.is_single_datetime_property(property_schema) + or schema_utils.is_single_boolean_property(property_schema) + ): + self._render_single_text_property( + streamlit, property_schema, output_property_value + ) + continue + if definitions and schema_utils.is_single_enum_property( + property_schema, definitions + ): + self._render_single_text_property( + streamlit, property_schema, output_property_value.value + ) + continue + + # TODO: render dict as table + + self._render_single_complex_property( + streamlit, property_schema, output_property_value + ) + return + + def _render_list_output(self, streamlit: st, output_data: List) -> None: + try: + data_items: List = [] + for data_item in output_data: + if has_output_ui_renderer(data_item): + # Render using the render function + data_item.render_output_ui(streamlit) # type: ignore + continue + data_items.append(data_item.dict()) + # Try to show as dataframe + streamlit.table(pd.DataFrame(data_items)) + except Exception: + # Fallback to + streamlit.json(jsonable_encoder(output_data)) + + +def getOpyrator(mode: str) -> Opyrator: + if mode == None or mode.startswith('VC'): + from mkgui.app_vc import convert + return Opyrator(convert) + if mode == None or mode.startswith('预处理'): + from mkgui.preprocess import preprocess + return Opyrator(preprocess) + if mode == None or mode.startswith('模型训练'): + from mkgui.train import train + return Opyrator(train) + from mkgui.app import synthesize + return Opyrator(synthesize) + + +def render_streamlit_ui() -> None: + # init + session_state = st.session_state + session_state.input_data = {} + # Add custom css settings + st.markdown(f"", unsafe_allow_html=True) + + with st.spinner("Loading MockingBird GUI. Please wait..."): + session_state.mode = st.sidebar.selectbox( + '模式选择', + ( "AI拟音", "VC拟音", "预处理", "模型训练") + ) + if "mode" in session_state: + mode = session_state.mode + else: + mode = "" + opyrator = getOpyrator(mode) + title = opyrator.name + mode + + col1, col2, _ = st.columns(3) + col2.title(title) + col2.markdown("欢迎使用MockingBird Web 2") + + image = Image.open('.\\mkgui\\static\\mb.png') + col1.image(image) + + st.markdown("---") + left, right = st.columns([0.4, 0.6]) + + with left: + st.header("Control 控制") + InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st) + execute_selected = st.button(opyrator.action) + if execute_selected: + with st.spinner("Executing operation. Please wait..."): + try: + input_data_obj = parse_obj_as( + opyrator.input_type, session_state.input_data + ) + session_state.output_data = opyrator(input=input_data_obj) + session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object? + except ValidationError as ex: + st.error(ex) + else: + # st.success("Operation executed successfully.") + pass + + with right: + st.header("Result 结果") + if 'output_data' in session_state: + OutputUI( + session_state.output_data, session_state.latest_operation_input + ).render_ui(st) + if st.button("Clear"): + # Clear all state + for key in st.session_state.keys(): + del st.session_state[key] + session_state.input_data = {} + st.experimental_rerun() + else: + # placeholder + st.caption("请使用左侧控制板进行输入并运行获得结果") + + diff --git a/mkgui/base/ui/streamlit_utils.py b/mkgui/base/ui/streamlit_utils.py new file mode 100644 index 0000000..beb6e65 --- /dev/null +++ b/mkgui/base/ui/streamlit_utils.py @@ -0,0 +1,13 @@ +CUSTOM_STREAMLIT_CSS = """ +div[data-testid="stBlock"] button { + width: 100% !important; + margin-bottom: 20px !important; + border-color: #bfbfbf !important; +} +section[data-testid="stSidebar"] div { + max-width: 10rem; +} +pre code { + white-space: pre-wrap; +} +""" diff --git a/mkgui/preprocess.py b/mkgui/preprocess.py new file mode 100644 index 0000000..9d41994 --- /dev/null +++ b/mkgui/preprocess.py @@ -0,0 +1,96 @@ +from pydantic import BaseModel, Field +import os +from pathlib import Path +from enum import Enum +from typing import Any + + +# Constants +EXT_MODELS_DIRT = "ppg_extractor\\saved_models" +ENC_MODELS_DIRT = "encoder\\saved_models" + + +if os.path.isdir(EXT_MODELS_DIRT): + extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded extractor models: " + str(len(extractors))) +else: + raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(ENC_MODELS_DIRT): + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded encoders models: " + str(len(encoders))) +else: + raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") + +class Model(str, Enum): + VC_PPG2MEL = "ppg2mel" + +class Dataset(str, Enum): + AIDATATANG_200ZH = "aidatatang_200zh" + AIDATATANG_200ZH_S = "aidatatang_200zh_s" + +class Input(BaseModel): + # def render_input_ui(st, input) -> Dict: + # input["selected_dataset"] = st.selectbox( + # '选择数据集', + # ("aidatatang_200zh", "aidatatang_200zh_s") + # ) + # return input + model: Model = Field( + Model.VC_PPG2MEL, title="目标模型", + ) + dataset: Dataset = Field( + Dataset.AIDATATANG_200ZH, title="数据集选择", + ) + datasets_root: str = Field( + ..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)", + format=True, + example="..\\trainning_data\\" + ) + output_root: str = Field( + ..., alias="输出根目录", description="输出结果根目录(相对/绝对)", + format=True, + example="..\\trainning_data\\" + ) + n_processes: int = Field( + 2, alias="处理线程数", description="根据CPU线程数来设置", + le=32, ge=1 + ) + extractor: extractors = Field( + ..., alias="特征提取模型", + description="选择PPG特征提取模型文件." + ) + encoder: encoders = Field( + ..., alias="语音编码模型", + description="选择语音编码模型文件." + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: tuple[str, int] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + sr, count = self.__root__ + streamlit_app.subheader(f"Dataset {sr} done processed total of {count}") + +def preprocess(input: Input) -> Output: + """Preprocess(预处理)""" + finished = 0 + if input.model == Model.VC_PPG2MEL: + from ppg2mel.preprocess import preprocess_dataset + finished = preprocess_dataset( + datasets_root=Path(input.datasets_root), + dataset=input.dataset, + out_dir=Path(input.output_root), + n_processes=input.n_processes, + ppg_encoder_model_fpath=Path(input.extractor.value), + speaker_encoder_model=Path(input.encoder.value) + ) + # TODO: pass useful return code + return Output(__root__=(input.dataset, finished)) \ No newline at end of file diff --git a/mkgui/static/mb.png b/mkgui/static/mb.png new file mode 100644 index 0000000..abd804c Binary files /dev/null and b/mkgui/static/mb.png differ diff --git a/mkgui/train.py b/mkgui/train.py new file mode 100644 index 0000000..7b85ecc --- /dev/null +++ b/mkgui/train.py @@ -0,0 +1,156 @@ +from pydantic import BaseModel, Field +import os +from pathlib import Path +from enum import Enum +from typing import Any +import numpy as np +from utils.load_yaml import HpsYaml +from utils.util import AttrDict +import torch + +# TODO: seperator for *unix systems +# Constants +EXT_MODELS_DIRT = "ppg_extractor\\saved_models" +CONV_MODELS_DIRT = "ppg2mel\\saved_models" +ENC_MODELS_DIRT = "encoder\\saved_models" + + +if os.path.isdir(EXT_MODELS_DIRT): + extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded extractor models: " + str(len(extractors))) +else: + raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(CONV_MODELS_DIRT): + convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth"))) + print("Loaded convertor models: " + str(len(convertors))) +else: + raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(ENC_MODELS_DIRT): + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded encoders models: " + str(len(encoders))) +else: + raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") + +class Model(str, Enum): + VC_PPG2MEL = "ppg2mel" + +class Dataset(str, Enum): + AIDATATANG_200ZH = "aidatatang_200zh" + AIDATATANG_200ZH_S = "aidatatang_200zh_s" + +class Input(BaseModel): + # def render_input_ui(st, input) -> Dict: + # input["selected_dataset"] = st.selectbox( + # '选择数据集', + # ("aidatatang_200zh", "aidatatang_200zh_s") + # ) + # return input + model: Model = Field( + Model.VC_PPG2MEL, title="模型类型", + ) + # datasets_root: str = Field( + # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", + # format=True, + # example="..\\trainning_data\\" + # ) + output_root: str = Field( + ..., alias="输出目录(可选)", description="建议不填,保持默认", + format=True, + example="" + ) + continue_mode: bool = Field( + True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练", + ) + gpu: bool = Field( + True, alias="GPU训练", description="选择“是”,则使用GPU训练", + ) + verbose: bool = Field( + True, alias="打印详情", description="选择“是”,输出更多详情", + ) + # TODO: Move to hiden fields by default + convertor: convertors = Field( + ..., alias="转换模型", + description="选择语音转换模型文件." + ) + extractor: extractors = Field( + ..., alias="特征提取模型", + description="选择PPG特征提取模型文件." + ) + encoder: encoders = Field( + ..., alias="语音编码模型", + description="选择语音编码模型文件." + ) + njobs: int = Field( + 8, alias="进程数", description="适用于ppg2mel", + ) + seed: int = Field( + default=0, alias="初始随机数", description="适用于ppg2mel", + ) + model_name: str = Field( + ..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效", + example="test" + ) + model_config: str = Field( + ..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效", + example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2" + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: tuple[str, int] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + sr, count = self.__root__ + streamlit_app.subheader(f"Dataset {sr} done processed total of {count}") + +def train(input: Input) -> Output: + """Train(训练)""" + + print(">>> OneShot VC training ...") + params = AttrDict() + params.update({ + "gpu": input.gpu, + "cpu": not input.gpu, + "njobs": input.njobs, + "seed": input.seed, + "verbose": input.verbose, + "load": input.convertor.value, + "warm_start": False, + }) + if input.continue_mode: + # trace old model and config + p = Path(input.convertor.value) + params.name = p.parent.name + # search a config file + model_config_fpaths = list(p.parent.rglob("*.yaml")) + if len(model_config_fpaths) == 0: + raise "No model yaml config found for convertor" + config = HpsYaml(model_config_fpaths[0]) + params.ckpdir = p.parent.parent + params.config = model_config_fpaths[0] + params.logdir = os.path.join(p.parent, "log") + else: + # Make the config dict dot visitable + config = HpsYaml(input.config) + np.random.seed(input.seed) + torch.manual_seed(input.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(input.seed) + mode = "train" + from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + solver = Solver(config, params, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + + # TODO: pass useful return code + return Output(__root__=(input.dataset, 0)) \ No newline at end of file diff --git a/ppg2mel/__init__.py b/ppg2mel/__init__.py index 53ee3b2..cc54db8 100644 --- a/ppg2mel/__init__.py +++ b/ppg2mel/__init__.py @@ -191,12 +191,15 @@ class MelDecoderMOLv2(AbsMelDecoder): return mel_outputs[0], mel_outputs_postnet[0], alignments[0] -def load_model(train_config, model_file, device=None): - +def load_model(model_file, device=None): + # search a config file + model_config_fpaths = list(model_file.parent.rglob("*.yaml")) + if len(model_config_fpaths) == 0: + raise "No model yaml config found for convertor" if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model_config = HpsYaml(train_config) + model_config = HpsYaml(model_config_fpaths[0]) ppg2mel_model = MelDecoderMOLv2( **model_config["model"] ).to(device) diff --git a/ppg2mel/preprocess.py b/ppg2mel/preprocess.py index 6da9054..0feee6e 100644 --- a/ppg2mel/preprocess.py +++ b/ppg2mel/preprocess.py @@ -110,3 +110,4 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder t_fid_file.close() d_fid_file.close() e_fid_file.close() + return len(wav_file_list) diff --git a/ppg2mel/train.py b/ppg2mel/train.py index fed7501..d3ef729 100644 --- a/ppg2mel/train.py +++ b/ppg2mel/train.py @@ -31,15 +31,10 @@ def main(): parser.add_argument('--njobs', default=8, type=int, help='Number of threads for dataloader/decoding.', required=False) parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') - parser.add_argument('--no-pin', action='store_true', - help='Disable pin-memory for dataloader') - parser.add_argument('--test', action='store_true', help='Test the model.') + # parser.add_argument('--no-pin', action='store_true', + # help='Disable pin-memory for dataloader') parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') - parser.add_argument('--finetune', action='store_true', help='Finetune model') - parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') - parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') - parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') - + ### paras = parser.parse_args() diff --git a/ppg2mel/train/solver.py b/ppg2mel/train/solver.py index 264a91c..9ca71cb 100644 --- a/ppg2mel/train/solver.py +++ b/ppg2mel/train/solver.py @@ -93,6 +93,7 @@ class BaseSolver(): def load_ckpt(self): ''' Load ckpt if --load option is specified ''' + print(self.paras) if self.paras.load is not None: if self.paras.warm_start: self.verbose(f"Warm starting model from checkpoint {self.paras.load}.") @@ -100,7 +101,7 @@ class BaseSolver(): self.paras.load, map_location=self.device if self.mode == 'train' else 'cpu') model_dict = ckpt['model'] - if len(self.config.model.ignore_layers) > 0: + if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0: model_dict = {k:v for k, v in model_dict.items() if k not in self.config.model.ignore_layers} dummy_dict = self.model.state_dict() diff --git a/requirements.txt b/requirements.txt index 21becf4..5c64e70 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,5 +21,7 @@ flask_cors==3.0.10 gevent==21.8.0 flask_restx tensorboard -opyrator -streamlit==1.3.1 \ No newline at end of file +streamlit==1.8.0 +PyYAML==5.4.1 +torch_complex +espnet \ No newline at end of file diff --git a/requirements_vc.txt b/requirements_vc.txt deleted file mode 100644 index 871fdee..0000000 --- a/requirements_vc.txt +++ /dev/null @@ -1,3 +0,0 @@ -PyYAML==5.4.1 -torch_complex -espnet \ No newline at end of file diff --git a/samples/T0055G0013S0005.wav b/samples/T0055G0013S0005.wav new file mode 100644 index 0000000..4fcc65c Binary files /dev/null and b/samples/T0055G0013S0005.wav differ diff --git a/toolbox/__init__.py b/toolbox/__init__.py index 3d03397..76cd36a 100644 --- a/toolbox/__init__.py +++ b/toolbox/__init__.py @@ -405,16 +405,11 @@ class Toolbox: if self.ui.current_convertor_fpath is None: return model_fpath = self.ui.current_convertor_fpath - # search a config file - model_config_fpaths = list(model_fpath.parent.rglob("*.yaml")) - if self.ui.current_convertor_fpath is None: - return - model_config_fpath = model_config_fpaths[0] self.ui.log("Loading the convertor %s... " % model_fpath) self.ui.set_loading(1) start = timer() import ppg2mel as convertor - self.convertor = convertor.load_model(model_config_fpath, model_fpath) + self.convertor = convertor.load_model( model_fpath) self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") self.ui.set_loading(0) diff --git a/utils/util.py b/utils/util.py index 5227538..34bcffd 100644 --- a/utils/util.py +++ b/utils/util.py @@ -42,3 +42,9 @@ def human_format(num): # add more suffixes if you need them return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude]) + +# provide easy access of attribute from dict, such abc.key +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/vocoder/hifigan/env.py b/vocoder/hifigan/env.py index 2bdbc95..8f0d306 100644 --- a/vocoder/hifigan/env.py +++ b/vocoder/hifigan/env.py @@ -1,13 +1,6 @@ import os import shutil - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - def build_env(config, config_name, path): t_path = os.path.join(path, config_name) if config != t_path: diff --git a/vocoder/hifigan/inference.py b/vocoder/hifigan/inference.py index edbcd38..8caf348 100644 --- a/vocoder/hifigan/inference.py +++ b/vocoder/hifigan/inference.py @@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import json import torch -from vocoder.hifigan.env import AttrDict +from utils.util import AttrDict from vocoder.hifigan.models import Generator generator = None # type: Generator @@ -26,7 +26,11 @@ def load_model(weights_fpath, config_fpath=None, verbose=True): print("Building hifigan") if config_fpath == None: - config_fpath = "./vocoder/hifigan/config_16k_.json" + model_config_fpaths = list(weights_fpath.parent.rglob("*.json")) + if len(model_config_fpaths) > 0: + config_fpath = model_config_fpaths[0] + else: + config_fpath = "./vocoder/hifigan/config_16k_.json" with open(config_fpath) as f: data = f.read() json_config = json.loads(data) diff --git a/vocoder/hifigan/train.py b/vocoder/hifigan/train.py index 987bcca..8760274 100644 --- a/vocoder/hifigan/train.py +++ b/vocoder/hifigan/train.py @@ -12,7 +12,6 @@ from torch.utils.data import DistributedSampler, DataLoader import torch.multiprocessing as mp from torch.distributed import init_process_group from torch.nn.parallel import DistributedDataParallel -from vocoder.hifigan.env import AttrDict, build_env from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\ discriminator_loss diff --git a/vocoder_train.py b/vocoder_train.py index d3ad0f5..1ef0e30 100644 --- a/vocoder_train.py +++ b/vocoder_train.py @@ -1,7 +1,7 @@ from utils.argutils import print_args from vocoder.wavernn.train import train from vocoder.hifigan.train import train as train_hifigan -from vocoder.hifigan.env import AttrDict +from utils.util import AttrDict from pathlib import Path import argparse import json diff --git a/web.py b/web.py index 56ac93c..d232530 100644 --- a/web.py +++ b/web.py @@ -1,11 +1,21 @@ -from web import webApp -from gevent import pywsgi as wsgi +import os +import sys +import typer +cli = typer.Typer() + +@cli.command() +def launch_ui(port: int = typer.Option(8080, "--port", "-p")) -> None: + """Start a graphical UI server for the opyrator. + + The UI is auto-generated from the input- and output-schema of the given function. + """ + # Add the current working directory to the sys path + # This is required to resolve the opyrator path + sys.path.append(os.getcwd()) + + from mkgui.base.ui.streamlit_ui import launch_ui + launch_ui(port) if __name__ == "__main__": - app = webApp() - host = app.config.get("HOST") - port = app.config.get("PORT") - print(f"Web server: http://{host}:{port}") - server = wsgi.WSGIServer((host, port), app) - server.serve_forever() + cli() \ No newline at end of file