mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Upgrade to new web service (#529)
* Init new GUI * Remove unused codes * Reset layout * Add samples * Make framework to support multiple pages * Add vc mode * Add preprocessing mode * Add training mode * Remove text input in vc mode * Add entry for GUI and revise readme * Move requirement together * Add error raise when no model folder found * Add readme
This commit is contained in:
parent
7f799d322f
commit
c5d03fb3cb
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,7 +13,6 @@
|
|||
*.bbl
|
||||
*.bcf
|
||||
*.toc
|
||||
*.wav
|
||||
*.sh
|
||||
*/saved_models
|
||||
!vocoder/saved_models/pretrained/**
|
||||
|
|
8
.vscode/launch.json
vendored
8
.vscode/launch.json
vendored
|
@ -61,5 +61,13 @@
|
|||
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GUI",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "mkgui\\base\\_cli.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": []
|
||||
},
|
||||
]
|
||||
}
|
||||
|
|
16
README-CN.md
16
README-CN.md
|
@ -18,6 +18,15 @@
|
|||
|
||||
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
|
||||
|
||||
### 进行中的工作
|
||||
* GUI/客户端大升级与合并
|
||||
[X] 初始化框架 `./mkgui` (基于streamlit + fastapi)和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
||||
[X] 增加 Voice Cloning and Conversion的演示页面
|
||||
[X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面
|
||||
[ ] 增加其他的的预处理preprocessing 和训练 training 页面
|
||||
* 模型后端基于ESPnet2升级
|
||||
|
||||
|
||||
## 开始
|
||||
### 1. 安装要求
|
||||
> 按照原始存储库测试您是否已准备好所有环境。
|
||||
|
@ -82,15 +91,10 @@
|
|||
### 3. 启动程序或工具箱
|
||||
您可以尝试使用以下命令:
|
||||
|
||||
### 3.1 启动Web程序:
|
||||
### 3.1 启动Web程序(v2):
|
||||
`python web.py`
|
||||
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
|
||||
![123](https://user-images.githubusercontent.com/12797292/135494044-ae59181c-fe3a-406f-9c7d-d21d12fdb4cb.png)
|
||||
> 注:目前界面比较buggy,
|
||||
> * 第一次点击`录制`要等待几秒浏览器正常启动录音,否则会有重音
|
||||
> * 录制结束不要再点`录制`而是`停止`
|
||||
> * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒
|
||||
> * 默认使用第一个找到的模型,有动手能力的可以看代码修改 `web\__init__.py`。
|
||||
|
||||
### 3.2 启动工具箱:
|
||||
`python demo_toolbox.py -d <datasets_root>`
|
||||
|
|
|
@ -18,6 +18,14 @@
|
|||
|
||||
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/)
|
||||
|
||||
### Ongoing Works(Helps Needed)
|
||||
* Major upgrade on GUI/Client and unifying web and toolbox
|
||||
[X] Init framework `./mkgui` and [tech design](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
||||
[X] Add demo part of Voice Cloning and Conversion
|
||||
[X] Add preprocessing and training for Voice Conversion
|
||||
[ ] Add preprocessing and training for Encoder/Synthesizer/Vocoder
|
||||
* Major upgrade on model backend based on ESPnet2(not yet started)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Install Requirements
|
||||
|
|
|
@ -8,9 +8,11 @@ import librosa
|
|||
from scipy.io.wavfile import write
|
||||
import re
|
||||
import numpy as np
|
||||
from opyrator.components.types import FileContent
|
||||
from mkgui.base.components.types import FileContent
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from synthesizer.inference import Synthesizer
|
||||
from typing import Any
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = 'samples\\'
|
||||
|
@ -27,20 +29,31 @@ if os.path.isdir(AUDIO_SAMPLES_DIR):
|
|||
if os.path.isdir(SYN_MODELS_DIRT):
|
||||
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(synthesizers)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||
)
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(..., alias="或上传语音",
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
encoder: encoders = Field(
|
||||
..., alias="编码模型",
|
||||
|
@ -48,37 +61,48 @@ class Input(BaseModel):
|
|||
)
|
||||
synthesizer: synthesizers = Field(
|
||||
..., alias="合成模型",
|
||||
description="选择语音编码模型文件."
|
||||
description="选择语音合成模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="输出文本内容"
|
||||
..., alias="语音解码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
result_file: FileContent = Field(
|
||||
...,
|
||||
mime_type="audio/wav",
|
||||
description="输出音频",
|
||||
)
|
||||
source_file: FileContent = Field(
|
||||
...,
|
||||
mime_type="audio/wav",
|
||||
description="原始音频.",
|
||||
)
|
||||
__root__: tuple[AudioEntity, AudioEntity]
|
||||
|
||||
def mocking_bird(input: Input) -> Output:
|
||||
"""欢迎使用MockingBird Web 2"""
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
|
||||
def synthesize(input: Input) -> Output:
|
||||
"""synthesize(合成)"""
|
||||
# load models
|
||||
encoder.load_model(Path(input.encoder.value))
|
||||
current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != NULL:
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
|
@ -87,6 +111,8 @@ def mocking_bird(input: Input) -> Output:
|
|||
wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
|
||||
source_spec = Synthesizer.make_spectrogram(wav)
|
||||
|
||||
# preprocess
|
||||
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
||||
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
||||
|
@ -114,4 +140,4 @@ def mocking_bird(input: Input) -> Output:
|
|||
source_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
return Output(source_file=source_file, result_file=result_file)
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
|
167
mkgui/app_vc.py
Normal file
167
mkgui/app_vc.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
from asyncio.windows_events import NULL
|
||||
from synthesizer.inference import Synthesizer
|
||||
from pydantic import BaseModel, Field
|
||||
from encoder import inference as speacker_encoder
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
import ppg_extractor as Extractor
|
||||
import ppg2mel as Convertor
|
||||
import librosa
|
||||
from scipy.io.wavfile import write
|
||||
import re
|
||||
import numpy as np
|
||||
from mkgui.base.components.types import FileContent
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
from typing import Any
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
# Constants
|
||||
AUDIO_SAMPLES_DIR = 'samples\\'
|
||||
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||
CONV_MODELS_DIRT = "ppg2mel\\saved_models"
|
||||
VOC_MODELS_DIRT = "vocoder\\saved_models"
|
||||
TEMP_SOURCE_AUDIO = "wavs/temp_source.wav"
|
||||
TEMP_TARGET_AUDIO = "wavs/temp_target.wav"
|
||||
TEMP_RESULT_AUDIO = "wavs/temp_result.wav"
|
||||
|
||||
# Load local sample audio as options TODO: load dataset
|
||||
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
||||
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
||||
# Pre-Load models
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(VOC_MODELS_DIRT):
|
||||
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
||||
print("Loaded vocoders models: " + str(len(vocoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Input(BaseModel):
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
local_audio_file_target: audio_input_selection = Field(
|
||||
..., alias="目标语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
extractor: extractors = Field(
|
||||
..., alias="编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
vocoder: vocoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: tuple[AudioEntity, AudioEntity, AudioEntity]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, target, result = self.__root__
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Source Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(target.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Target Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
fig, ax = plt.subplots()
|
||||
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
||||
ax.set_title("mel spectrogram(Result Audio)")
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
def convert(input: Input) -> Output:
|
||||
"""convert(转换)"""
|
||||
# load models
|
||||
extractor = Extractor.load_model(Path(input.extractor.value))
|
||||
convertor = Convertor.load_model(Path(input.convertor.value))
|
||||
# current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
else:
|
||||
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
|
||||
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
|
||||
|
||||
if input.upload_audio_file_target != None:
|
||||
with open(TEMP_TARGET_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file_target.as_bytes())
|
||||
f.seek(0)
|
||||
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
|
||||
else:
|
||||
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
|
||||
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
|
||||
|
||||
ppg = extractor.extract_from_wav(src_wav)
|
||||
# Import necessary dependency of Voice Conversion
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
|
||||
embed = speacker_encoder.embed_utterance(ref_wav)
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_, mel_pred, att_ws = convertor.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
|
||||
)
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
breaks = [mel_pred.shape[1]]
|
||||
mel_pred= mel_pred.detach().cpu().numpy()
|
||||
|
||||
# synthesize and vocode
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
|
||||
|
||||
# write and output
|
||||
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
||||
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
||||
source_file = f.read()
|
||||
with open(TEMP_TARGET_AUDIO, "rb") as f:
|
||||
target_file = f.read()
|
||||
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
||||
result_file = f.read()
|
||||
|
||||
|
||||
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
|
2
mkgui/base/__init__.py
Normal file
2
mkgui/base/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
|
||||
from .core import Opyrator
|
1
mkgui/base/api/__init__.py
Normal file
1
mkgui/base/api/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .fastapi_app import create_api
|
102
mkgui/base/api/fastapi_utils.py
Normal file
102
mkgui/base/api/fastapi_utils.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
"""Collection of utilities for FastAPI apps."""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Type
|
||||
|
||||
from fastapi import FastAPI, Form
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def as_form(cls: Type[BaseModel]) -> Any:
|
||||
"""Adds an as_form class method to decorated models.
|
||||
|
||||
The as_form class method can be used with FastAPI endpoints
|
||||
"""
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
field.alias,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
default=(Form(field.default) if not field.required else Form(...)),
|
||||
)
|
||||
for field in cls.__fields__.values()
|
||||
]
|
||||
|
||||
async def _as_form(**data): # type: ignore
|
||||
return cls(**data)
|
||||
|
||||
sig = inspect.signature(_as_form)
|
||||
sig = sig.replace(parameters=new_params)
|
||||
_as_form.__signature__ = sig # type: ignore
|
||||
setattr(cls, "as_form", _as_form)
|
||||
return cls
|
||||
|
||||
|
||||
def patch_fastapi(app: FastAPI) -> None:
|
||||
"""Patch function to allow relative url resolution.
|
||||
|
||||
This patch is required to make fastapi fully functional with a relative url path.
|
||||
This code snippet can be copy-pasted to any Fastapi application.
|
||||
"""
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
async def redoc_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
redoc_ui = get_redoc_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Redoc UI",
|
||||
)
|
||||
|
||||
return HTMLResponse(redoc_ui.body.decode("utf-8"))
|
||||
|
||||
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
||||
assert app.openapi_url is not None
|
||||
swagger_ui = get_swagger_ui_html(
|
||||
openapi_url="./" + app.openapi_url.lstrip("/"),
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||
)
|
||||
|
||||
# insert request interceptor to have all request run on relativ path
|
||||
request_interceptor = (
|
||||
"requestInterceptor: (e) => {"
|
||||
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
|
||||
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
|
||||
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
|
||||
"\n\t\t\te.contextUrl = url"
|
||||
"\n\t\t\te.url = url"
|
||||
"\n\t\t\treturn e;}"
|
||||
)
|
||||
|
||||
return HTMLResponse(
|
||||
swagger_ui.body.decode("utf-8").replace(
|
||||
"dom_id: '#swagger-ui',",
|
||||
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
|
||||
)
|
||||
)
|
||||
|
||||
# remove old docs route and add our patched route
|
||||
routes_new = []
|
||||
for app_route in app.routes:
|
||||
if app_route.path == "/docs": # type: ignore
|
||||
continue
|
||||
|
||||
if app_route.path == "/redoc": # type: ignore
|
||||
continue
|
||||
|
||||
routes_new.append(app_route)
|
||||
|
||||
app.router.routes = routes_new
|
||||
|
||||
assert app.docs_url is not None
|
||||
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
|
||||
assert app.redoc_url is not None
|
||||
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
|
||||
|
||||
# Make graphql realtive
|
||||
from starlette import graphql
|
||||
|
||||
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
|
||||
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
|
||||
)
|
0
mkgui/base/components/__init__.py
Normal file
0
mkgui/base/components/__init__.py
Normal file
43
mkgui/base/components/outputs.py
Normal file
43
mkgui/base/components/outputs.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ScoredLabel(BaseModel):
|
||||
label: str
|
||||
score: float
|
||||
|
||||
|
||||
class ClassificationOutput(BaseModel):
|
||||
__root__: List[ScoredLabel]
|
||||
|
||||
def __iter__(self): # type: ignore
|
||||
return iter(self.__root__)
|
||||
|
||||
def __getitem__(self, item): # type: ignore
|
||||
return self.__root__[item]
|
||||
|
||||
def render_output_ui(self, streamlit) -> None: # type: ignore
|
||||
import plotly.express as px
|
||||
|
||||
sorted_predictions = sorted(
|
||||
[prediction.dict() for prediction in self.__root__],
|
||||
key=lambda k: k["score"],
|
||||
)
|
||||
|
||||
num_labels = len(sorted_predictions)
|
||||
if len(sorted_predictions) > 10:
|
||||
num_labels = streamlit.slider(
|
||||
"Maximum labels to show: ",
|
||||
min_value=1,
|
||||
max_value=len(sorted_predictions),
|
||||
value=len(sorted_predictions),
|
||||
)
|
||||
fig = px.bar(
|
||||
sorted_predictions[len(sorted_predictions) - num_labels :],
|
||||
x="score",
|
||||
y="label",
|
||||
orientation="h",
|
||||
)
|
||||
streamlit.plotly_chart(fig, use_container_width=True)
|
||||
# fig.show()
|
46
mkgui/base/components/types.py
Normal file
46
mkgui/base/components/types.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import base64
|
||||
from typing import Any, Dict, overload
|
||||
|
||||
|
||||
class FileContent(str):
|
||||
def as_bytes(self) -> bytes:
|
||||
return base64.b64decode(self, validate=True)
|
||||
|
||||
def as_str(self) -> str:
|
||||
return self.as_bytes().decode()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(format="byte")
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Any: # type: ignore
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> "FileContent":
|
||||
if isinstance(value, FileContent):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
return FileContent(value)
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return FileContent(base64.b64encode(value).decode())
|
||||
else:
|
||||
raise Exception("Wrong type")
|
||||
|
||||
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
|
||||
# class DirectoryContent(FileContent):
|
||||
# @classmethod
|
||||
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
# field_schema.update(format="path")
|
||||
|
||||
# @classmethod
|
||||
# def validate(cls, value: Any) -> "DirectoryContent":
|
||||
# if isinstance(value, DirectoryContent):
|
||||
# return value
|
||||
# elif isinstance(value, str):
|
||||
# return DirectoryContent(value)
|
||||
# elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# return DirectoryContent(base64.b64encode(value).decode())
|
||||
# else:
|
||||
# raise Exception("Wrong type")
|
203
mkgui/base/core.py
Normal file
203
mkgui/base/core.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Type, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, parse_raw_as
|
||||
from pydantic.tools import parse_obj_as
|
||||
|
||||
|
||||
def name_to_title(name: str) -> str:
|
||||
"""Converts a camelCase or snake_case name to title case."""
|
||||
# If camelCase -> convert to snake case
|
||||
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
||||
# Convert to title case
|
||||
return name.replace("_", " ").strip().title()
|
||||
|
||||
|
||||
def is_compatible_type(type: Type) -> bool:
|
||||
"""Returns `True` if the type is opyrator-compatible."""
|
||||
try:
|
||||
if issubclass(type, BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# valid list type
|
||||
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_input_type(func: Callable) -> Type:
|
||||
"""Returns the input type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the input type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid input type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
if "input" not in type_hints:
|
||||
raise ValueError(
|
||||
"The callable MUST have a parameter with the name `input` with typing annotation. "
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
input_type = type_hints["input"]
|
||||
|
||||
if not is_compatible_type(input_type):
|
||||
raise ValueError(
|
||||
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
# TODO: return warning if more than one input parameters
|
||||
|
||||
return input_type
|
||||
|
||||
|
||||
def get_output_type(func: Callable) -> Type:
|
||||
"""Returns the output type of a given function (callable).
|
||||
|
||||
Args:
|
||||
func: The function for which to get the output type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function does not have a valid output type annotation.
|
||||
"""
|
||||
type_hints = get_type_hints(func)
|
||||
if "return" not in type_hints:
|
||||
raise ValueError(
|
||||
"The return type of the callable MUST be annotated with type hints."
|
||||
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
||||
)
|
||||
|
||||
output_type = type_hints["return"]
|
||||
|
||||
if not is_compatible_type(output_type):
|
||||
raise ValueError(
|
||||
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
||||
)
|
||||
|
||||
return output_type
|
||||
|
||||
|
||||
def get_callable(import_string: str) -> Callable:
|
||||
"""Import a callable from an string."""
|
||||
callable_seperator = ":"
|
||||
if callable_seperator not in import_string:
|
||||
# Use dot as seperator
|
||||
callable_seperator = "."
|
||||
|
||||
if callable_seperator not in import_string:
|
||||
raise ValueError("The callable path MUST specify the function. ")
|
||||
|
||||
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
|
||||
mod = importlib.import_module(mod_name)
|
||||
return getattr(mod, callable_name)
|
||||
|
||||
|
||||
class Opyrator:
|
||||
def __init__(self, func: Union[Callable, str]) -> None:
|
||||
if isinstance(func, str):
|
||||
# Try to load the function from a string notion
|
||||
self.function = get_callable(func)
|
||||
else:
|
||||
self.function = func
|
||||
|
||||
self._action = "Execute"
|
||||
self._input_type = None
|
||||
self._output_type = None
|
||||
|
||||
if not callable(self.function):
|
||||
raise ValueError("The provided function parameters is not a callable.")
|
||||
|
||||
if inspect.isclass(self.function):
|
||||
raise ValueError(
|
||||
"The provided callable is an uninitialized Class. This is not allowed."
|
||||
)
|
||||
|
||||
if inspect.isfunction(self.function):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function)
|
||||
self._output_type = get_output_type(self.function)
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(self.function.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get description from function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(self.function, "__call__"):
|
||||
# The provided callable is a function
|
||||
self._input_type = get_input_type(self.function.__call__) # type: ignore
|
||||
self._output_type = get_output_type(self.function.__call__) # type: ignore
|
||||
|
||||
try:
|
||||
# Get name
|
||||
self._name = name_to_title(type(self.function).__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get action from
|
||||
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
|
||||
if (
|
||||
not self._action
|
||||
or self._action == "Call"
|
||||
):
|
||||
# Get docstring from class instead of __call__ function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown callable type.")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def action(self) -> str:
|
||||
return self._action
|
||||
|
||||
@property
|
||||
def input_type(self) -> Any:
|
||||
return self._input_type
|
||||
|
||||
@property
|
||||
def output_type(self) -> Any:
|
||||
return self._output_type
|
||||
|
||||
def __call__(self, input: Any, **kwargs: Any) -> Any:
|
||||
|
||||
input_obj = input
|
||||
|
||||
if isinstance(input, str):
|
||||
# Allow json input
|
||||
input_obj = parse_raw_as(self.input_type, input)
|
||||
|
||||
if isinstance(input, dict):
|
||||
# Allow dict input
|
||||
input_obj = parse_obj_as(self.input_type, input)
|
||||
|
||||
return self.function(input_obj, **kwargs)
|
1
mkgui/base/ui/__init__.py
Normal file
1
mkgui/base/ui/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .streamlit_ui import render_streamlit_ui
|
129
mkgui/base/ui/schema_utils.py
Normal file
129
mkgui/base/ui/schema_utils.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
from typing import Dict
|
||||
|
||||
|
||||
def resolve_reference(reference: str, references: Dict) -> Dict:
|
||||
return references[reference.split("/")[-1]]
|
||||
|
||||
|
||||
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
|
||||
# Ref can either be directly in the properties or the first element of allOf
|
||||
reference = property.get("$ref")
|
||||
if reference is None:
|
||||
reference = property["allOf"][0]["$ref"]
|
||||
return resolve_reference(reference, references)
|
||||
|
||||
|
||||
def is_single_string_property(property: Dict) -> bool:
|
||||
return property.get("type") == "string"
|
||||
|
||||
|
||||
def is_single_datetime_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") in ["date-time", "time", "date"]
|
||||
|
||||
|
||||
def is_single_boolean_property(property: Dict) -> bool:
|
||||
return property.get("type") == "boolean"
|
||||
|
||||
|
||||
def is_single_number_property(property: Dict) -> bool:
|
||||
return property.get("type") in ["integer", "number"]
|
||||
|
||||
|
||||
def is_single_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
# TODO: binary?
|
||||
return property.get("format") == "byte"
|
||||
|
||||
|
||||
def is_single_directory_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
return property.get("format") == "path"
|
||||
|
||||
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("uniqueItems") is not True:
|
||||
# Only relevant if it is a set or other datastructures with unique items
|
||||
return False
|
||||
|
||||
try:
|
||||
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_enum_property(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
_ = get_single_reference_item(property, references)["enum"]
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_dict_property(property: Dict) -> bool:
|
||||
if property.get("type") != "object":
|
||||
return False
|
||||
return "additionalProperties" in property
|
||||
|
||||
|
||||
def is_single_reference(property: Dict) -> bool:
|
||||
if property.get("type") is not None:
|
||||
return False
|
||||
|
||||
return bool(property.get("$ref"))
|
||||
|
||||
|
||||
def is_multi_file_property(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# TODO: binary
|
||||
return property["items"]["format"] == "byte"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_single_object(property: Dict, references: Dict) -> bool:
|
||||
try:
|
||||
object_reference = get_single_reference_item(property, references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_property_list(property: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
if property.get("items") is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
return property["items"]["type"] in ["string", "number", "integer"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_object_list_property(property: Dict, references: Dict) -> bool:
|
||||
if property.get("type") != "array":
|
||||
return False
|
||||
|
||||
try:
|
||||
object_reference = resolve_reference(property["items"]["$ref"], references)
|
||||
if object_reference["type"] != "object":
|
||||
return False
|
||||
return "properties" in object_reference
|
||||
except Exception:
|
||||
return False
|
885
mkgui/base/ui/streamlit_ui.py
Normal file
885
mkgui/base/ui/streamlit_ui.py
Normal file
|
@ -0,0 +1,885 @@
|
|||
import datetime
|
||||
import inspect
|
||||
import mimetypes
|
||||
import sys
|
||||
from os import getcwd, unlink
|
||||
from platform import system
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
from PIL import Image
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError, parse_obj_as
|
||||
|
||||
from mkgui.base import Opyrator
|
||||
from mkgui.base.core import name_to_title
|
||||
from mkgui.base.ui import schema_utils
|
||||
from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS
|
||||
|
||||
STREAMLIT_RUNNER_SNIPPET = """
|
||||
from mkgui.base.ui import render_streamlit_ui
|
||||
from mkgui.base import Opyrator
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# TODO: Make it configurable
|
||||
# Page config can only be setup once
|
||||
st.set_page_config(
|
||||
page_title="MockingBird",
|
||||
page_icon="🧊",
|
||||
layout="wide")
|
||||
|
||||
render_streamlit_ui()
|
||||
"""
|
||||
|
||||
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
# opyrator = Opyrator("{opyrator_path}")
|
||||
|
||||
|
||||
def launch_ui(port: int = 8501) -> None:
|
||||
with NamedTemporaryFile(
|
||||
suffix=".py", mode="w", encoding="utf-8", delete=False
|
||||
) as f:
|
||||
f.write(STREAMLIT_RUNNER_SNIPPET)
|
||||
f.seek(0)
|
||||
|
||||
import subprocess
|
||||
|
||||
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
||||
if system() == "Windows":
|
||||
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
|
||||
subprocess.run(
|
||||
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
f.close()
|
||||
unlink(f.name)
|
||||
|
||||
|
||||
def function_has_named_arg(func: Callable, parameter: str) -> bool:
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
for param in sig.parameters.values():
|
||||
if param.name == "input":
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def has_output_ui_renderer(data_item: BaseModel) -> bool:
|
||||
return hasattr(data_item, "render_output_ui")
|
||||
|
||||
|
||||
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
|
||||
return hasattr(input_class, "render_input_ui")
|
||||
|
||||
|
||||
def is_compatible_audio(mime_type: str) -> bool:
|
||||
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
|
||||
|
||||
|
||||
def is_compatible_image(mime_type: str) -> bool:
|
||||
return mime_type in ["image/png", "image/jpeg"]
|
||||
|
||||
|
||||
def is_compatible_video(mime_type: str) -> bool:
|
||||
return mime_type in ["video/mp4"]
|
||||
|
||||
|
||||
class InputUI:
|
||||
def __init__(self, session_state, input_class: Type[BaseModel]):
|
||||
self._session_state = session_state
|
||||
self._input_class = input_class
|
||||
|
||||
self._schema_properties = input_class.schema(by_alias=True).get(
|
||||
"properties", {}
|
||||
)
|
||||
self._schema_references = input_class.schema(by_alias=True).get(
|
||||
"definitions", {}
|
||||
)
|
||||
|
||||
def render_ui(self, streamlit_app_root) -> None:
|
||||
if has_input_ui_renderer(self._input_class):
|
||||
# The input model has a rendering function
|
||||
# The rendering also returns the current state of input data
|
||||
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
|
||||
st, self._session_state.input_data
|
||||
)
|
||||
return
|
||||
|
||||
# print(self._schema_properties)
|
||||
for property_key in self._schema_properties.keys():
|
||||
property = self._schema_properties[property_key]
|
||||
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
|
||||
try:
|
||||
if "input_data" in self._session_state:
|
||||
self._store_value(
|
||||
property_key,
|
||||
self._render_property(streamlit_app_root, property_key, property),
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception!", e)
|
||||
pass
|
||||
|
||||
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
|
||||
streamlit_kwargs = {
|
||||
"label": property.get("title"),
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if property.get("description"):
|
||||
streamlit_kwargs["help"] = property.get("description")
|
||||
return streamlit_kwargs
|
||||
|
||||
def _store_value(self, key: str, value: Any) -> None:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
data_element[key_element] = value
|
||||
return
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
|
||||
def _get_value(self, key: str) -> Any:
|
||||
data_element = self._session_state.input_data
|
||||
key_elements = key.split(".")
|
||||
for i, key_element in enumerate(key_elements):
|
||||
if i == len(key_elements) - 1:
|
||||
# add value to this element
|
||||
if key_element not in data_element:
|
||||
return None
|
||||
return data_element[key_element]
|
||||
if key_element not in data_element:
|
||||
data_element[key_element] = {}
|
||||
data_element = data_element[key_element]
|
||||
return None
|
||||
|
||||
def _render_single_datetime_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("format") == "time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.time_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return streamlit_app.date_input(**streamlit_kwargs)
|
||||
elif property.get("format") == "date-time":
|
||||
if property.get("default"):
|
||||
try:
|
||||
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
with streamlit_app.container():
|
||||
streamlit_app.subheader(streamlit_kwargs.get("label"))
|
||||
if streamlit_kwargs.get("description"):
|
||||
streamlit_app.text(streamlit_kwargs.get("description"))
|
||||
selected_date = None
|
||||
selected_time = None
|
||||
date_col, time_col = streamlit_app.columns(2)
|
||||
with date_col:
|
||||
date_kwargs = {"label": "Date", "key": key + "-date-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).date()
|
||||
except Exception:
|
||||
pass
|
||||
selected_date = streamlit_app.date_input(**date_kwargs)
|
||||
|
||||
with time_col:
|
||||
time_kwargs = {"label": "Time", "key": key + "-time-input"}
|
||||
if streamlit_kwargs.get("value"):
|
||||
try:
|
||||
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
||||
"value"
|
||||
).time()
|
||||
except Exception:
|
||||
pass
|
||||
selected_time = streamlit_app.time_input(**time_kwargs)
|
||||
return datetime.datetime.combine(selected_date, selected_time)
|
||||
else:
|
||||
streamlit_app.warning(
|
||||
"Date format is not supported: " + str(property.get("format"))
|
||||
)
|
||||
|
||||
def _render_single_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_file = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||
)
|
||||
if uploaded_file is None:
|
||||
return None
|
||||
|
||||
bytes = uploaded_file.getvalue()
|
||||
if property.get("mime_type"):
|
||||
if is_compatible_audio(property["mime_type"]):
|
||||
# Show audio
|
||||
streamlit_app.audio(bytes, format=property.get("mime_type"))
|
||||
if is_compatible_image(property["mime_type"]):
|
||||
# Show image
|
||||
streamlit_app.image(bytes)
|
||||
if is_compatible_video(property["mime_type"]):
|
||||
# Show video
|
||||
streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||
return bytes
|
||||
|
||||
def _render_single_string_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
elif property.get("example"):
|
||||
# TODO: also use example for other property types
|
||||
# Use example as value if it is provided
|
||||
streamlit_kwargs["value"] = property.get("example")
|
||||
|
||||
if property.get("maxLength") is not None:
|
||||
streamlit_kwargs["max_chars"] = property.get("maxLength")
|
||||
|
||||
if (
|
||||
property.get("format")
|
||||
or (
|
||||
property.get("maxLength") is not None
|
||||
and int(property.get("maxLength")) < 140 # type: ignore
|
||||
)
|
||||
or property.get("writeOnly")
|
||||
):
|
||||
# If any format is set, use single text input
|
||||
# If max chars is set to less than 140, use single text input
|
||||
# If write only -> password field
|
||||
if property.get("writeOnly"):
|
||||
streamlit_kwargs["type"] = "password"
|
||||
return streamlit_app.text_input(**streamlit_kwargs)
|
||||
else:
|
||||
# Otherwise use multiline text area
|
||||
return streamlit_app.text_area(**streamlit_kwargs)
|
||||
|
||||
def _render_multi_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
# TODO: how to select defaults
|
||||
return streamlit_app.multiselect(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_enum_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
try:
|
||||
streamlit_kwargs["index"] = reference_item["enum"].index(
|
||||
property.get("default")
|
||||
)
|
||||
except Exception:
|
||||
# Use default selection
|
||||
pass
|
||||
|
||||
return streamlit_app.selectbox(
|
||||
**streamlit_kwargs, options=reference_item["enum"]
|
||||
)
|
||||
|
||||
def _render_single_dict_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_dict = self._get_value(key)
|
||||
if not current_dict:
|
||||
current_dict = {}
|
||||
|
||||
key_col, value_col = streamlit_app.columns(2)
|
||||
|
||||
with key_col:
|
||||
updated_key = streamlit_app.text_input(
|
||||
"Key", value="", key=key + "-new-key"
|
||||
)
|
||||
|
||||
with value_col:
|
||||
# TODO: also add boolean?
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["additionalProperties"].get("type") == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["additionalProperties"].get("type") == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
updated_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
updated_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_dict = {}
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and updated_key
|
||||
):
|
||||
current_dict[updated_key] = updated_value
|
||||
|
||||
streamlit_app.write(current_dict)
|
||||
|
||||
return current_dict
|
||||
|
||||
def _render_single_reference(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
reference_item = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_property(streamlit_app, key, reference_item)
|
||||
|
||||
def _render_multi_file_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
uploaded_files = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
|
||||
)
|
||||
uploaded_files_bytes = []
|
||||
if uploaded_files:
|
||||
for uploaded_file in uploaded_files:
|
||||
uploaded_files_bytes.append(uploaded_file.read())
|
||||
return uploaded_files_bytes
|
||||
|
||||
def _render_single_boolean_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
if property.get("default"):
|
||||
streamlit_kwargs["value"] = property.get("default")
|
||||
return streamlit_app.checkbox(**streamlit_kwargs)
|
||||
|
||||
def _render_single_number_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
|
||||
number_transform = int
|
||||
if property.get("type") == "number":
|
||||
number_transform = float # type: ignore
|
||||
streamlit_kwargs["format"] = "%f"
|
||||
|
||||
if "multipleOf" in property:
|
||||
# Set stepcount based on multiple of parameter
|
||||
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
|
||||
elif number_transform == int:
|
||||
# Set step size to 1 as default
|
||||
streamlit_kwargs["step"] = 1
|
||||
elif number_transform == float:
|
||||
# Set step size to 0.01 as default
|
||||
# TODO: adapt to default value
|
||||
streamlit_kwargs["step"] = 0.01
|
||||
|
||||
if "minimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
|
||||
if "exclusiveMinimum" in property:
|
||||
streamlit_kwargs["min_value"] = number_transform(
|
||||
property["exclusiveMinimum"] + streamlit_kwargs["step"]
|
||||
)
|
||||
if "maximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
|
||||
|
||||
if "exclusiveMaximum" in property:
|
||||
streamlit_kwargs["max_value"] = number_transform(
|
||||
property["exclusiveMaximum"] - streamlit_kwargs["step"]
|
||||
)
|
||||
|
||||
if property.get("default") is not None:
|
||||
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
|
||||
else:
|
||||
if "min_value" in streamlit_kwargs:
|
||||
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
|
||||
elif number_transform == int:
|
||||
streamlit_kwargs["value"] = 0
|
||||
else:
|
||||
# Set default value to step
|
||||
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
|
||||
|
||||
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
|
||||
# TODO: Only if less than X steps
|
||||
return streamlit_app.slider(**streamlit_kwargs)
|
||||
else:
|
||||
return streamlit_app.number_input(**streamlit_kwargs)
|
||||
|
||||
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
properties = property["properties"]
|
||||
object_inputs = {}
|
||||
for property_key in properties:
|
||||
property = properties[property_key]
|
||||
if not property.get("title"):
|
||||
# Set property key as fallback title
|
||||
property["title"] = name_to_title(property_key)
|
||||
# construct full key based on key parts -> required later to get the value
|
||||
full_key = key + "." + property_key
|
||||
object_inputs[property_key] = self._render_property(
|
||||
streamlit_app, full_key, property
|
||||
)
|
||||
return object_inputs
|
||||
|
||||
def _render_single_object_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
# Add title and subheader
|
||||
title = property.get("title")
|
||||
streamlit_app.subheader(title)
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
object_reference = schema_utils.get_single_reference_item(
|
||||
property, self._schema_references
|
||||
)
|
||||
return self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
def _render_property_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
||||
if property["items"]["type"] == "integer":
|
||||
value_kwargs["value"] = 0 # type: ignore
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
elif property["items"]["type"] == "number":
|
||||
value_kwargs["value"] = 0.0 # type: ignore
|
||||
value_kwargs["format"] = "%f"
|
||||
new_value = streamlit_app.number_input(**value_kwargs)
|
||||
else:
|
||||
value_kwargs["value"] = ""
|
||||
new_value = streamlit_app.text_input(**value_kwargs)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and new_value is not None
|
||||
):
|
||||
current_list.append(new_value)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
|
||||
return current_list
|
||||
|
||||
def _render_object_list_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
||||
# TODO: support max_items, and min_items properties
|
||||
|
||||
# Add title and subheader
|
||||
streamlit_app.subheader(property.get("title"))
|
||||
if property.get("description"):
|
||||
streamlit_app.markdown(property.get("description"))
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
current_list = self._get_value(key)
|
||||
if not current_list:
|
||||
current_list = []
|
||||
|
||||
object_reference = schema_utils.resolve_reference(
|
||||
property["items"]["$ref"], self._schema_references
|
||||
)
|
||||
input_data = self._render_object_input(streamlit_app, key, object_reference)
|
||||
|
||||
streamlit_app.markdown("---")
|
||||
|
||||
with streamlit_app.container():
|
||||
clear_col, add_col = streamlit_app.columns([1, 2])
|
||||
|
||||
with clear_col:
|
||||
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
||||
current_list = []
|
||||
|
||||
with add_col:
|
||||
if (
|
||||
streamlit_app.button("Add Item", key=key + "-add-item")
|
||||
and input_data
|
||||
):
|
||||
current_list.append(input_data)
|
||||
|
||||
streamlit_app.write(current_list)
|
||||
return current_list
|
||||
|
||||
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
||||
if schema_utils.is_single_enum_property(property, self._schema_references):
|
||||
return self._render_single_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_enum_property(property, self._schema_references):
|
||||
return self._render_multi_enum_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_file_property(property):
|
||||
return self._render_single_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_multi_file_property(property):
|
||||
return self._render_multi_file_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_datetime_property(property):
|
||||
return self._render_single_datetime_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_boolean_property(property):
|
||||
return self._render_single_boolean_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_dict_property(property):
|
||||
return self._render_single_dict_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_number_property(property):
|
||||
return self._render_single_number_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_string_property(property):
|
||||
return self._render_single_string_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_object(property, self._schema_references):
|
||||
return self._render_single_object_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_object_list_property(property, self._schema_references):
|
||||
return self._render_object_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_property_list(property):
|
||||
return self._render_property_list_input(streamlit_app, key, property)
|
||||
|
||||
if schema_utils.is_single_reference(property):
|
||||
return self._render_single_reference(streamlit_app, key, property)
|
||||
|
||||
streamlit_app.warning(
|
||||
"The type of the following property is currently not supported: "
|
||||
+ str(property.get("title"))
|
||||
)
|
||||
raise Exception("Unsupported property")
|
||||
|
||||
|
||||
class OutputUI:
|
||||
def __init__(self, output_data: Any, input_data: Any):
|
||||
self._output_data = output_data
|
||||
self._input_data = input_data
|
||||
|
||||
def render_ui(self, streamlit_app) -> None:
|
||||
try:
|
||||
if isinstance(self._output_data, BaseModel):
|
||||
self._render_single_output(streamlit_app, self._output_data)
|
||||
return
|
||||
if type(self._output_data) == list:
|
||||
self._render_list_output(streamlit_app, self._output_data)
|
||||
return
|
||||
except Exception as ex:
|
||||
streamlit_app.exception(ex)
|
||||
# Fallback to
|
||||
streamlit_app.json(jsonable_encoder(self._output_data))
|
||||
|
||||
def _render_single_text_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
streamlit.code(str(value), language="plain")
|
||||
|
||||
def _render_single_file_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
if value is None or value == "":
|
||||
streamlit.info("No value returned!")
|
||||
else:
|
||||
# TODO: Detect if it is a FileContent instance
|
||||
# TODO: detect if it is base64
|
||||
file_extension = ""
|
||||
if "mime_type" in property_schema:
|
||||
mime_type = property_schema["mime_type"]
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ""
|
||||
|
||||
if is_compatible_audio(mime_type):
|
||||
streamlit.audio(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
if is_compatible_image(mime_type):
|
||||
streamlit.image(value.as_bytes())
|
||||
return
|
||||
|
||||
if is_compatible_video(mime_type):
|
||||
streamlit.video(value.as_bytes(), format=mime_type)
|
||||
return
|
||||
|
||||
filename = (
|
||||
(property_schema["title"] + file_extension)
|
||||
.lower()
|
||||
.strip()
|
||||
.replace(" ", "-")
|
||||
)
|
||||
streamlit.markdown(
|
||||
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
def _render_single_complex_property(
|
||||
self, streamlit: st, property_schema: Dict, value: Any
|
||||
) -> None:
|
||||
# Add title and subheader
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
|
||||
streamlit.json(jsonable_encoder(value))
|
||||
|
||||
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
|
||||
try:
|
||||
if has_output_ui_renderer(output_data):
|
||||
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
|
||||
# render method also requests the input data
|
||||
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
|
||||
else:
|
||||
output_data.render_output_ui(streamlit) # type: ignore
|
||||
return
|
||||
except Exception:
|
||||
# Use default auto-generation methods if the custom rendering throws an exception
|
||||
logger.exception(
|
||||
"Failed to execute custom render_output_ui function. Using auto-generation instead"
|
||||
)
|
||||
|
||||
model_schema = output_data.schema(by_alias=False)
|
||||
model_properties = model_schema.get("properties")
|
||||
definitions = model_schema.get("definitions")
|
||||
|
||||
if model_properties:
|
||||
for property_key in output_data.__dict__:
|
||||
property_schema = model_properties.get(property_key)
|
||||
if not property_schema.get("title"):
|
||||
# Set property key as fallback title
|
||||
property_schema["title"] = property_key
|
||||
|
||||
output_property_value = output_data.__dict__[property_key]
|
||||
|
||||
if has_output_ui_renderer(output_property_value):
|
||||
output_property_value.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
|
||||
if isinstance(output_property_value, BaseModel):
|
||||
# Render output recursivly
|
||||
streamlit.subheader(property_schema.get("title"))
|
||||
if property_schema.get("description"):
|
||||
streamlit.markdown(property_schema.get("description"))
|
||||
self._render_single_output(streamlit, output_property_value)
|
||||
continue
|
||||
|
||||
if property_schema:
|
||||
if schema_utils.is_single_file_property(property_schema):
|
||||
self._render_single_file_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
schema_utils.is_single_string_property(property_schema)
|
||||
or schema_utils.is_single_number_property(property_schema)
|
||||
or schema_utils.is_single_datetime_property(property_schema)
|
||||
or schema_utils.is_single_boolean_property(property_schema)
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
continue
|
||||
if definitions and schema_utils.is_single_enum_property(
|
||||
property_schema, definitions
|
||||
):
|
||||
self._render_single_text_property(
|
||||
streamlit, property_schema, output_property_value.value
|
||||
)
|
||||
continue
|
||||
|
||||
# TODO: render dict as table
|
||||
|
||||
self._render_single_complex_property(
|
||||
streamlit, property_schema, output_property_value
|
||||
)
|
||||
return
|
||||
|
||||
def _render_list_output(self, streamlit: st, output_data: List) -> None:
|
||||
try:
|
||||
data_items: List = []
|
||||
for data_item in output_data:
|
||||
if has_output_ui_renderer(data_item):
|
||||
# Render using the render function
|
||||
data_item.render_output_ui(streamlit) # type: ignore
|
||||
continue
|
||||
data_items.append(data_item.dict())
|
||||
# Try to show as dataframe
|
||||
streamlit.table(pd.DataFrame(data_items))
|
||||
except Exception:
|
||||
# Fallback to
|
||||
streamlit.json(jsonable_encoder(output_data))
|
||||
|
||||
|
||||
def getOpyrator(mode: str) -> Opyrator:
|
||||
if mode == None or mode.startswith('VC'):
|
||||
from mkgui.app_vc import convert
|
||||
return Opyrator(convert)
|
||||
if mode == None or mode.startswith('预处理'):
|
||||
from mkgui.preprocess import preprocess
|
||||
return Opyrator(preprocess)
|
||||
if mode == None or mode.startswith('模型训练'):
|
||||
from mkgui.train import train
|
||||
return Opyrator(train)
|
||||
from mkgui.app import synthesize
|
||||
return Opyrator(synthesize)
|
||||
|
||||
|
||||
def render_streamlit_ui() -> None:
|
||||
# init
|
||||
session_state = st.session_state
|
||||
session_state.input_data = {}
|
||||
# Add custom css settings
|
||||
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||
|
||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
session_state.mode = st.sidebar.selectbox(
|
||||
'模式选择',
|
||||
( "AI拟音", "VC拟音", "预处理", "模型训练")
|
||||
)
|
||||
if "mode" in session_state:
|
||||
mode = session_state.mode
|
||||
else:
|
||||
mode = ""
|
||||
opyrator = getOpyrator(mode)
|
||||
title = opyrator.name + mode
|
||||
|
||||
col1, col2, _ = st.columns(3)
|
||||
col2.title(title)
|
||||
col2.markdown("欢迎使用MockingBird Web 2")
|
||||
|
||||
image = Image.open('.\\mkgui\\static\\mb.png')
|
||||
col1.image(image)
|
||||
|
||||
st.markdown("---")
|
||||
left, right = st.columns([0.4, 0.6])
|
||||
|
||||
with left:
|
||||
st.header("Control 控制")
|
||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||
execute_selected = st.button(opyrator.action)
|
||||
if execute_selected:
|
||||
with st.spinner("Executing operation. Please wait..."):
|
||||
try:
|
||||
input_data_obj = parse_obj_as(
|
||||
opyrator.input_type, session_state.input_data
|
||||
)
|
||||
session_state.output_data = opyrator(input=input_data_obj)
|
||||
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
|
||||
except ValidationError as ex:
|
||||
st.error(ex)
|
||||
else:
|
||||
# st.success("Operation executed successfully.")
|
||||
pass
|
||||
|
||||
with right:
|
||||
st.header("Result 结果")
|
||||
if 'output_data' in session_state:
|
||||
OutputUI(
|
||||
session_state.output_data, session_state.latest_operation_input
|
||||
).render_ui(st)
|
||||
if st.button("Clear"):
|
||||
# Clear all state
|
||||
for key in st.session_state.keys():
|
||||
del st.session_state[key]
|
||||
session_state.input_data = {}
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
# placeholder
|
||||
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
||||
|
||||
|
13
mkgui/base/ui/streamlit_utils.py
Normal file
13
mkgui/base/ui/streamlit_utils.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
CUSTOM_STREAMLIT_CSS = """
|
||||
div[data-testid="stBlock"] button {
|
||||
width: 100% !important;
|
||||
margin-bottom: 20px !important;
|
||||
border-color: #bfbfbf !important;
|
||||
}
|
||||
section[data-testid="stSidebar"] div {
|
||||
max-width: 10rem;
|
||||
}
|
||||
pre code {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
"""
|
96
mkgui/preprocess.py
Normal file
96
mkgui/preprocess.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||
ENC_MODELS_DIRT = "encoder\\saved_models"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="目标模型",
|
||||
)
|
||||
dataset: Dataset = Field(
|
||||
Dataset.AIDATATANG_200ZH, title="数据集选择",
|
||||
)
|
||||
datasets_root: str = Field(
|
||||
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
output_root: str = Field(
|
||||
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
|
||||
format=True,
|
||||
example="..\\trainning_data\\"
|
||||
)
|
||||
n_processes: int = Field(
|
||||
2, alias="处理线程数", description="根据CPU线程数来设置",
|
||||
le=32, ge=1
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def preprocess(input: Input) -> Output:
|
||||
"""Preprocess(预处理)"""
|
||||
finished = 0
|
||||
if input.model == Model.VC_PPG2MEL:
|
||||
from ppg2mel.preprocess import preprocess_dataset
|
||||
finished = preprocess_dataset(
|
||||
datasets_root=Path(input.datasets_root),
|
||||
dataset=input.dataset,
|
||||
out_dir=Path(input.output_root),
|
||||
n_processes=input.n_processes,
|
||||
ppg_encoder_model_fpath=Path(input.extractor.value),
|
||||
speaker_encoder_model=Path(input.encoder.value)
|
||||
)
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, finished))
|
BIN
mkgui/static/mb.png
Normal file
BIN
mkgui/static/mb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 5.6 KiB |
156
mkgui/train.py
Normal file
156
mkgui/train.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
from pydantic import BaseModel, Field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from utils.util import AttrDict
|
||||
import torch
|
||||
|
||||
# TODO: seperator for *unix systems
|
||||
# Constants
|
||||
EXT_MODELS_DIRT = "ppg_extractor\\saved_models"
|
||||
CONV_MODELS_DIRT = "ppg2mel\\saved_models"
|
||||
ENC_MODELS_DIRT = "encoder\\saved_models"
|
||||
|
||||
|
||||
if os.path.isdir(EXT_MODELS_DIRT):
|
||||
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded extractor models: " + str(len(extractors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(CONV_MODELS_DIRT):
|
||||
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
||||
print("Loaded convertor models: " + str(len(convertors)))
|
||||
else:
|
||||
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
if os.path.isdir(ENC_MODELS_DIRT):
|
||||
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
||||
print("Loaded encoders models: " + str(len(encoders)))
|
||||
else:
|
||||
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
class Model(str, Enum):
|
||||
VC_PPG2MEL = "ppg2mel"
|
||||
|
||||
class Dataset(str, Enum):
|
||||
AIDATATANG_200ZH = "aidatatang_200zh"
|
||||
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
||||
|
||||
class Input(BaseModel):
|
||||
# def render_input_ui(st, input) -> Dict:
|
||||
# input["selected_dataset"] = st.selectbox(
|
||||
# '选择数据集',
|
||||
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
||||
# )
|
||||
# return input
|
||||
model: Model = Field(
|
||||
Model.VC_PPG2MEL, title="模型类型",
|
||||
)
|
||||
# datasets_root: str = Field(
|
||||
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
||||
# format=True,
|
||||
# example="..\\trainning_data\\"
|
||||
# )
|
||||
output_root: str = Field(
|
||||
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
||||
format=True,
|
||||
example=""
|
||||
)
|
||||
continue_mode: bool = Field(
|
||||
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
||||
)
|
||||
gpu: bool = Field(
|
||||
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
||||
)
|
||||
verbose: bool = Field(
|
||||
True, alias="打印详情", description="选择“是”,输出更多详情",
|
||||
)
|
||||
# TODO: Move to hiden fields by default
|
||||
convertor: convertors = Field(
|
||||
..., alias="转换模型",
|
||||
description="选择语音转换模型文件."
|
||||
)
|
||||
extractor: extractors = Field(
|
||||
..., alias="特征提取模型",
|
||||
description="选择PPG特征提取模型文件."
|
||||
)
|
||||
encoder: encoders = Field(
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件."
|
||||
)
|
||||
njobs: int = Field(
|
||||
8, alias="进程数", description="适用于ppg2mel",
|
||||
)
|
||||
seed: int = Field(
|
||||
default=0, alias="初始随机数", description="适用于ppg2mel",
|
||||
)
|
||||
model_name: str = Field(
|
||||
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example="test"
|
||||
)
|
||||
model_config: str = Field(
|
||||
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
||||
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
mel: Any
|
||||
|
||||
class Output(BaseModel):
|
||||
__root__: tuple[str, int]
|
||||
|
||||
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
||||
"""Custom output UI.
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
sr, count = self.__root__
|
||||
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
||||
|
||||
def train(input: Input) -> Output:
|
||||
"""Train(训练)"""
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
params = AttrDict()
|
||||
params.update({
|
||||
"gpu": input.gpu,
|
||||
"cpu": not input.gpu,
|
||||
"njobs": input.njobs,
|
||||
"seed": input.seed,
|
||||
"verbose": input.verbose,
|
||||
"load": input.convertor.value,
|
||||
"warm_start": False,
|
||||
})
|
||||
if input.continue_mode:
|
||||
# trace old model and config
|
||||
p = Path(input.convertor.value)
|
||||
params.name = p.parent.name
|
||||
# search a config file
|
||||
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
config = HpsYaml(model_config_fpaths[0])
|
||||
params.ckpdir = p.parent.parent
|
||||
params.config = model_config_fpaths[0]
|
||||
params.logdir = os.path.join(p.parent, "log")
|
||||
else:
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(input.config)
|
||||
np.random.seed(input.seed)
|
||||
torch.manual_seed(input.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(input.seed)
|
||||
mode = "train"
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
solver = Solver(config, params, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
|
||||
# TODO: pass useful return code
|
||||
return Output(__root__=(input.dataset, 0))
|
|
@ -191,12 +191,15 @@ class MelDecoderMOLv2(AbsMelDecoder):
|
|||
|
||||
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
||||
|
||||
def load_model(train_config, model_file, device=None):
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
if len(model_config_fpaths) == 0:
|
||||
raise "No model yaml config found for convertor"
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model_config = HpsYaml(train_config)
|
||||
model_config = HpsYaml(model_config_fpaths[0])
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
|
|
|
@ -110,3 +110,4 @@ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder
|
|||
t_fid_file.close()
|
||||
d_fid_file.close()
|
||||
e_fid_file.close()
|
||||
return len(wav_file_list)
|
||||
|
|
|
@ -31,15 +31,10 @@ def main():
|
|||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
# parser.add_argument('--no-pin', action='store_true',
|
||||
# help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
|
|
|
@ -93,6 +93,7 @@ class BaseSolver():
|
|||
|
||||
def load_ckpt(self):
|
||||
''' Load ckpt if --load option is specified '''
|
||||
print(self.paras)
|
||||
if self.paras.load is not None:
|
||||
if self.paras.warm_start:
|
||||
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
|
||||
|
@ -100,7 +101,7 @@ class BaseSolver():
|
|||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
model_dict = ckpt['model']
|
||||
if len(self.config.model.ignore_layers) > 0:
|
||||
if "ignore_layers" in self.config.model and len(self.config.model.ignore_layers) > 0:
|
||||
model_dict = {k:v for k, v in model_dict.items()
|
||||
if k not in self.config.model.ignore_layers}
|
||||
dummy_dict = self.model.state_dict()
|
||||
|
|
|
@ -21,5 +21,7 @@ flask_cors==3.0.10
|
|||
gevent==21.8.0
|
||||
flask_restx
|
||||
tensorboard
|
||||
opyrator
|
||||
streamlit==1.3.1
|
||||
streamlit==1.8.0
|
||||
PyYAML==5.4.1
|
||||
torch_complex
|
||||
espnet
|
|
@ -1,3 +0,0 @@
|
|||
PyYAML==5.4.1
|
||||
torch_complex
|
||||
espnet
|
BIN
samples/T0055G0013S0005.wav
Normal file
BIN
samples/T0055G0013S0005.wav
Normal file
Binary file not shown.
|
@ -405,16 +405,11 @@ class Toolbox:
|
|||
if self.ui.current_convertor_fpath is None:
|
||||
return
|
||||
model_fpath = self.ui.current_convertor_fpath
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.yaml"))
|
||||
if self.ui.current_convertor_fpath is None:
|
||||
return
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
self.ui.log("Loading the convertor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
import ppg2mel as convertor
|
||||
self.convertor = convertor.load_model(model_config_fpath, model_fpath)
|
||||
self.convertor = convertor.load_model( model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
|
|
|
@ -42,3 +42,9 @@ def human_format(num):
|
|||
# add more suffixes if you need them
|
||||
return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude])
|
||||
|
||||
|
||||
# provide easy access of attribute from dict, such abc.key
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
|
|
@ -1,13 +1,6 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
from vocoder.hifigan.env import AttrDict
|
||||
from utils.util import AttrDict
|
||||
from vocoder.hifigan.models import Generator
|
||||
|
||||
generator = None # type: Generator
|
||||
|
@ -26,7 +26,11 @@ def load_model(weights_fpath, config_fpath=None, verbose=True):
|
|||
print("Building hifigan")
|
||||
|
||||
if config_fpath == None:
|
||||
config_fpath = "./vocoder/hifigan/config_16k_.json"
|
||||
model_config_fpaths = list(weights_fpath.parent.rglob("*.json"))
|
||||
if len(model_config_fpaths) > 0:
|
||||
config_fpath = model_config_fpaths[0]
|
||||
else:
|
||||
config_fpath = "./vocoder/hifigan/config_16k_.json"
|
||||
with open(config_fpath) as f:
|
||||
data = f.read()
|
||||
json_config = json.loads(data)
|
||||
|
|
|
@ -12,7 +12,6 @@ from torch.utils.data import DistributedSampler, DataLoader
|
|||
import torch.multiprocessing as mp
|
||||
from torch.distributed import init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from vocoder.hifigan.env import AttrDict, build_env
|
||||
from vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
|
||||
from vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\
|
||||
discriminator_loss
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from utils.argutils import print_args
|
||||
from vocoder.wavernn.train import train
|
||||
from vocoder.hifigan.train import train as train_hifigan
|
||||
from vocoder.hifigan.env import AttrDict
|
||||
from utils.util import AttrDict
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import json
|
||||
|
|
26
web.py
26
web.py
|
@ -1,11 +1,21 @@
|
|||
from web import webApp
|
||||
from gevent import pywsgi as wsgi
|
||||
import os
|
||||
import sys
|
||||
import typer
|
||||
|
||||
cli = typer.Typer()
|
||||
|
||||
@cli.command()
|
||||
def launch_ui(port: int = typer.Option(8080, "--port", "-p")) -> None:
|
||||
"""Start a graphical UI server for the opyrator.
|
||||
|
||||
The UI is auto-generated from the input- and output-schema of the given function.
|
||||
"""
|
||||
# Add the current working directory to the sys path
|
||||
# This is required to resolve the opyrator path
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from mkgui.base.ui.streamlit_ui import launch_ui
|
||||
launch_ui(port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = webApp()
|
||||
host = app.config.get("HOST")
|
||||
port = app.config.get("PORT")
|
||||
print(f"Web server: http://{host}:{port}")
|
||||
server = wsgi.WSGIServer((host, port), app)
|
||||
server.serve_forever()
|
||||
cli()
|
Loading…
Reference in New Issue
Block a user