Spaces:
Running
Running
import torch | |
import torchaudio | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
from datetime import date | |
# device setup | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load model + processor | |
model_name = "ibm-granite/granite-speech-3.3-8b" | |
processor = AutoProcessor.from_pretrained(model_name) | |
tokenizer = processor.tokenizer | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_name, device_map=device, torch_dtype=torch.bfloat16 | |
) | |
def transcribe(audio_file, user_prompt): | |
# load wav file | |
wav, sr = torchaudio.load(audio_file, normalize=True) | |
if wav.shape[0] != 1 or sr != 16000: | |
# resample + convert to mono if needed | |
wav = torch.mean(wav, dim=0, keepdim=True) # mono | |
wav = torchaudio.functional.resample(wav, sr, 16000) | |
sr = 16000 | |
today_str = date.today().strftime("%B %d, %Y") | |
system_prompt = ( | |
"Knowledge Cutoff Date: April 2024.\n" | |
f"Today's Date: {today_str}.\n" | |
"You are Granite, developed by IBM. You are a helpful AI assistant." | |
) | |
chat = [ | |
dict(role="system", content=system_prompt), | |
dict(role="user", content=f"<|audio|>{user_prompt}"), | |
] | |
prompt = tokenizer.apply_chat_template( | |
chat, tokenize=False, add_generation_prompt=True) | |
# run model | |
model_inputs = processor( | |
prompt, | |
wav, | |
device=device, | |
return_tensors="pt").to(device) | |
model_outputs = model.generate( | |
**model_inputs, | |
max_new_tokens=512, | |
do_sample=False, | |
num_beams=1 | |
) | |
# strip prompt tokens | |
num_input_tokens = model_inputs["input_ids"].shape[-1] | |
new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0) | |
output_text = tokenizer.batch_decode( | |
new_tokens, add_special_tokens=False, skip_special_tokens=True | |
) | |
return output_text[0].strip() | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## Granite 3.3 Speech-to-Text") | |
gr.Markdown( | |
"Upload an audio file and Granite Speech 3.3 8b will transcribe it into text." | |
"You can also edit the prompt below to customize what Granite should do with the audio, like translation." | |
) | |
with gr.Row(): | |
audio_input = gr.Audio(type="filepath", | |
label="Upload Audio (16kHz mono preferred)") | |
output_text = gr.Textbox(label="Transcription", lines=5) | |
user_prompt = gr.Textbox( | |
label="User Prompt", | |
value="Can you transcribe the speech into a written format?", | |
lines=2 | |
) | |
transcribe_btn = gr.Button("Transcribe") | |
transcribe_btn.click( | |
fn=transcribe, | |
inputs=[ | |
audio_input, | |
user_prompt], | |
outputs=output_text) | |
demo.launch() | |