ferferefer commited on
Commit
5d9018e
·
verified ·
1 Parent(s): f343d07

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +10 -7
  2. app.py +238 -0
  3. requirements.txt +253 -0
  4. style.css +11 -0
  5. uv.lock +0 -0
README.md CHANGED
@@ -1,12 +1,15 @@
1
  ---
2
- title: X
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: indigo
 
 
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: "XX"
3
+ models:
4
+ - XX
5
+ emoji: 🩻
6
+ colorFrom: blue
7
+ colorTo: blue
8
  sdk: gradio
9
+ sdk_version: 5.31.0
10
  app_file: app.py
11
+ pinned: true
12
+
13
  ---
14
 
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import re
5
+ import tempfile
6
+ from collections.abc import Iterator
7
+ from threading import Thread
8
+
9
+ import cv2
10
+ import gradio as gr
11
+ import spaces
12
+ import torch
13
+ from loguru import logger
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer, Qwen2_5_VLForConditionalGeneration
16
+ from qwen_vl_utils import process_vision_info
17
+
18
+ #import subprocess
19
+ #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
+
21
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
22
+ os.getenv("MODEL_PATH"),
23
+ torch_dtype=torch.bfloat16,
24
+ #attn_implementation="flash_attention_2",
25
+ device_map="auto",
26
+ )
27
+
28
+ processor = AutoProcessor.from_pretrained(os.getenv("MODEL_PATH"))
29
+
30
+
31
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
32
+
33
+
34
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
35
+ image_count = 0
36
+ video_count = 0
37
+ for path in paths:
38
+ if path.endswith(".mp4"):
39
+ video_count += 1
40
+ else:
41
+ image_count += 1
42
+ return image_count, video_count
43
+
44
+
45
+ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
46
+ image_count = 0
47
+ video_count = 0
48
+ for item in history:
49
+ if item["role"] != "user" or isinstance(item["content"], str):
50
+ continue
51
+ if item["content"][0].endswith(".mp4"):
52
+ video_count += 1
53
+ else:
54
+ image_count += 1
55
+ return image_count, video_count
56
+
57
+
58
+ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
59
+ new_image_count, new_video_count = count_files_in_new_message(message["files"])
60
+ history_image_count, history_video_count = count_files_in_history(history)
61
+ image_count = history_image_count + new_image_count
62
+ video_count = history_video_count + new_video_count
63
+ if video_count > 1:
64
+ gr.Warning("Only one video is supported.")
65
+ return False
66
+ if video_count == 1:
67
+ if image_count > 0:
68
+ gr.Warning("Mixing images and videos is not allowed.")
69
+ return False
70
+ if "<image>" in message["text"]:
71
+ gr.Warning("Using <image> tags with video files is not supported.")
72
+ return False
73
+ if video_count == 0 and image_count > MAX_NUM_IMAGES:
74
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
75
+ return False
76
+ if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
77
+ gr.Warning("The number of <image> tags in the text does not match the number of images.")
78
+ return False
79
+ return True
80
+
81
+
82
+ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
83
+ vidcap = cv2.VideoCapture(video_path)
84
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
85
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
86
+
87
+ frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
88
+ frames: list[tuple[Image.Image, float]] = []
89
+
90
+ for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
91
+ if len(frames) >= MAX_NUM_IMAGES:
92
+ break
93
+
94
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
95
+ success, image = vidcap.read()
96
+ if success:
97
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
98
+ pil_image = Image.fromarray(image)
99
+ timestamp = round(i / fps, 2)
100
+ frames.append((pil_image, timestamp))
101
+
102
+ vidcap.release()
103
+ return frames
104
+
105
+
106
+ def process_video(video_path: str) -> list[dict]:
107
+ content = []
108
+ frames = downsample_video(video_path)
109
+ for frame in frames:
110
+ pil_image, timestamp = frame
111
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
112
+ pil_image.save(temp_file.name)
113
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
114
+ content.append({"type": "image", "image": temp_file.name})
115
+ logger.debug(f"{content=}")
116
+ return content
117
+
118
+
119
+ def process_interleaved_images(message: dict) -> list[dict]:
120
+ logger.debug(f"{message['files']=}")
121
+ parts = re.split(r"(<image>)", message["text"])
122
+ logger.debug(f"{parts=}")
123
+
124
+ content = []
125
+ image_index = 0
126
+ for part in parts:
127
+ logger.debug(f"{part=}")
128
+ if part == "<image>":
129
+ content.append({"type": "image", "image": message["files"][image_index]})
130
+ logger.debug(f"file: {message['files'][image_index]}")
131
+ image_index += 1
132
+ elif part.strip():
133
+ content.append({"type": "text", "text": part.strip()})
134
+ elif isinstance(part, str) and part != "<image>":
135
+ content.append({"type": "text", "text": part})
136
+ logger.debug(f"{content=}")
137
+ return content
138
+
139
+
140
+ def process_new_user_message(message: dict) -> list[dict]:
141
+ if not message["files"]:
142
+ return [{"type": "text", "text": message["text"]}]
143
+ if message["files"][0].endswith(".mp4"):
144
+ return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
145
+ if "<image>" in message["text"]:
146
+ return process_interleaved_images(message)
147
+ return [
148
+ {"type": "text", "text": message["text"]},
149
+ *[{"type": "image", "image": path} for path in message["files"]],
150
+ ]
151
+
152
+
153
+ def process_history(history: list[dict]) -> list[dict]:
154
+ messages = []
155
+ current_user_content: list[dict] = []
156
+ for item in history:
157
+ if item["role"] == "assistant":
158
+ if current_user_content:
159
+ messages.append({"role": "user", "content": current_user_content})
160
+ current_user_content = []
161
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
162
+ else:
163
+ content = item["content"]
164
+ if isinstance(content, str):
165
+ current_user_content.append({"type": "text", "text": content})
166
+ else:
167
+ current_user_content.append({"type": "image", "image": content[0]})
168
+ return messages
169
+
170
+
171
+ @spaces.GPU(duration=120)
172
+ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 2048) -> Iterator[str]:
173
+ if not validate_media_constraints(message, history):
174
+ yield ""
175
+ return
176
+
177
+ messages = []
178
+ if system_prompt:
179
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
180
+ messages.extend(process_history(history))
181
+ messages.append({"role": "user", "content": process_new_user_message(message)})
182
+
183
+ # Preparation for inference
184
+ text = processor.apply_chat_template(
185
+ messages, tokenize=False, add_generation_prompt=True
186
+ )
187
+ image_inputs, video_inputs = process_vision_info(messages)
188
+ inputs = processor(
189
+ text=[text],
190
+ images=image_inputs,
191
+ videos=video_inputs,
192
+ padding=True,
193
+ return_tensors="pt",
194
+ )
195
+ inputs = inputs.to(model.device)
196
+
197
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
198
+ generate_kwargs = dict(
199
+ inputs,
200
+ max_new_tokens=max_new_tokens,
201
+ streamer=streamer,
202
+ temperature=0.7,
203
+ top_p=1,
204
+ repetition_penalty=1,
205
+ )
206
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
207
+ t.start()
208
+
209
+ output = ""
210
+ for delta in streamer:
211
+ output += delta
212
+ yield output
213
+
214
+
215
+ DESCRIPTION = """\
216
+ AI X-Ray"""
217
+
218
+ demo = gr.ChatInterface(
219
+ fn=run,
220
+ type="messages",
221
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
222
+ textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
223
+ multimodal=True,
224
+ additional_inputs=[
225
+ gr.Textbox(label="System Prompt", value="You are a helpful AI expert."),
226
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=8192, step=10, value=2048),
227
+ ],
228
+ stop_btn=False,
229
+ title="AI Expert",
230
+ description=DESCRIPTION,
231
+ run_examples_on_click=False,
232
+ cache_examples=False,
233
+ css_paths="style.css",
234
+ delete_cache=(1800, 1800),
235
+ )
236
+
237
+ if __name__ == "__main__":
238
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.4.0
4
+ # via gemma-3-12b-it (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.8.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ certifi==2025.1.31
15
+ # via
16
+ # httpcore
17
+ # httpx
18
+ # requests
19
+ charset-normalizer==3.4.1
20
+ # via requests
21
+ click==8.1.8
22
+ # via
23
+ # typer
24
+ # uvicorn
25
+ exceptiongroup==1.2.2
26
+ # via anyio
27
+ fastapi==0.115.11
28
+ # via gradio
29
+ ffmpy==0.5.0
30
+ # via gradio
31
+ filelock==3.17.0
32
+ # via
33
+ # huggingface-hub
34
+ # torch
35
+ # transformers
36
+ # triton
37
+ fsspec==2025.3.0
38
+ # via
39
+ # gradio-client
40
+ # huggingface-hub
41
+ # torch
42
+ gradio==5.21.0
43
+ # via
44
+ # gemma-3-12b-it (pyproject.toml)
45
+ # spaces
46
+ gradio-client==1.7.2
47
+ # via gradio
48
+ groovy==0.1.2
49
+ # via gradio
50
+ h11==0.14.0
51
+ # via
52
+ # httpcore
53
+ # uvicorn
54
+ hf-transfer==0.1.9
55
+ # via gemma-3-12b-it (pyproject.toml)
56
+ httpcore==1.0.7
57
+ # via httpx
58
+ httpx==0.28.1
59
+ # via
60
+ # gradio
61
+ # gradio-client
62
+ # safehttpx
63
+ # spaces
64
+ huggingface-hub==0.29.2
65
+ # via
66
+ # accelerate
67
+ # gradio
68
+ # gradio-client
69
+ # tokenizers
70
+ # transformers
71
+ idna==3.10
72
+ # via
73
+ # anyio
74
+ # httpx
75
+ # requests
76
+ jinja2==3.1.6
77
+ # via
78
+ # gradio
79
+ # torch
80
+ loguru==0.7.3
81
+ # via gemma-3-12b-it (pyproject.toml)
82
+ markdown-it-py==3.0.0
83
+ # via rich
84
+ markupsafe==2.1.5
85
+ # via
86
+ # gradio
87
+ # jinja2
88
+ mdurl==0.1.2
89
+ # via markdown-it-py
90
+ mpmath==1.3.0
91
+ # via sympy
92
+ networkx==3.4.2
93
+ # via torch
94
+ numpy==2.2.3
95
+ # via
96
+ # accelerate
97
+ # gradio
98
+ # opencv-python-headless
99
+ # pandas
100
+ # transformers
101
+ nvidia-cublas-cu12==12.1.3.1
102
+ # via
103
+ # nvidia-cudnn-cu12
104
+ # nvidia-cusolver-cu12
105
+ # torch
106
+ nvidia-cuda-cupti-cu12==12.1.105
107
+ # via torch
108
+ nvidia-cuda-nvrtc-cu12==12.1.105
109
+ # via torch
110
+ nvidia-cuda-runtime-cu12==12.1.105
111
+ # via torch
112
+ nvidia-cudnn-cu12==9.1.0.70
113
+ # via torch
114
+ nvidia-cufft-cu12==11.0.2.54
115
+ # via torch
116
+ nvidia-curand-cu12==10.3.2.106
117
+ # via torch
118
+ nvidia-cusolver-cu12==11.4.5.107
119
+ # via torch
120
+ nvidia-cusparse-cu12==12.1.0.106
121
+ # via
122
+ # nvidia-cusolver-cu12
123
+ # torch
124
+ nvidia-nccl-cu12==2.20.5
125
+ # via torch
126
+ nvidia-nvjitlink-cu12==12.8.93
127
+ # via
128
+ # nvidia-cusolver-cu12
129
+ # nvidia-cusparse-cu12
130
+ nvidia-nvtx-cu12==12.1.105
131
+ # via torch
132
+ opencv-python-headless==4.11.0.86
133
+ # via gemma-3-12b-it (pyproject.toml)
134
+ orjson==3.10.15
135
+ # via gradio
136
+ packaging==24.2
137
+ # via
138
+ # accelerate
139
+ # gradio
140
+ # gradio-client
141
+ # huggingface-hub
142
+ # spaces
143
+ # transformers
144
+ pandas==2.2.3
145
+ # via gradio
146
+ pillow==11.1.0
147
+ # via gradio
148
+ protobuf==6.30.0
149
+ # via gemma-3-12b-it (pyproject.toml)
150
+ psutil==5.9.8
151
+ # via
152
+ # accelerate
153
+ # spaces
154
+ pydantic==2.10.6
155
+ # via
156
+ # fastapi
157
+ # gradio
158
+ # spaces
159
+ pydantic-core==2.27.2
160
+ # via pydantic
161
+ pydub==0.25.1
162
+ # via gradio
163
+ pygments==2.19.1
164
+ # via rich
165
+ python-dateutil==2.9.0.post0
166
+ # via pandas
167
+ python-multipart==0.0.20
168
+ # via gradio
169
+ pytz==2025.1
170
+ # via pandas
171
+ pyyaml==6.0.2
172
+ # via
173
+ # accelerate
174
+ # gradio
175
+ # huggingface-hub
176
+ # transformers
177
+ regex==2024.11.6
178
+ # via transformers
179
+ requests==2.32.3
180
+ # via
181
+ # huggingface-hub
182
+ # spaces
183
+ # transformers
184
+ rich==13.9.4
185
+ # via typer
186
+ ruff==0.9.10
187
+ # via gradio
188
+ safehttpx==0.1.6
189
+ # via gradio
190
+ safetensors==0.5.3
191
+ # via
192
+ # accelerate
193
+ # transformers
194
+ semantic-version==2.10.0
195
+ # via gradio
196
+ sentencepiece==0.2.0
197
+ # via gemma-3-12b-it (pyproject.toml)
198
+ shellingham==1.5.4
199
+ # via typer
200
+ six==1.17.0
201
+ # via python-dateutil
202
+ sniffio==1.3.1
203
+ # via anyio
204
+ spaces==0.32.0
205
+ # via gemma-3-12b-it (pyproject.toml)
206
+ starlette==0.46.1
207
+ # via
208
+ # fastapi
209
+ # gradio
210
+ sympy==1.13.3
211
+ # via torch
212
+ tokenizers==0.21.0
213
+ # via transformers
214
+ tomlkit==0.13.2
215
+ # via gradio
216
+ torch==2.4.0
217
+ # via
218
+ # gemma-3-12b-it (pyproject.toml)
219
+ # accelerate
220
+ torchvision
221
+ tqdm==4.67.1
222
+ # via
223
+ # huggingface-hub
224
+ # transformers
225
+ transformers @ git+https://github.com/huggingface/transformers@2829013d2d00e63d75a1f6f7a3f003bc60cc69af
226
+ # via gemma-3-12b-it (pyproject.toml)
227
+ triton==3.0.0
228
+ # via torch
229
+ typer==0.15.2
230
+ # via gradio
231
+ typing-extensions==4.12.2
232
+ # via
233
+ # anyio
234
+ # fastapi
235
+ # gradio
236
+ # gradio-client
237
+ # huggingface-hub
238
+ # pydantic
239
+ # pydantic-core
240
+ # rich
241
+ # spaces
242
+ # torch
243
+ # typer
244
+ # uvicorn
245
+ tzdata==2025.1
246
+ # via pandas
247
+ urllib3==2.3.0
248
+ # via requests
249
+ uvicorn==0.34.0
250
+ # via gradio
251
+ websockets==15.0.1
252
+ # via gradio-client
253
+ qwen_vl_utils
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #logo {
7
+ display: block;
8
+ margin: 0 auto;
9
+ width: 40%;
10
+ object-fit: contain;
11
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff