mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Make framework to support multiple pages
This commit is contained in:
parent
47cc597ad0
commit
c997dbdf66
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@ -4,7 +4,6 @@
|
|||||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
|
|
||||||
{
|
{
|
||||||
"name": "Python: Web",
|
"name": "Python: Web",
|
||||||
"type": "python",
|
"type": "python",
|
||||||
@ -68,8 +67,7 @@
|
|||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "mkgui\\base\\_cli.py",
|
"program": "mkgui\\base\\_cli.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": [ "mkgui.app:mocking_bird"
|
"args": []
|
||||||
]
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
13
mkgui/app.py
13
mkgui/app.py
@ -38,6 +38,9 @@ if os.path.isdir(VOC_MODELS_DIRT):
|
|||||||
|
|
||||||
|
|
||||||
class Input(BaseModel):
|
class Input(BaseModel):
|
||||||
|
message: str = Field(
|
||||||
|
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||||
|
)
|
||||||
local_audio_file: audio_input_selection = Field(
|
local_audio_file: audio_input_selection = Field(
|
||||||
..., alias="输入语音(本地wav)",
|
..., alias="输入语音(本地wav)",
|
||||||
description="选择本地语音文件."
|
description="选择本地语音文件."
|
||||||
@ -56,9 +59,6 @@ class Input(BaseModel):
|
|||||||
..., alias="语音编码模型",
|
..., alias="语音编码模型",
|
||||||
description="选择语音编码模型文件(目前只支持HifiGan类型)."
|
description="选择语音编码模型文件(目前只支持HifiGan类型)."
|
||||||
)
|
)
|
||||||
message: str = Field(
|
|
||||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="输出文本内容"
|
|
||||||
)
|
|
||||||
|
|
||||||
class AudioEntity(BaseModel):
|
class AudioEntity(BaseModel):
|
||||||
content: bytes
|
content: bytes
|
||||||
@ -72,7 +72,8 @@ class Output(BaseModel):
|
|||||||
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
||||||
"""
|
"""
|
||||||
src, result = self.__root__
|
src, result = self.__root__
|
||||||
streamlit_app.subheader("Result Audio")
|
|
||||||
|
streamlit_app.subheader("Synthesized Audio")
|
||||||
streamlit_app.audio(result.content, format="audio/wav")
|
streamlit_app.audio(result.content, format="audio/wav")
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
@ -85,8 +86,8 @@ class Output(BaseModel):
|
|||||||
streamlit_app.pyplot(fig)
|
streamlit_app.pyplot(fig)
|
||||||
|
|
||||||
|
|
||||||
def mocking_bird(input: Input) -> Output:
|
def main(input: Input) -> Output:
|
||||||
"""欢迎使用MockingBird Web 2"""
|
"""synthesize(合成)"""
|
||||||
# load models
|
# load models
|
||||||
encoder.load_model(Path(input.encoder.value))
|
encoder.load_model(Path(input.encoder.value))
|
||||||
current_synt = Synthesizer(Path(input.synthesizer.value))
|
current_synt = Synthesizer(Path(input.synthesizer.value))
|
||||||
|
@ -5,7 +5,7 @@ import typer
|
|||||||
cli = typer.Typer()
|
cli = typer.Typer()
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def launch_ui(opyrator: str, port: int = typer.Option(8051, "--port", "-p")) -> None:
|
def launch_ui(port: int = typer.Option(8051, "--port", "-p")) -> None:
|
||||||
"""Start a graphical UI server for the opyrator.
|
"""Start a graphical UI server for the opyrator.
|
||||||
|
|
||||||
The UI is auto-generated from the input- and output-schema of the given function.
|
The UI is auto-generated from the input- and output-schema of the given function.
|
||||||
@ -15,7 +15,7 @@ def launch_ui(opyrator: str, port: int = typer.Option(8051, "--port", "-p")) ->
|
|||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
from mkgui.base.ui.streamlit_ui import launch_ui
|
from mkgui.base.ui.streamlit_ui import launch_ui
|
||||||
launch_ui(opyrator, port)
|
launch_ui(port)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
@ -1,69 +0,0 @@
|
|||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from fastapi import FastAPI, status
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from starlette.responses import RedirectResponse
|
|
||||||
|
|
||||||
from mkgui.base import Opyrator
|
|
||||||
from mkgui.base.api.fastapi_utils import patch_fastapi
|
|
||||||
|
|
||||||
|
|
||||||
def launch_api(opyrator_path: str, port: int = 8501, host: str = "0.0.0.0") -> None:
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
from mkgui.base import Opyrator
|
|
||||||
from mkgui.base.api import create_api
|
|
||||||
|
|
||||||
app = create_api(Opyrator(opyrator_path))
|
|
||||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
|
||||||
|
|
||||||
|
|
||||||
def create_api(opyrator: Opyrator) -> FastAPI:
|
|
||||||
|
|
||||||
title = opyrator.name
|
|
||||||
if "opyrator" not in opyrator.name.lower():
|
|
||||||
title += " - Opyrator"
|
|
||||||
|
|
||||||
# TODO what about version?
|
|
||||||
app = FastAPI(title=title, description=opyrator.description)
|
|
||||||
|
|
||||||
patch_fastapi(app)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.post(
|
|
||||||
"/call",
|
|
||||||
operation_id="call",
|
|
||||||
response_model=opyrator.output_type,
|
|
||||||
# response_model_exclude_unset=True,
|
|
||||||
summary="Execute the opyrator.",
|
|
||||||
status_code=status.HTTP_200_OK,
|
|
||||||
)
|
|
||||||
def call(input: opyrator.input_type) -> Any: # type: ignore
|
|
||||||
"""Executes this opyrator."""
|
|
||||||
return opyrator(input)
|
|
||||||
|
|
||||||
@app.get(
|
|
||||||
"/info",
|
|
||||||
operation_id="info",
|
|
||||||
response_model=Dict,
|
|
||||||
# response_model_exclude_unset=True,
|
|
||||||
summary="Get info metadata.",
|
|
||||||
status_code=status.HTTP_200_OK,
|
|
||||||
)
|
|
||||||
def info() -> Any: # type: ignore
|
|
||||||
"""Returns informational metadata about this Opyrator."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Redirect to docs
|
|
||||||
@app.get("/", include_in_schema=False)
|
|
||||||
def root() -> Any:
|
|
||||||
return RedirectResponse("./docs")
|
|
||||||
|
|
||||||
return app
|
|
@ -113,7 +113,7 @@ class Opyrator:
|
|||||||
self.function = func
|
self.function = func
|
||||||
|
|
||||||
self._name = "Opyrator"
|
self._name = "Opyrator"
|
||||||
self._description = ""
|
self._action = "Execute"
|
||||||
self._input_type = None
|
self._input_type = None
|
||||||
self._output_type = None
|
self._output_type = None
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ class Opyrator:
|
|||||||
# Get description from function
|
# Get description from function
|
||||||
doc_string = inspect.getdoc(self.function)
|
doc_string = inspect.getdoc(self.function)
|
||||||
if doc_string:
|
if doc_string:
|
||||||
self._description = doc_string
|
self._action = doc_string
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
elif hasattr(self.function, "__call__"):
|
elif hasattr(self.function, "__call__"):
|
||||||
@ -155,19 +155,19 @@ class Opyrator:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get description from
|
# Get action from
|
||||||
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
||||||
if doc_string:
|
if doc_string:
|
||||||
self._description = doc_string
|
self._action = doc_string
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not self._description
|
not self._action
|
||||||
or self._description == "Call self as a function."
|
or self._action == "Call"
|
||||||
):
|
):
|
||||||
# Get docstring from class instead of __call__ function
|
# Get docstring from class instead of __call__ function
|
||||||
doc_string = inspect.getdoc(self.function)
|
doc_string = inspect.getdoc(self.function)
|
||||||
if doc_string:
|
if doc_string:
|
||||||
self._description = doc_string
|
self._action = doc_string
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -178,8 +178,8 @@ class Opyrator:
|
|||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def action(self) -> str:
|
||||||
return self._description
|
return self._action
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_type(self) -> Any:
|
def input_type(self) -> Any:
|
||||||
|
@ -32,21 +32,20 @@ st.set_page_config(
|
|||||||
page_icon="🧊",
|
page_icon="🧊",
|
||||||
layout="wide")
|
layout="wide")
|
||||||
|
|
||||||
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
render_streamlit_ui()
|
||||||
opyrator = Opyrator("{opyrator_path}")
|
|
||||||
|
|
||||||
render_streamlit_ui(opyrator, action="{action}")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||||
|
# opyrator = Opyrator("{opyrator_path}")
|
||||||
|
|
||||||
def launch_ui(opyrator_path: str, port: int = 8501) -> None:
|
|
||||||
|
def launch_ui(port: int = 8501) -> None:
|
||||||
with NamedTemporaryFile(
|
with NamedTemporaryFile(
|
||||||
suffix=".py", mode="w", encoding="utf-8", delete=False
|
suffix=".py", mode="w", encoding="utf-8", delete=False
|
||||||
) as f:
|
) as f:
|
||||||
f.write(STREAMLIT_RUNNER_SNIPPET.format_map({'opyrator_path': opyrator_path, 'action': "Synthesize"}))
|
f.write(STREAMLIT_RUNNER_SNIPPET)
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
|
|
||||||
# TODO: PYTHONPATH="$PYTHONPATH:/workspace/opyrator/src"
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
||||||
@ -801,30 +800,42 @@ class OutputUI:
|
|||||||
streamlit.json(jsonable_encoder(output_data))
|
streamlit.json(jsonable_encoder(output_data))
|
||||||
|
|
||||||
|
|
||||||
def render_streamlit_ui(opyrator: Opyrator, action: str = "Execute") -> None:
|
def getOpyrator(mode: str) -> Opyrator:
|
||||||
title = opyrator.name
|
# if mode == None or mode.startswith('VC'):
|
||||||
|
# from mkgui.app_vc import vc
|
||||||
|
# return Opyrator(vc)
|
||||||
|
from mkgui.app import main
|
||||||
|
return Opyrator(main)
|
||||||
|
|
||||||
|
|
||||||
|
def render_streamlit_ui() -> None:
|
||||||
# init
|
# init
|
||||||
session_state = st.session_state
|
session_state = st.session_state
|
||||||
session_state.input_data = {}
|
session_state.input_data = {}
|
||||||
|
session_state.mode = None
|
||||||
|
|
||||||
|
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
||||||
|
session_state.mode = st.sidebar.selectbox(
|
||||||
|
'模式选择',
|
||||||
|
("AI拟音", "VC拟音")
|
||||||
|
)
|
||||||
|
opyrator = getOpyrator(session_state.mode)
|
||||||
|
title = opyrator.name
|
||||||
|
|
||||||
col1, col2, _ = st.columns(3)
|
col1, col2, _ = st.columns(3)
|
||||||
col2.title(title)
|
col2.title(title)
|
||||||
|
col2.markdown("欢迎使用MockingBird Web 2")
|
||||||
|
|
||||||
image = Image.open('.\\mkgui\\static\\mb.png')
|
image = Image.open('.\\mkgui\\static\\mb.png')
|
||||||
col1.image(image)
|
col1.image(image)
|
||||||
|
|
||||||
# Add custom css settings
|
st.markdown("---")
|
||||||
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
left, right = st.columns([0.4, 0.6])
|
||||||
|
|
||||||
if opyrator.description:
|
|
||||||
st.markdown(opyrator.description)
|
|
||||||
|
|
||||||
left, right = st.columns([0.3, 0.8])
|
|
||||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(left)
|
|
||||||
|
|
||||||
|
|
||||||
with left:
|
with left:
|
||||||
execute_selected = st.button(action)
|
st.header("Control 控制")
|
||||||
|
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||||
|
execute_selected = st.button(opyrator.action)
|
||||||
if execute_selected:
|
if execute_selected:
|
||||||
with st.spinner("Executing operation. Please wait..."):
|
with st.spinner("Executing operation. Please wait..."):
|
||||||
try:
|
try:
|
||||||
@ -838,18 +849,23 @@ def render_streamlit_ui(opyrator: Opyrator, action: str = "Execute") -> None:
|
|||||||
else:
|
else:
|
||||||
# st.success("Operation executed successfully.")
|
# st.success("Operation executed successfully.")
|
||||||
pass
|
pass
|
||||||
if st.button("Clear"):
|
|
||||||
|
with right:
|
||||||
|
st.header("Result 结果")
|
||||||
|
if 'output_data' in session_state:
|
||||||
|
OutputUI(
|
||||||
|
session_state.output_data, session_state.latest_operation_input
|
||||||
|
).render_ui(st)
|
||||||
|
if st.button("Clear"):
|
||||||
# Clear all state
|
# Clear all state
|
||||||
for key in st.session_state.keys():
|
for key in st.session_state.keys():
|
||||||
del st.session_state[key]
|
del st.session_state[key]
|
||||||
session_state.input_data = {}
|
session_state.input_data = {}
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
else:
|
||||||
|
# placeholder
|
||||||
if 'output_data' in session_state:
|
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
||||||
OutputUI(
|
|
||||||
session_state.output_data, session_state.latest_operation_input
|
# Add custom css settings
|
||||||
).render_ui(right)
|
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
||||||
|
|
||||||
# st.markdown("---")
|
|
||||||
|
|
||||||
|
@ -4,6 +4,9 @@ div[data-testid="stBlock"] button {
|
|||||||
margin-bottom: 20px !important;
|
margin-bottom: 20px !important;
|
||||||
border-color: #bfbfbf !important;
|
border-color: #bfbfbf !important;
|
||||||
}
|
}
|
||||||
|
section[data-testid="stSidebar"] div {
|
||||||
|
max-width: 10rem;
|
||||||
|
}
|
||||||
pre code {
|
pre code {
|
||||||
white-space: pre-wrap;
|
white-space: pre-wrap;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user