mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
|
from typing import List
|
||
|
|
||
|
from pydantic import BaseModel
|
||
|
|
||
|
|
||
|
class ScoredLabel(BaseModel):
|
||
|
label: str
|
||
|
score: float
|
||
|
|
||
|
|
||
|
class ClassificationOutput(BaseModel):
|
||
|
__root__: List[ScoredLabel]
|
||
|
|
||
|
def __iter__(self): # type: ignore
|
||
|
return iter(self.__root__)
|
||
|
|
||
|
def __getitem__(self, item): # type: ignore
|
||
|
return self.__root__[item]
|
||
|
|
||
|
def render_output_ui(self, streamlit) -> None: # type: ignore
|
||
|
import plotly.express as px
|
||
|
|
||
|
sorted_predictions = sorted(
|
||
|
[prediction.dict() for prediction in self.__root__],
|
||
|
key=lambda k: k["score"],
|
||
|
)
|
||
|
|
||
|
num_labels = len(sorted_predictions)
|
||
|
if len(sorted_predictions) > 10:
|
||
|
num_labels = streamlit.slider(
|
||
|
"Maximum labels to show: ",
|
||
|
min_value=1,
|
||
|
max_value=len(sorted_predictions),
|
||
|
value=len(sorted_predictions),
|
||
|
)
|
||
|
fig = px.bar(
|
||
|
sorted_predictions[len(sorted_predictions) - num_labels :],
|
||
|
x="score",
|
||
|
y="label",
|
||
|
orientation="h",
|
||
|
)
|
||
|
streamlit.plotly_chart(fig, use_container_width=True)
|
||
|
# fig.show()
|