Spaces:
Running
Running
import streamlit as st | |
import torch | |
import tempfile | |
import os | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
from audiorecorder import audiorecorder | |
from pydub import AudioSegment | |
# Setup model | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model_id = "KBLab/kb-whisper-tiny" | |
def load_model(): | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, use_safetensors=True, cache_dir="cache" | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
return pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
pipe = load_model() | |
def transcribe_audio(audio_path): | |
return pipe(audio_path, chunk_length_s=30, generate_kwargs={"task": "transcribe", "language": "sv"}) | |
st.title("Speech-to-Text Transcription") | |
# Audio recording | |
st.subheader("Record Audio") | |
recorded_audio = audiorecorder("Start Recording", "Stop Recording") | |
if recorded_audio: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
temp_file.write(recorded_audio.tobytes()) | |
temp_file_path = temp_file.name | |
st.audio(temp_file_path, format="audio/wav") | |
result = transcribe_audio(temp_file_path) | |
st.write("### Transcription:") | |
st.write(result["text"]) | |
os.remove(temp_file_path) | |
# File upload | |
st.subheader("Upload Audio File") | |
uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg", "flac"]) | |
if uploaded_file: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[-1]) as temp_file: | |
temp_file.write(uploaded_file.read()) | |
temp_file_path = temp_file.name | |
st.audio(temp_file_path) | |
result = transcribe_audio(temp_file_path) | |
st.write("### Transcription:") | |
st.write(result["text"]) | |
os.remove(temp_file_path) | |