akhaliq HF Staff commited on
Commit
6a6a2f0
·
verified ·
1 Parent(s): aab58a2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +264 -0
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import cv2
6
+ import numpy as np
7
+ from typing import Optional
8
+ import tempfile
9
+ import os
10
+
11
+ MID = "apple/FastVLM-7B"
12
+ IMAGE_TOKEN_INDEX = -200
13
+
14
+ # Load model and tokenizer
15
+ print("Loading FastVLM model...")
16
+ tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MID,
19
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
20
+ device_map="auto",
21
+ trust_remote_code=True,
22
+ )
23
+ print("Model loaded successfully!")
24
+
25
+ def extract_frames(video_path: str, num_frames: int = 8, sampling_method: str = "uniform"):
26
+ """Extract frames from video"""
27
+ cap = cv2.VideoCapture(video_path)
28
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
29
+
30
+ if total_frames == 0:
31
+ cap.release()
32
+ return []
33
+
34
+ frames = []
35
+
36
+ if sampling_method == "uniform":
37
+ # Uniform sampling
38
+ indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
39
+ elif sampling_method == "first":
40
+ # Take first N frames
41
+ indices = list(range(min(num_frames, total_frames)))
42
+ elif sampling_method == "last":
43
+ # Take last N frames
44
+ start = max(0, total_frames - num_frames)
45
+ indices = list(range(start, total_frames))
46
+ else: # middle
47
+ # Take frames from the middle
48
+ start = max(0, (total_frames - num_frames) // 2)
49
+ indices = list(range(start, min(start + num_frames, total_frames)))
50
+
51
+ for idx in indices:
52
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
53
+ ret, frame = cap.read()
54
+ if ret:
55
+ # Convert BGR to RGB
56
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
57
+ frames.append(Image.fromarray(frame_rgb))
58
+
59
+ cap.release()
60
+ return frames
61
+
62
+ def caption_frame(image: Image.Image, prompt: str) -> str:
63
+ """Generate caption for a single frame"""
64
+ # Build chat with custom prompt
65
+ messages = [
66
+ {"role": "user", "content": f"<image>\n{prompt}"}
67
+ ]
68
+ rendered = tok.apply_chat_template(
69
+ messages, add_generation_prompt=True, tokenize=False
70
+ )
71
+ pre, post = rendered.split("<image>", 1)
72
+
73
+ # Tokenize the text around the image token
74
+ pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
75
+ post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
76
+
77
+ # Splice in the IMAGE token id
78
+ img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
79
+ input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
80
+ attention_mask = torch.ones_like(input_ids, device=model.device)
81
+
82
+ # Preprocess image
83
+ px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"]
84
+ px = px.to(model.device, dtype=model.dtype)
85
+
86
+ # Generate
87
+ with torch.no_grad():
88
+ out = model.generate(
89
+ inputs=input_ids,
90
+ attention_mask=attention_mask,
91
+ images=px,
92
+ max_new_tokens=256,
93
+ temperature=0.7,
94
+ do_sample=True,
95
+ )
96
+
97
+ caption = tok.decode(out[0], skip_special_tokens=True)
98
+ # Extract only the generated part
99
+ if prompt in caption:
100
+ caption = caption.split(prompt)[-1].strip()
101
+
102
+ return caption
103
+
104
+ def process_video(
105
+ video_path: str,
106
+ num_frames: int,
107
+ sampling_method: str,
108
+ caption_mode: str,
109
+ custom_prompt: str,
110
+ progress=gr.Progress()
111
+ ) -> tuple:
112
+ """Process video and generate captions"""
113
+
114
+ if not video_path:
115
+ return "Please upload a video first.", None, None
116
+
117
+ progress(0, desc="Extracting frames...")
118
+ frames = extract_frames(video_path, num_frames, sampling_method)
119
+
120
+ if not frames:
121
+ return "Failed to extract frames from video.", None, None
122
+
123
+ # Prepare prompt based on mode
124
+ if caption_mode == "Detailed Description":
125
+ prompt = "Describe this image in detail, including all visible objects, actions, and the overall scene."
126
+ elif caption_mode == "Brief Summary":
127
+ prompt = "Provide a brief one-sentence description of what's happening in this image."
128
+ elif caption_mode == "Action Recognition":
129
+ prompt = "What action or activity is taking place in this image? Focus on the main action."
130
+ else: # Custom
131
+ prompt = custom_prompt if custom_prompt else "Describe this image."
132
+
133
+ captions = []
134
+ frame_previews = []
135
+
136
+ for i, frame in enumerate(frames):
137
+ progress((i + 1) / (len(frames) + 1), desc=f"Analyzing frame {i + 1}/{len(frames)}...")
138
+ caption = caption_frame(frame, prompt)
139
+ captions.append(f"**Frame {i + 1}:** {caption}")
140
+ frame_previews.append(frame)
141
+
142
+ progress(1.0, desc="Generating summary...")
143
+
144
+ # Combine captions into a narrative
145
+ full_caption = "\n\n".join(captions)
146
+
147
+ # Generate overall summary if multiple frames
148
+ if len(frames) > 1:
149
+ summary_prompt = f"Based on these frame descriptions, provide a coherent summary of the video:\n{full_caption}\n\nSummary:"
150
+ # For simplicity, we'll just combine the captions
151
+ video_summary = f"## Video Analysis ({len(frames)} frames analyzed)\n\n{full_caption}"
152
+ else:
153
+ video_summary = f"## Video Analysis\n\n{full_caption}"
154
+
155
+ return video_summary, frame_previews, video_path
156
+
157
+ # Create the Gradio interface
158
+ with gr.Blocks(css="""
159
+ .video-container {
160
+ height: calc(100vh - 100px) !important;
161
+ }
162
+ .sidebar {
163
+ height: calc(100vh - 100px) !important;
164
+ overflow-y: auto;
165
+ }
166
+ """) as demo:
167
+ gr.Markdown("# 🎬 FastVLM Video Captioning")
168
+
169
+ with gr.Row():
170
+ # Main video display
171
+ with gr.Column(scale=7):
172
+ video_display = gr.Video(
173
+ label="Video Input",
174
+ height=600,
175
+ elem_classes=["video-container"],
176
+ autoplay=True,
177
+ loop=True
178
+ )
179
+
180
+ # Sidebar with controls
181
+ with gr.Sidebar(width=400, elem_classes=["sidebar"]):
182
+ gr.Markdown("## ⚙️ Settings")
183
+
184
+ with gr.Group():
185
+ gr.Markdown("### Frame Sampling")
186
+ num_frames = gr.Slider(
187
+ minimum=1,
188
+ maximum=16,
189
+ value=8,
190
+ step=1,
191
+ label="Number of Frames to Analyze",
192
+ info="More frames = better understanding but slower processing"
193
+ )
194
+
195
+ sampling_method = gr.Radio(
196
+ choices=["uniform", "first", "last", "middle"],
197
+ value="uniform",
198
+ label="Sampling Method",
199
+ info="How to select frames from the video"
200
+ )
201
+
202
+ with gr.Group():
203
+ gr.Markdown("### Caption Settings")
204
+ caption_mode = gr.Radio(
205
+ choices=["Detailed Description", "Brief Summary", "Action Recognition", "Custom"],
206
+ value="Detailed Description",
207
+ label="Caption Mode"
208
+ )
209
+
210
+ custom_prompt = gr.Textbox(
211
+ label="Custom Prompt",
212
+ placeholder="Enter your custom prompt here...",
213
+ visible=False,
214
+ lines=3
215
+ )
216
+
217
+ process_btn = gr.Button("🎯 Analyze Video", variant="primary", size="lg")
218
+
219
+ gr.Markdown("### 📝 Results")
220
+ output_text = gr.Markdown(
221
+ value="Upload a video and click 'Analyze Video' to begin.",
222
+ elem_classes=["output-text"]
223
+ )
224
+
225
+ with gr.Accordion("🖼️ Analyzed Frames", open=False):
226
+ frame_gallery = gr.Gallery(
227
+ label="Extracted Frames",
228
+ show_label=False,
229
+ columns=2,
230
+ rows=4,
231
+ object_fit="contain",
232
+ height="auto"
233
+ )
234
+
235
+ # Show/hide custom prompt based on mode selection
236
+ def toggle_custom_prompt(mode):
237
+ return gr.Textbox(visible=(mode == "Custom"))
238
+
239
+ caption_mode.change(
240
+ toggle_custom_prompt,
241
+ inputs=[caption_mode],
242
+ outputs=[custom_prompt]
243
+ )
244
+
245
+ # Upload handler
246
+ def handle_upload(video):
247
+ if video:
248
+ return video, "Video loaded! Click 'Analyze Video' to generate captions."
249
+ return None, "Upload a video to begin."
250
+
251
+ video_display.upload(
252
+ handle_upload,
253
+ inputs=[video_display],
254
+ outputs=[video_display, output_text]
255
+ )
256
+
257
+ # Process button
258
+ process_btn.click(
259
+ process_video,
260
+ inputs=[video_display, num_frames, sampling_method, caption_mode, custom_prompt],
261
+ outputs=[output_text, frame_gallery, video_display]
262
+ )
263
+
264
+ demo.launch()