|
|
|
|
|
|
|
from tempfile import NamedTemporaryFile |
|
from typing import Any |
|
|
|
import streamlit as st |
|
|
|
from conette import CoNeTTEModel, conette |
|
|
|
|
|
@st.cache_resource |
|
def load_conette(*args, **kwargs) -> CoNeTTEModel: |
|
return conette(*args, **kwargs) |
|
|
|
|
|
def main() -> None: |
|
st.header("Describe audio content with CoNeTTE") |
|
|
|
model = load_conette(model_kwds=dict(device="cpu")) |
|
|
|
task = st.selectbox("Task embedding input", model.tasks, 0) |
|
beam_size: int = st.select_slider( |
|
"Beam size", |
|
list(range(1, 20)), |
|
model.config.beam_size, |
|
) |
|
min_pred_size: int = st.select_slider( |
|
"Minimal number of words", |
|
list(range(1, 31)), |
|
model.config.min_pred_size, |
|
) |
|
max_pred_size: int = st.select_slider( |
|
"Maximal number of words", |
|
list(range(1, 31)), |
|
model.config.max_pred_size, |
|
) |
|
|
|
st.write("Recommanded audio: lasting from 1s to 30s, sampled at 32 kHz.") |
|
audios = st.file_uploader( |
|
"Upload an audio file", |
|
type=["wav", "flac", "mp3", "ogg", "avi"], |
|
accept_multiple_files=True, |
|
) |
|
|
|
if audios is not None and len(audios) > 0: |
|
for audio in audios: |
|
with NamedTemporaryFile() as temp: |
|
temp.write(audio.getvalue()) |
|
fpath = temp.name |
|
|
|
kwargs: dict[str, Any] = dict( |
|
task=task, |
|
beam_size=beam_size, |
|
min_pred_size=min_pred_size, |
|
max_pred_size=max_pred_size, |
|
) |
|
cand_key = f"{audio.name}-{kwargs}" |
|
|
|
if cand_key in st.session_state: |
|
cand = st.session_state[cand_key] |
|
else: |
|
outputs = model( |
|
fpath, |
|
**kwargs, |
|
) |
|
cand = outputs["cands"][0] |
|
st.session_state[cand_key] = cand |
|
|
|
st.write(f"Output for {audio.name}:") |
|
st.write(" - ", cand) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|