ruff
Browse files
app.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from collections.abc import Iterator
|
| 4 |
from threading import Thread
|
| 5 |
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import spaces
|
| 8 |
import torch
|
| 9 |
-
import
|
| 10 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
| 11 |
|
| 12 |
model_id = "google/gemma-3-12b-it"
|
|
@@ -15,17 +18,13 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
| 15 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
| 16 |
)
|
| 17 |
|
| 18 |
-
import cv2
|
| 19 |
-
from PIL import Image
|
| 20 |
-
import numpy as np
|
| 21 |
-
import tempfile
|
| 22 |
|
| 23 |
def downsample_video(video_path):
|
| 24 |
vidcap = cv2.VideoCapture(video_path)
|
| 25 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 26 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 27 |
-
|
| 28 |
-
frame_interval = int(fps / 3)
|
| 29 |
frames = []
|
| 30 |
|
| 31 |
for i in range(0, total_frames, frame_interval):
|
|
@@ -34,7 +33,7 @@ def downsample_video(video_path):
|
|
| 34 |
if success:
|
| 35 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 36 |
pil_image = Image.fromarray(image)
|
| 37 |
-
timestamp = round(i / fps, 2)
|
| 38 |
frames.append((pil_image, timestamp))
|
| 39 |
|
| 40 |
vidcap.release()
|
|
@@ -46,8 +45,8 @@ def process_new_user_message(message: dict) -> list[dict]:
|
|
| 46 |
if "<image>" in message["text"]:
|
| 47 |
content = []
|
| 48 |
print("message[files]", message["files"])
|
| 49 |
-
parts = re.split(r
|
| 50 |
-
image_index = 0
|
| 51 |
print("parts", parts)
|
| 52 |
for part in parts:
|
| 53 |
print("part", part)
|
|
@@ -55,29 +54,30 @@ def process_new_user_message(message: dict) -> list[dict]:
|
|
| 55 |
content.append({"type": "image", "url": message["files"][image_index]})
|
| 56 |
print("file", message["files"][image_index])
|
| 57 |
image_index += 1
|
| 58 |
-
elif part.strip():
|
| 59 |
content.append({"type": "text", "text": part.strip()})
|
| 60 |
elif isinstance(part, str) and not part == "<image>":
|
| 61 |
content.append({"type": "text", "text": part})
|
| 62 |
print(content)
|
| 63 |
return content
|
| 64 |
-
|
| 65 |
content = []
|
| 66 |
video = message["files"].pop(0)
|
| 67 |
frames = downsample_video(video)
|
| 68 |
for frame in frames:
|
| 69 |
pil_image, timestamp = frame
|
| 70 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=
|
| 71 |
pil_image.save(temp_file.name)
|
| 72 |
content.append({"type": "text", "text": f"Frame {timestamp}:"})
|
| 73 |
content.append({"type": "image", "url": temp_file.name})
|
| 74 |
print(content)
|
| 75 |
return content
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
def process_history(history: list[dict]) -> list[dict]:
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
|
| 3 |
+
import re
|
| 4 |
+
import tempfile
|
| 5 |
from collections.abc import Iterator
|
| 6 |
from threading import Thread
|
| 7 |
|
| 8 |
+
import cv2
|
| 9 |
import gradio as gr
|
| 10 |
import spaces
|
| 11 |
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
|
| 14 |
|
| 15 |
model_id = "google/gemma-3-12b-it"
|
|
|
|
| 18 |
model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
| 19 |
)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def downsample_video(video_path):
|
| 23 |
vidcap = cv2.VideoCapture(video_path)
|
| 24 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 25 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 26 |
+
|
| 27 |
+
frame_interval = int(fps / 3)
|
| 28 |
frames = []
|
| 29 |
|
| 30 |
for i in range(0, total_frames, frame_interval):
|
|
|
|
| 33 |
if success:
|
| 34 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 35 |
pil_image = Image.fromarray(image)
|
| 36 |
+
timestamp = round(i / fps, 2)
|
| 37 |
frames.append((pil_image, timestamp))
|
| 38 |
|
| 39 |
vidcap.release()
|
|
|
|
| 45 |
if "<image>" in message["text"]:
|
| 46 |
content = []
|
| 47 |
print("message[files]", message["files"])
|
| 48 |
+
parts = re.split(r"(<image>)", message["text"])
|
| 49 |
+
image_index = 0
|
| 50 |
print("parts", parts)
|
| 51 |
for part in parts:
|
| 52 |
print("part", part)
|
|
|
|
| 54 |
content.append({"type": "image", "url": message["files"][image_index]})
|
| 55 |
print("file", message["files"][image_index])
|
| 56 |
image_index += 1
|
| 57 |
+
elif part.strip():
|
| 58 |
content.append({"type": "text", "text": part.strip()})
|
| 59 |
elif isinstance(part, str) and not part == "<image>":
|
| 60 |
content.append({"type": "text", "text": part})
|
| 61 |
print(content)
|
| 62 |
return content
|
| 63 |
+
if message["files"][0].endswith(".mp4"):
|
| 64 |
content = []
|
| 65 |
video = message["files"].pop(0)
|
| 66 |
frames = downsample_video(video)
|
| 67 |
for frame in frames:
|
| 68 |
pil_image, timestamp = frame
|
| 69 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
| 70 |
pil_image.save(temp_file.name)
|
| 71 |
content.append({"type": "text", "text": f"Frame {timestamp}:"})
|
| 72 |
content.append({"type": "image", "url": temp_file.name})
|
| 73 |
print(content)
|
| 74 |
return content
|
| 75 |
+
# non interleaved images
|
| 76 |
+
return [
|
| 77 |
+
{"type": "text", "text": message["text"]},
|
| 78 |
+
*[{"type": "image", "url": path} for path in message["files"]],
|
| 79 |
+
]
|
| 80 |
+
return [{"type": "text", "text": message["text"]}]
|
| 81 |
|
| 82 |
|
| 83 |
def process_history(history: list[dict]) -> list[dict]:
|