randomblock1 commited on
Commit
f0fc394
·
verified ·
1 Parent(s): f680d7e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import gradio as gr
4
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
5
+
6
+ # device setup
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ # load model + processor
10
+ model_name = "ibm-granite/granite-speech-3.3-8b"
11
+ processor = AutoProcessor.from_pretrained(model_name)
12
+ tokenizer = processor.tokenizer
13
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
14
+ model_name, device_map=device, torch_dtype=torch.bfloat16
15
+ )
16
+
17
+ today_str = date.today().strftime("%B %d, %Y")
18
+
19
+ system_prompt = (
20
+ "Knowledge Cutoff Date: April 2024.\n"
21
+ f"Today's Date: {today_str}.\n"
22
+ "You are Granite, developed by IBM. You are a helpful AI assistant."
23
+ )
24
+
25
+ def transcribe(audio_file):
26
+ # load wav file
27
+ wav, sr = torchaudio.load(audio_file, normalize=True)
28
+ if wav.shape[0] != 1 or sr != 16000:
29
+ # resample + convert to mono if needed
30
+ wav = torch.mean(wav, dim=0, keepdim=True) # mono
31
+ wav = torchaudio.functional.resample(wav, sr, 16000)
32
+ sr = 16000
33
+
34
+ # user prompt
35
+ user_prompt = "<|audio|>can you transcribe the speech into a written format?"
36
+ chat = [
37
+ dict(role="system", content=system_prompt),
38
+ dict(role="user", content=user_prompt),
39
+ ]
40
+ prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
41
+
42
+ # run model
43
+ model_inputs = processor(prompt, wav, sampling_rate=sr, device=device, return_tensors="pt").to(device)
44
+ model_outputs = model.generate(
45
+ **model_inputs,
46
+ max_new_tokens=200,
47
+ do_sample=False,
48
+ num_beams=1
49
+ )
50
+
51
+ # strip prompt tokens
52
+ num_input_tokens = model_inputs["input_ids"].shape[-1]
53
+ new_tokens = torch.unsqueeze(model_outputs[0, num_input_tokens:], dim=0)
54
+ output_text = tokenizer.batch_decode(
55
+ new_tokens, add_special_tokens=False, skip_special_tokens=True
56
+ )
57
+
58
+ return output_text[0].strip()
59
+
60
+ # Gradio UI
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("## Granite 3.3 Speech-to-Text Demo")
63
+
64
+ with gr.Row():
65
+ audio_input = gr.Audio(type="filepath", label="Upload Audio (16kHz mono preferred)")
66
+ output_text = gr.Textbox(label="Transcription", lines=5)
67
+
68
+ transcribe_btn = gr.Button("Transcribe")
69
+ transcribe_btn.click(fn=transcribe, inputs=audio_input, outputs=output_text)
70
+
71
+ demo.launch()