fdaudens commited on
Commit
64001ff
·
verified ·
1 Parent(s): a6fcf14

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from pydub import AudioSegment
5
+ import numpy as np
6
+ import io
7
+ from scipy.io import wavfile
8
+ from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor
9
+ from transformers.utils.import_utils import is_flash_attn_2_available
10
+ import base64
11
+ from scipy.io.wavfile import write
12
+ import os
13
+
14
+ # Global model variables
15
+ model = None
16
+ processor = None
17
+
18
+ def load_model():
19
+ """Load model and processor once"""
20
+ global model, processor
21
+ if model is None:
22
+ model = ColQwen2_5Omni.from_pretrained(
23
+ "vidore/colqwen-omni-v0.1",
24
+ torch_dtype=torch.bfloat16,
25
+ device_map="cpu", # Start on CPU for ZeroGPU
26
+ attn_implementation="eager" # ZeroGPU compatible
27
+ ).eval()
28
+ processor = ColQwen2_5OmniProcessor.from_pretrained("manu/colqwen-omni-v0.1")
29
+ return model, processor
30
+
31
+ def chunk_audio(audio_file, chunk_length=30):
32
+ """Split audio into chunks"""
33
+ audio = AudioSegment.from_file(audio_file.name)
34
+
35
+ audios = []
36
+ target_rate = 16000
37
+ chunk_length_ms = chunk_length * 1000
38
+
39
+ for i in range(0, len(audio), chunk_length_ms):
40
+ chunk = audio[i:i + chunk_length_ms]
41
+ chunk = chunk.set_channels(1).set_frame_rate(target_rate)
42
+
43
+ buf = io.BytesIO()
44
+ chunk.export(buf, format="wav")
45
+ buf.seek(0)
46
+
47
+ rate, data = wavfile.read(buf)
48
+ audios.append(data)
49
+
50
+ return audios
51
+
52
+ @spaces.GPU(duration=120)
53
+ def embed_audio_chunks(audios):
54
+ """Embed audio chunks using GPU"""
55
+ model, processor = load_model()
56
+ model = model.to('cuda')
57
+
58
+ # Process in batches
59
+ from torch.utils.data import DataLoader
60
+
61
+ dataloader = DataLoader(
62
+ dataset=audios,
63
+ batch_size=4,
64
+ shuffle=False,
65
+ collate_fn=lambda x: processor.process_audios(x)
66
+ )
67
+
68
+ embeddings = []
69
+ for batch_doc in dataloader:
70
+ with torch.no_grad():
71
+ batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
72
+ embeddings_doc = model(**batch_doc)
73
+ embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
74
+
75
+ # Move model back to CPU to free GPU memory
76
+ model = model.to('cpu')
77
+ torch.cuda.empty_cache()
78
+
79
+ return embeddings
80
+
81
+ @spaces.GPU(duration=60)
82
+ def search_audio(query, embeddings, audios, top_k=5):
83
+ """Search for relevant audio chunks"""
84
+ model, processor = load_model()
85
+ model = model.to('cuda')
86
+
87
+ # Process query
88
+ batch_queries = processor.process_queries([query]).to(model.device)
89
+
90
+ with torch.no_grad():
91
+ query_embeddings = model(**batch_queries)
92
+
93
+ # Score against all embeddings
94
+ scores = processor.score_multi_vector(query_embeddings, embeddings)
95
+ top_indices = scores[0].topk(top_k).indices.tolist()
96
+
97
+ # Move model back to CPU
98
+ model = model.to('cpu')
99
+ torch.cuda.empty_cache()
100
+
101
+ return top_indices
102
+
103
+ def audio_to_base64(data, rate=16000):
104
+ """Convert audio data to base64"""
105
+ buf = io.BytesIO()
106
+ write(buf, rate, data)
107
+ buf.seek(0)
108
+ encoded_string = base64.b64encode(buf.read()).decode("utf-8")
109
+ return encoded_string
110
+
111
+ def process_audio_rag(audio_file, query, chunk_length=30, use_openai=False, openai_key=None):
112
+ """Main processing function"""
113
+ if not audio_file:
114
+ return "Please upload an audio file", None, None
115
+
116
+ # Chunk audio
117
+ audios = chunk_audio(audio_file, chunk_length)
118
+
119
+ # Embed chunks
120
+ embeddings = embed_audio_chunks(audios)
121
+
122
+ # Search for relevant chunks
123
+ top_indices = search_audio(query, embeddings, audios)
124
+
125
+ # Prepare results
126
+ result_text = f"Found {len(top_indices)} relevant audio chunks:\n"
127
+ result_text += f"Chunk indices: {top_indices}\n\n"
128
+
129
+ # Save first result as audio file
130
+ first_chunk_path = "result_chunk.wav"
131
+ wavfile.write(first_chunk_path, 16000, audios[top_indices[0]])
132
+
133
+ # Optional: Use OpenAI for answer generation
134
+ if use_openai and openai_key:
135
+ from openai import OpenAI
136
+ client = OpenAI(api_key=openai_key)
137
+
138
+ content = [{"type": "text", "text": f"Answer the query using the audio files. Query: {query}"}]
139
+
140
+ for idx in top_indices[:3]: # Use top 3 chunks
141
+ content.extend([
142
+ {"type": "text", "text": f"Audio chunk #{idx}:"},
143
+ {
144
+ "type": "input_audio",
145
+ "input_audio": {
146
+ "data": audio_to_base64(audios[idx]),
147
+ "format": "wav"
148
+ }
149
+ }
150
+ ])
151
+
152
+ try:
153
+ completion = client.chat.completions.create(
154
+ model="gpt-4o-audio-preview",
155
+ messages=[{"role": "user", "content": content}]
156
+ )
157
+ result_text += f"\nOpenAI Answer: {completion.choices[0].message.content}"
158
+ except Exception as e:
159
+ result_text += f"\nOpenAI Error: {str(e)}"
160
+
161
+ # Create audio visualization
162
+ import matplotlib.pyplot as plt
163
+ fig, ax = plt.subplots(figsize=(10, 4))
164
+ ax.plot(audios[top_indices[0]])
165
+ ax.set_title(f"Waveform of top matching chunk (#{top_indices[0]})")
166
+ ax.set_xlabel("Samples")
167
+ ax.set_ylabel("Amplitude")
168
+ plt.tight_layout()
169
+
170
+ return result_text, first_chunk_path, fig
171
+
172
+ # Create Gradio interface
173
+ with gr.Blocks(title="AudioRAG Demo") as demo:
174
+ gr.Markdown("# AudioRAG Demo - Semantic Audio Search")
175
+ gr.Markdown("Upload an audio file and search through it using natural language queries!")
176
+
177
+ with gr.Row():
178
+ with gr.Column():
179
+ audio_input = gr.Audio(label="Upload Audio File", type="filepath")
180
+ query_input = gr.Textbox(label="Search Query", placeholder="What are you looking for in the audio?")
181
+ chunk_length = gr.Slider(minimum=10, maximum=60, value=30, step=5, label="Chunk Length (seconds)")
182
+
183
+ with gr.Accordion("OpenAI Integration (Optional)", open=False):
184
+ use_openai = gr.Checkbox(label="Use OpenAI for answer generation")
185
+ openai_key = gr.Textbox(label="OpenAI API Key", type="password")
186
+
187
+ search_btn = gr.Button("Search Audio", variant="primary")
188
+
189
+ with gr.Column():
190
+ output_text = gr.Textbox(label="Results", lines=10)
191
+ output_audio = gr.Audio(label="Top Matching Audio Chunk", type="filepath")
192
+ output_plot = gr.Plot(label="Audio Waveform")
193
+
194
+ search_btn.click(
195
+ fn=process_audio_rag,
196
+ inputs=[audio_input, query_input, chunk_length, use_openai, openai_key],
197
+ outputs=[output_text, output_audio, output_plot]
198
+ )
199
+
200
+ gr.Examples(
201
+ examples=[
202
+ ["example_audio.wav", "Was Hannibal well liked by his men?", 30],
203
+ ["podcast.mp3", "What did they say about climate change?", 20],
204
+ ],
205
+ inputs=[audio_input, query_input, chunk_length]
206
+ )
207
+
208
+ if __name__ == "__main__":
209
+ # Load model on startup
210
+ load_model()
211
+ demo.launch()