mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Make framework to support multiple pages
This commit is contained in:
parent
47cc597ad0
commit
c997dbdf66
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@ -4,7 +4,6 @@
|
||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
|
||||
{
|
||||
"name": "Python: Web",
|
||||
"type": "python",
|
||||
@ -68,8 +67,7 @@
|
||||
"request": "launch",
|
||||
"program": "mkgui\\base\\_cli.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [ "mkgui.app:mocking_bird"
|
||||
]
|
||||
"args": []
|
||||
},
|
||||
]
|
||||
}
|
||||
|
13
mkgui/app.py
13
mkgui/app.py
@ -38,6 +38,9 @@ if os.path.isdir(VOC_MODELS_DIRT):
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||
)
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
@ -56,9 +59,6 @@ class Input(BaseModel):
|
||||
..., alias="语音编码模型",
|
||||
description="选择语音编码模型文件(目前只支持HifiGan类型)."
|
||||
)
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="输出文本内容"
|
||||
)
|
||||
|
||||
class AudioEntity(BaseModel):
|
||||
content: bytes
|
||||
@ -72,7 +72,8 @@ class Output(BaseModel):
|
||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||
"""
|
||||
src, result = self.__root__
|
||||
streamlit_app.subheader("Result Audio")
|
||||
|
||||
streamlit_app.subheader("Synthesized Audio")
|
||||
streamlit_app.audio(result.content, format="audio/wav")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
@ -85,8 +86,8 @@ class Output(BaseModel):
|
||||
streamlit_app.pyplot(fig)
|
||||
|
||||
|
||||
def mocking_bird(input: Input) -> Output:
|
||||
"""欢迎使用MockingBird Web 2"""
|
||||
def main(input: Input) -> Output:
|
||||
"""synthesize(合成)"""
|
||||
# load models
|
||||
encoder.load_model(Path(input.encoder.value))
|
||||
current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||
|
@ -5,7 +5,7 @@ import typer
|
||||
cli = typer.Typer()
|
||||
|
||||
@cli.command()
|
||||
def launch_ui(opyrator: str, port: int = typer.Option(8051, "--port", "-p")) -> None:
|
||||
def launch_ui(port: int = typer.Option(8051, "--port", "-p")) -> None:
|
||||
"""Start a graphical UI server for the opyrator.
|
||||
|
||||
The UI is auto-generated from the input- and output-schema of the given function.
|
||||
@ -15,7 +15,7 @@ def launch_ui(opyrator: str, port: int = typer.Option(8051, "--port", "-p")) ->
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from mkgui.base.ui.streamlit_ui import launch_ui
|
||||
launch_ui(opyrator, port)
|
||||
launch_ui(port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
@ -1,69 +0,0 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from mkgui.base import Opyrator
|
||||
from mkgui.base.api.fastapi_utils import patch_fastapi
|
||||
|
||||
|
||||
def launch_api(opyrator_path: str, port: int = 8501, host: str = "0.0.0.0") -> None:
|
||||
import uvicorn
|
||||
|
||||
from mkgui.base import Opyrator
|
||||
from mkgui.base.api import create_api
|
||||
|
||||
app = create_api(Opyrator(opyrator_path))
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
def create_api(opyrator: Opyrator) -> FastAPI:
|
||||
|
||||
title = opyrator.name
|
||||
if "opyrator" not in opyrator.name.lower():
|
||||
title += " - Opyrator"
|
||||
|
||||
# TODO what about version?
|
||||
app = FastAPI(title=title, description=opyrator.description)
|
||||
|
||||
patch_fastapi(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/call",
|
||||
operation_id="call",
|
||||
response_model=opyrator.output_type,
|
||||
# response_model_exclude_unset=True,
|
||||
summary="Execute the opyrator.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
def call(input: opyrator.input_type) -> Any: # type: ignore
|
||||
"""Executes this opyrator."""
|
||||
return opyrator(input)
|
||||
|
||||
@app.get(
|
||||
"/info",
|
||||
operation_id="info",
|
||||
response_model=Dict,
|
||||
# response_model_exclude_unset=True,
|
||||
summary="Get info metadata.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
def info() -> Any: # type: ignore
|
||||
"""Returns informational metadata about this Opyrator."""
|
||||
return {}
|
||||
|
||||
# Redirect to docs
|
||||
@app.get("/", include_in_schema=False)
|
||||
def root() -> Any:
|
||||
return RedirectResponse("./docs")
|
||||
|
||||
return app
|
@ -113,7 +113,7 @@ class Opyrator:
|
||||
self.function = func
|
||||
|
||||
self._name = "Opyrator"
|
||||
self._description = ""
|
||||
self._action = "Execute"
|
||||
self._input_type = None
|
||||
self._output_type = None
|
||||
|
||||
@ -140,7 +140,7 @@ class Opyrator:
|
||||
# Get description from function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._description = doc_string
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
elif hasattr(self.function, "__call__"):
|
||||
@ -155,19 +155,19 @@ class Opyrator:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Get description from
|
||||
# Get action from
|
||||
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
||||
if doc_string:
|
||||
self._description = doc_string
|
||||
self._action = doc_string
|
||||
|
||||
if (
|
||||
not self._description
|
||||
or self._description == "Call self as a function."
|
||||
not self._action
|
||||
or self._action == "Call"
|
||||
):
|
||||
# Get docstring from class instead of __call__ function
|
||||
doc_string = inspect.getdoc(self.function)
|
||||
if doc_string:
|
||||
self._description = doc_string
|
||||
self._action = doc_string
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
@ -178,8 +178,8 @@ class Opyrator:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
def action(self) -> str:
|
||||
return self._action
|
||||
|
||||
@property
|
||||
def input_type(self) -> Any:
|
||||
|
@ -32,21 +32,20 @@ st.set_page_config(
|
||||
page_icon="🧊",
|
||||
layout="wide")
|
||||
|
||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
opyrator = Opyrator("{opyrator_path}")
|
||||
|
||||
render_streamlit_ui(opyrator, action="{action}")
|
||||
render_streamlit_ui()
|
||||
"""
|
||||
|
||||
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
# opyrator = Opyrator("{opyrator_path}")
|
||||
|
||||
def launch_ui(opyrator_path: str, port: int = 8501) -> None:
|
||||
|
||||
def launch_ui(port: int = 8501) -> None:
|
||||
with NamedTemporaryFile(
|
||||
suffix=".py", mode="w", encoding="utf-8", delete=False
|
||||
) as f:
|
||||
f.write(STREAMLIT_RUNNER_SNIPPET.format_map({'opyrator_path': opyrator_path, 'action': "Synthesize"}))
|
||||
f.write(STREAMLIT_RUNNER_SNIPPET)
|
||||
f.seek(0)
|
||||
|
||||
# TODO: PYTHONPATH="$PYTHONPATH:/workspace/opyrator/src"
|
||||
import subprocess
|
||||
|
||||
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
||||
@ -801,30 +800,42 @@ class OutputUI:
|
||||
streamlit.json(jsonable_encoder(output_data))
|
||||
|
||||
|
||||
def render_streamlit_ui(opyrator: Opyrator, action: str = "Execute") -> None:
|
||||
title = opyrator.name
|
||||
def getOpyrator(mode: str) -> Opyrator:
|
||||
# if mode == None or mode.startswith('VC'):
|
||||
# from mkgui.app_vc import vc
|
||||
# return Opyrator(vc)
|
||||
from mkgui.app import main
|
||||
return Opyrator(main)
|
||||
|
||||
|
||||
def render_streamlit_ui() -> None:
|
||||
# init
|
||||
session_state = st.session_state
|
||||
session_state.input_data = {}
|
||||
session_state.mode = None
|
||||
|
||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||
session_state.mode = st.sidebar.selectbox(
|
||||
'模式选择',
|
||||
("AI拟音", "VC拟音")
|
||||
)
|
||||
opyrator = getOpyrator(session_state.mode)
|
||||
title = opyrator.name
|
||||
|
||||
col1, col2, _ = st.columns(3)
|
||||
col2.title(title)
|
||||
col2.markdown("欢迎使用MockingBird Web 2")
|
||||
|
||||
image = Image.open('.\\mkgui\\static\\mb.png')
|
||||
col1.image(image)
|
||||
|
||||
# Add custom css settings
|
||||
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||
|
||||
if opyrator.description:
|
||||
st.markdown(opyrator.description)
|
||||
|
||||
left, right = st.columns([0.3, 0.8])
|
||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(left)
|
||||
|
||||
st.markdown("---")
|
||||
left, right = st.columns([0.4, 0.6])
|
||||
|
||||
with left:
|
||||
execute_selected = st.button(action)
|
||||
st.header("Control 控制")
|
||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||
execute_selected = st.button(opyrator.action)
|
||||
if execute_selected:
|
||||
with st.spinner("Executing operation. Please wait..."):
|
||||
try:
|
||||
@ -838,18 +849,23 @@ def render_streamlit_ui(opyrator: Opyrator, action: str = "Execute") -> None:
|
||||
else:
|
||||
# st.success("Operation executed successfully.")
|
||||
pass
|
||||
if st.button("Clear"):
|
||||
|
||||
with right:
|
||||
st.header("Result 结果")
|
||||
if 'output_data' in session_state:
|
||||
OutputUI(
|
||||
session_state.output_data, session_state.latest_operation_input
|
||||
).render_ui(st)
|
||||
if st.button("Clear"):
|
||||
# Clear all state
|
||||
for key in st.session_state.keys():
|
||||
del st.session_state[key]
|
||||
session_state.input_data = {}
|
||||
st.experimental_rerun()
|
||||
|
||||
|
||||
if 'output_data' in session_state:
|
||||
OutputUI(
|
||||
session_state.output_data, session_state.latest_operation_input
|
||||
).render_ui(right)
|
||||
|
||||
# st.markdown("---")
|
||||
for key in st.session_state.keys():
|
||||
del st.session_state[key]
|
||||
session_state.input_data = {}
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
# placeholder
|
||||
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
||||
|
||||
# Add custom css settings
|
||||
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||
|
||||
|
@ -4,6 +4,9 @@ div[data-testid="stBlock"] button {
|
||||
margin-bottom: 20px !important;
|
||||
border-color: #bfbfbf !important;
|
||||
}
|
||||
section[data-testid="stSidebar"] div {
|
||||
max-width: 10rem;
|
||||
}
|
||||
pre code {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user