randomblock1's picture
Update app.py
26082df verified
import torch
import torchaudio
import gradio as gr
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from datetime import date
# device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model + processor
model_name = "ibm-granite/granite-speech-3.3-8b"
processor = AutoProcessor.from_pretrained(model_name)
tokenizer = processor.tokenizer
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name, device_map=device, torch_dtype=torch.bfloat16
)
def transcribe(audio_file, user_prompt):
# load wav file
wav, sr = torchaudio.load(audio_file, normalize=True)
if wav.shape[0] != 1 or sr != 16000:
# resample + convert to mono if needed
wav = torch.mean(wav, dim=0, keepdim=True) # mono
wav = torchaudio.functional.resample(wav, sr, 16000)
sr = 16000
today_str = date.today().strftime("%B %d, %Y")
system_prompt = (
"Knowledge Cutoff Date: April 2024.\n"
f"Today's Date: {today_str}.\n"
"You are Granite, developed by IBM. You are a helpful AI assistant."
)
chat = [
dict(role="system", content=system_prompt),
dict(role="user", content=f"<|audio|>{user_prompt}"),
]
prompt = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True)
# run model
model_inputs = processor(
prompt,
wav,
device=device,
return_tensors="pt").to(device)
model_outputs = model.generate(
**model_inputs,
max_new_tokens=512,
do_sample=False,
num_beams=1
)
# strip prompt tokens
num_input_tokens = model_inputs["input_ids"].shape[-1]
new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
output_text = tokenizer.batch_decode(
new_tokens, add_special_tokens=False, skip_special_tokens=True
)
return output_text[0].strip()
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Granite 3.3 Speech-to-Text")
gr.Markdown(
"Upload an audio file and Granite Speech 3.3 8b will transcribe it into text."
"You can also edit the prompt below to customize what Granite should do with the audio, like translation."
)
with gr.Row():
audio_input = gr.Audio(type="filepath",
label="Upload Audio (16kHz mono preferred)")
output_text = gr.Textbox(label="Transcription", lines=5)
user_prompt = gr.Textbox(
label="User Prompt",
value="Can you transcribe the speech into a written format?",
lines=2
)
transcribe_btn = gr.Button("Transcribe")
transcribe_btn.click(
fn=transcribe,
inputs=[
audio_input,
user_prompt],
outputs=output_text)
demo.launch()