Spaces:
Paused
Paused
local
#1
by
raksama19
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .DS_Store +0 -0
- .gitattributes +0 -5
- __pycache__/gradio_alt_text.cpython-312.pyc +0 -0
- __pycache__/gradio_final_app.cpython-312.pyc +0 -0
- __pycache__/gradio_gemma.cpython-312.pyc +0 -0
- __pycache__/gradio_gemma_alt_text.cpython-312.pyc +0 -0
- app.py +0 -1084
- assets/demo.gif +0 -3
- assets/dolphin.png +0 -3
- assets/framework.png +0 -3
- chat.py +0 -198
- config/Dolphin.yaml +0 -17
- demo/.DS_Store +0 -0
- demo/element_imgs/.DS_Store +0 -0
- demo/element_imgs/block_formula.jpeg +0 -3
- demo/element_imgs/line_formula.jpeg +0 -3
- demo/element_imgs/markdown/.DS_Store +0 -0
- demo/element_imgs/markdown/table_1.md +0 -2
- demo/element_imgs/para_1.jpg +0 -3
- demo/element_imgs/para_2.jpg +0 -3
- demo/element_imgs/para_3.jpeg +0 -3
- demo/element_imgs/recognition_json/table_1.json +0 -6
- demo/element_imgs/table_1.jpeg +0 -3
- demo/element_imgs/table_2.jpeg +0 -3
- demo/page_imgs/.DS_Store +0 -0
- demo/page_imgs/markdown/.DS_Store +0 -0
- demo/page_imgs/markdown/figures/.DS_Store +0 -0
- demo/page_imgs/markdown/figures/test_page3_figure_000.png +0 -3
- demo/page_imgs/markdown/test_page3.md +0 -22
- demo/page_imgs/page_1.jpeg +0 -3
- demo/page_imgs/page_2.jpeg +0 -3
- demo/page_imgs/page_3.jpeg +0 -3
- demo/page_imgs/page_4.png +0 -3
- demo/page_imgs/page_5.jpg +0 -3
- demo/page_imgs/page_6.pdf +0 -0
- demo/page_imgs/page_7.jpeg +0 -3
- demo/page_imgs/recognition_json/page_1.json +0 -178
- demo/page_imgs/recognition_json/test_page.json +0 -47
- demo/page_imgs/recognition_json/test_page2.json +0 -102
- demo/page_imgs/recognition_json/test_page3.json +0 -124
- demo/page_imgs/test_page2.jpeg +0 -3
- demo/page_imgs/test_page3.jpeg +0 -3
- demo_element.py +0 -129
- demo_element_hf.py +0 -195
- demo_page.py +0 -247
- demo_page_hf.py +0 -365
- deployment/ReadMe.md +0 -12
- deployment/tensorrt_llm/ReadMe.md +0 -89
- deployment/tensorrt_llm/api_client.py +0 -100
- deployment/tensorrt_llm/api_server.py +0 -112
.DS_Store
DELETED
Binary file (10.2 kB)
|
|
.gitattributes
CHANGED
@@ -33,8 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
-
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
38 |
-
*.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
-
*.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
-
*.gif filter=lfs diff=lfs merge=lfs -text
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
__pycache__/gradio_alt_text.cpython-312.pyc
DELETED
Binary file (33.9 kB)
|
|
__pycache__/gradio_final_app.cpython-312.pyc
DELETED
Binary file (30.4 kB)
|
|
__pycache__/gradio_gemma.cpython-312.pyc
DELETED
Binary file (14.3 kB)
|
|
__pycache__/gradio_gemma_alt_text.cpython-312.pyc
DELETED
Binary file (8.1 kB)
|
|
app.py
DELETED
@@ -1,1084 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import json
|
3 |
-
import markdown
|
4 |
-
import cv2
|
5 |
-
import numpy as np
|
6 |
-
from PIL import Image
|
7 |
-
from transformers import AutoProcessor, VisionEncoderDecoderModel, AutoModelForImageTextToText
|
8 |
-
import torch
|
9 |
-
try:
|
10 |
-
from sentence_transformers import SentenceTransformer
|
11 |
-
import numpy as np
|
12 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
-
RAG_DEPENDENCIES_AVAILABLE = True
|
14 |
-
except ImportError as e:
|
15 |
-
print(f"RAG dependencies not available: {e}")
|
16 |
-
print("Please install: pip install sentence-transformers scikit-learn")
|
17 |
-
RAG_DEPENDENCIES_AVAILABLE = False
|
18 |
-
SentenceTransformer = None
|
19 |
-
import os
|
20 |
-
import tempfile
|
21 |
-
import uuid
|
22 |
-
import base64
|
23 |
-
import io
|
24 |
-
from utils.utils import *
|
25 |
-
from utils.markdown_utils import MarkdownConverter
|
26 |
-
|
27 |
-
# Voice functionality imports
|
28 |
-
import time
|
29 |
-
import librosa
|
30 |
-
from dataclasses import dataclass, field
|
31 |
-
from pydub import AudioSegment
|
32 |
-
try:
|
33 |
-
from voice_chat.utils.vad import get_speech_timestamps, collect_chunks, VadOptions
|
34 |
-
from voice_chat.gemma3n_inference import Gemma3nInference
|
35 |
-
VOICE_DEPENDENCIES_AVAILABLE = True
|
36 |
-
except ImportError as e:
|
37 |
-
print(f"Voice dependencies not available: {e}")
|
38 |
-
VOICE_DEPENDENCIES_AVAILABLE = False
|
39 |
-
|
40 |
-
# Math extension is optional for enhanced math rendering
|
41 |
-
MATH_EXTENSION_AVAILABLE = False
|
42 |
-
try:
|
43 |
-
from mdx_math import MathExtension
|
44 |
-
MATH_EXTENSION_AVAILABLE = True
|
45 |
-
except ImportError:
|
46 |
-
pass
|
47 |
-
|
48 |
-
# Initialize voice model early to avoid NameError
|
49 |
-
voice_model = None
|
50 |
-
if VOICE_DEPENDENCIES_AVAILABLE:
|
51 |
-
try:
|
52 |
-
print("Loading voice model...")
|
53 |
-
voice_model = Gemma3nInference(device='cuda' if torch.cuda.is_available() else 'cpu')
|
54 |
-
print("Warming up voice model...")
|
55 |
-
voice_model.warm_up()
|
56 |
-
print("✅ Voice model loaded and warmed up successfully")
|
57 |
-
except Exception as e:
|
58 |
-
print(f"⚠️ Voice model initialization failed: {e}")
|
59 |
-
voice_model = None
|
60 |
-
|
61 |
-
|
62 |
-
class DOLPHIN:
|
63 |
-
def __init__(self, model_id_or_path):
|
64 |
-
"""Initialize the Hugging Face model optimized for powerful GPU"""
|
65 |
-
self.processor = AutoProcessor.from_pretrained(model_id_or_path)
|
66 |
-
self.model = VisionEncoderDecoderModel.from_pretrained(
|
67 |
-
model_id_or_path,
|
68 |
-
torch_dtype=torch.float16,
|
69 |
-
device_map="auto" if torch.cuda.is_available() else None
|
70 |
-
)
|
71 |
-
self.model.eval()
|
72 |
-
|
73 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
74 |
-
if not torch.cuda.is_available():
|
75 |
-
self.model = self.model.float()
|
76 |
-
|
77 |
-
self.tokenizer = self.processor.tokenizer
|
78 |
-
|
79 |
-
def chat(self, prompt, image):
|
80 |
-
"""Process an image or batch of images with the given prompt(s)"""
|
81 |
-
is_batch = isinstance(image, list)
|
82 |
-
|
83 |
-
if not is_batch:
|
84 |
-
images = [image]
|
85 |
-
prompts = [prompt]
|
86 |
-
else:
|
87 |
-
images = image
|
88 |
-
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
|
89 |
-
|
90 |
-
batch_inputs = self.processor(images, return_tensors="pt", padding=True)
|
91 |
-
batch_pixel_values = batch_inputs.pixel_values
|
92 |
-
|
93 |
-
if torch.cuda.is_available():
|
94 |
-
batch_pixel_values = batch_pixel_values.half().to(self.device)
|
95 |
-
else:
|
96 |
-
batch_pixel_values = batch_pixel_values.to(self.device)
|
97 |
-
|
98 |
-
prompts = [f"<s>{p} <Answer/>" for p in prompts]
|
99 |
-
batch_prompt_inputs = self.tokenizer(
|
100 |
-
prompts,
|
101 |
-
add_special_tokens=False,
|
102 |
-
return_tensors="pt"
|
103 |
-
)
|
104 |
-
|
105 |
-
batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
|
106 |
-
batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
|
107 |
-
|
108 |
-
with torch.no_grad():
|
109 |
-
outputs = self.model.generate(
|
110 |
-
pixel_values=batch_pixel_values,
|
111 |
-
decoder_input_ids=batch_prompt_ids,
|
112 |
-
decoder_attention_mask=batch_attention_mask,
|
113 |
-
min_length=1,
|
114 |
-
max_length=2048,
|
115 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
116 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
117 |
-
use_cache=True,
|
118 |
-
bad_words_ids=[[self.tokenizer.unk_token_id]],
|
119 |
-
return_dict_in_generate=True,
|
120 |
-
do_sample=False,
|
121 |
-
num_beams=1,
|
122 |
-
repetition_penalty=1.1,
|
123 |
-
temperature=1.0
|
124 |
-
)
|
125 |
-
|
126 |
-
sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
|
127 |
-
|
128 |
-
results = []
|
129 |
-
for i, sequence in enumerate(sequences):
|
130 |
-
cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
|
131 |
-
results.append(cleaned)
|
132 |
-
|
133 |
-
if not is_batch:
|
134 |
-
return results[0]
|
135 |
-
return results
|
136 |
-
|
137 |
-
|
138 |
-
class Gemma3nModel:
|
139 |
-
def __init__(self, model_id="google/gemma-3n-E4B-it"):
|
140 |
-
"""Initialize the Gemma 3n model for text generation and image description"""
|
141 |
-
self.model_id = model_id
|
142 |
-
self.processor = AutoProcessor.from_pretrained(model_id)
|
143 |
-
self.model = AutoModelForImageTextToText.from_pretrained(
|
144 |
-
model_id,
|
145 |
-
torch_dtype="auto",
|
146 |
-
device_map="auto"
|
147 |
-
)
|
148 |
-
self.model.eval()
|
149 |
-
print(f"✅ Gemma 3n loaded (Device: {self.model.device}, DType: {self.model.dtype})")
|
150 |
-
|
151 |
-
def generate_alt_text(self, pil_image):
|
152 |
-
"""Generate alt text for an image using local Gemma 3n"""
|
153 |
-
try:
|
154 |
-
# Ensure image is in RGB mode
|
155 |
-
if pil_image.mode != 'RGB':
|
156 |
-
pil_image = pil_image.convert('RGB')
|
157 |
-
|
158 |
-
# Create a detailed prompt for alt text generation
|
159 |
-
prompt = """You are an accessibility expert creating alt text for images to help visually impaired users understand visual content. Analyze this image and provide a clear, concise description that captures the essential visual information.
|
160 |
-
|
161 |
-
Focus on:
|
162 |
-
- Main subject or content of the image
|
163 |
-
- Important details, text, or data shown
|
164 |
-
- Layout and structure if relevant (charts, diagrams, tables)
|
165 |
-
- Context that would help someone understand the image's purpose
|
166 |
-
|
167 |
-
Provide a descriptive alt text in 1-2 sentences that is informative but not overly verbose. Start directly with the description without saying "This image shows" or similar phrases."""
|
168 |
-
|
169 |
-
# Prepare the message format
|
170 |
-
message = {
|
171 |
-
"role": "user",
|
172 |
-
"content": [
|
173 |
-
{"type": "image", "image": pil_image},
|
174 |
-
{"type": "text", "text": prompt}
|
175 |
-
]
|
176 |
-
}
|
177 |
-
|
178 |
-
# Apply chat template and generate
|
179 |
-
input_ids = self.processor.apply_chat_template(
|
180 |
-
[message],
|
181 |
-
add_generation_prompt=True,
|
182 |
-
tokenize=True,
|
183 |
-
return_dict=True,
|
184 |
-
return_tensors="pt",
|
185 |
-
)
|
186 |
-
input_len = input_ids["input_ids"].shape[-1]
|
187 |
-
|
188 |
-
input_ids = input_ids.to(self.model.device, dtype=self.model.dtype)
|
189 |
-
outputs = self.model.generate(
|
190 |
-
**input_ids,
|
191 |
-
max_new_tokens=256,
|
192 |
-
disable_compile=True,
|
193 |
-
do_sample=False,
|
194 |
-
temperature=0.1
|
195 |
-
)
|
196 |
-
|
197 |
-
text = self.processor.batch_decode(
|
198 |
-
outputs[:, input_len:],
|
199 |
-
skip_special_tokens=True,
|
200 |
-
clean_up_tokenization_spaces=True
|
201 |
-
)
|
202 |
-
|
203 |
-
alt_text = text[0].strip()
|
204 |
-
|
205 |
-
# Clean up the alt text
|
206 |
-
alt_text = alt_text.replace('\n', ' ').replace('\r', ' ')
|
207 |
-
# Remove common prefixes if they appear
|
208 |
-
prefixes_to_remove = ["This image shows", "The image shows", "This shows", "The figure shows"]
|
209 |
-
for prefix in prefixes_to_remove:
|
210 |
-
if alt_text.startswith(prefix):
|
211 |
-
alt_text = alt_text[len(prefix):].strip()
|
212 |
-
break
|
213 |
-
|
214 |
-
return alt_text if alt_text else "Image description unavailable"
|
215 |
-
|
216 |
-
except Exception as e:
|
217 |
-
print(f"❌ Error generating alt text: {e}")
|
218 |
-
import traceback
|
219 |
-
traceback.print_exc()
|
220 |
-
return "Image description unavailable"
|
221 |
-
|
222 |
-
def chat(self, prompt, history=None):
|
223 |
-
"""Chat functionality using Gemma 3n for text-only conversations"""
|
224 |
-
try:
|
225 |
-
# Create message format
|
226 |
-
message = {
|
227 |
-
"role": "user",
|
228 |
-
"content": [
|
229 |
-
{"type": "text", "text": prompt}
|
230 |
-
]
|
231 |
-
}
|
232 |
-
|
233 |
-
# If history exists, include it
|
234 |
-
conversation = history if history else []
|
235 |
-
conversation.append(message)
|
236 |
-
|
237 |
-
# Apply chat template and generate
|
238 |
-
input_ids = self.processor.apply_chat_template(
|
239 |
-
conversation,
|
240 |
-
add_generation_prompt=True,
|
241 |
-
tokenize=True,
|
242 |
-
return_dict=True,
|
243 |
-
return_tensors="pt",
|
244 |
-
)
|
245 |
-
input_len = input_ids["input_ids"].shape[-1]
|
246 |
-
|
247 |
-
input_ids = input_ids.to(self.model.device, dtype=self.model.dtype)
|
248 |
-
outputs = self.model.generate(
|
249 |
-
**input_ids,
|
250 |
-
max_new_tokens=1024,
|
251 |
-
disable_compile=True,
|
252 |
-
do_sample=False,
|
253 |
-
pad_token_id=self.processor.tokenizer.pad_token_id
|
254 |
-
)
|
255 |
-
|
256 |
-
text = self.processor.batch_decode(
|
257 |
-
outputs[:, input_len:],
|
258 |
-
skip_special_tokens=True,
|
259 |
-
clean_up_tokenization_spaces=True
|
260 |
-
)
|
261 |
-
|
262 |
-
return text[0].strip()
|
263 |
-
|
264 |
-
except Exception as e:
|
265 |
-
print(f"❌ Error in chat: {e}")
|
266 |
-
import traceback
|
267 |
-
traceback.print_exc()
|
268 |
-
return f"Error generating response: {str(e)}"
|
269 |
-
|
270 |
-
|
271 |
-
def convert_pdf_to_images_gradio(pdf_file):
|
272 |
-
"""Convert uploaded PDF file to list of PIL Images"""
|
273 |
-
try:
|
274 |
-
import pymupdf
|
275 |
-
|
276 |
-
if isinstance(pdf_file, str):
|
277 |
-
pdf_document = pymupdf.open(pdf_file)
|
278 |
-
else:
|
279 |
-
pdf_bytes = pdf_file.read()
|
280 |
-
pdf_document = pymupdf.open(stream=pdf_bytes, filetype="pdf")
|
281 |
-
|
282 |
-
images = []
|
283 |
-
for page_num in range(len(pdf_document)):
|
284 |
-
page = pdf_document[page_num]
|
285 |
-
mat = pymupdf.Matrix(2.0, 2.0)
|
286 |
-
pix = page.get_pixmap(matrix=mat)
|
287 |
-
img_data = pix.tobytes("png")
|
288 |
-
pil_image = Image.open(io.BytesIO(img_data)).convert("RGB")
|
289 |
-
images.append(pil_image)
|
290 |
-
|
291 |
-
pdf_document.close()
|
292 |
-
return images
|
293 |
-
|
294 |
-
except Exception as e:
|
295 |
-
raise Exception(f"Error converting PDF: {str(e)}")
|
296 |
-
|
297 |
-
|
298 |
-
def process_pdf_document(pdf_file, model, progress=gr.Progress()):
|
299 |
-
"""Process uploaded PDF file page by page"""
|
300 |
-
if pdf_file is None:
|
301 |
-
return "No PDF file uploaded", ""
|
302 |
-
|
303 |
-
try:
|
304 |
-
progress(0.1, desc="Converting PDF to images...")
|
305 |
-
images = convert_pdf_to_images_gradio(pdf_file)
|
306 |
-
|
307 |
-
if not images:
|
308 |
-
return "Failed to convert PDF to images", ""
|
309 |
-
|
310 |
-
all_results = []
|
311 |
-
|
312 |
-
for page_idx, pil_image in enumerate(images):
|
313 |
-
progress((page_idx + 1) / len(images) * 0.8 + 0.1,
|
314 |
-
desc=f"Processing page {page_idx + 1}/{len(images)}...")
|
315 |
-
|
316 |
-
layout_output = model.chat("Parse the reading order of this document.", pil_image)
|
317 |
-
|
318 |
-
padded_image, dims = prepare_image(pil_image)
|
319 |
-
recognition_results = process_elements_optimized(
|
320 |
-
layout_output,
|
321 |
-
padded_image,
|
322 |
-
dims,
|
323 |
-
model,
|
324 |
-
max_batch_size=4
|
325 |
-
)
|
326 |
-
|
327 |
-
try:
|
328 |
-
markdown_converter = MarkdownConverter()
|
329 |
-
markdown_content = markdown_converter.convert(recognition_results)
|
330 |
-
except:
|
331 |
-
markdown_content = generate_fallback_markdown(recognition_results)
|
332 |
-
|
333 |
-
page_result = {
|
334 |
-
"page_number": page_idx + 1,
|
335 |
-
"markdown": markdown_content
|
336 |
-
}
|
337 |
-
all_results.append(page_result)
|
338 |
-
|
339 |
-
progress(1.0, desc="Processing complete!")
|
340 |
-
|
341 |
-
combined_markdown = "\n\n---\n\n".join([
|
342 |
-
f"# Page {result['page_number']}\n\n{result['markdown']}"
|
343 |
-
for result in all_results
|
344 |
-
])
|
345 |
-
|
346 |
-
return combined_markdown, "processing_complete"
|
347 |
-
|
348 |
-
except Exception as e:
|
349 |
-
error_msg = f"Error processing PDF: {str(e)}"
|
350 |
-
return error_msg, "error"
|
351 |
-
|
352 |
-
|
353 |
-
def process_elements_optimized(layout_results, padded_image, dims, model, max_batch_size=4):
|
354 |
-
"""Optimized element processing for powerful GPU"""
|
355 |
-
layout_results = parse_layout_string(layout_results)
|
356 |
-
|
357 |
-
text_elements = []
|
358 |
-
table_elements = []
|
359 |
-
figure_results = []
|
360 |
-
previous_box = None
|
361 |
-
reading_order = 0
|
362 |
-
|
363 |
-
for bbox, label in layout_results:
|
364 |
-
try:
|
365 |
-
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
|
366 |
-
bbox, padded_image, dims, previous_box
|
367 |
-
)
|
368 |
-
|
369 |
-
cropped = padded_image[y1:y2, x1:x2]
|
370 |
-
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
|
371 |
-
if label == "fig":
|
372 |
-
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
373 |
-
pil_crop = crop_margin(pil_crop)
|
374 |
-
|
375 |
-
# Generate alt text for accessibility using local Gemma 3n
|
376 |
-
alt_text = gemma_model.generate_alt_text(pil_crop)
|
377 |
-
|
378 |
-
buffered = io.BytesIO()
|
379 |
-
pil_crop.save(buffered, format="PNG")
|
380 |
-
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
381 |
-
data_uri = f"data:image/png;base64,{img_base64}"
|
382 |
-
|
383 |
-
figure_results.append({
|
384 |
-
"label": label,
|
385 |
-
"text": f"\n\n*{alt_text}*",
|
386 |
-
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
387 |
-
"reading_order": reading_order,
|
388 |
-
"alt_text": alt_text,
|
389 |
-
})
|
390 |
-
else:
|
391 |
-
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
392 |
-
element_info = {
|
393 |
-
"crop": pil_crop,
|
394 |
-
"label": label,
|
395 |
-
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
396 |
-
"reading_order": reading_order,
|
397 |
-
}
|
398 |
-
|
399 |
-
if label == "tab":
|
400 |
-
table_elements.append(element_info)
|
401 |
-
else:
|
402 |
-
text_elements.append(element_info)
|
403 |
-
|
404 |
-
reading_order += 1
|
405 |
-
|
406 |
-
except Exception as e:
|
407 |
-
print(f"Error processing element {label}: {str(e)}")
|
408 |
-
continue
|
409 |
-
|
410 |
-
recognition_results = figure_results.copy()
|
411 |
-
|
412 |
-
if text_elements:
|
413 |
-
text_results = process_element_batch_optimized(
|
414 |
-
text_elements, model, "Read text in the image.", max_batch_size
|
415 |
-
)
|
416 |
-
recognition_results.extend(text_results)
|
417 |
-
|
418 |
-
if table_elements:
|
419 |
-
table_results = process_element_batch_optimized(
|
420 |
-
table_elements, model, "Parse the table in the image.", max_batch_size
|
421 |
-
)
|
422 |
-
recognition_results.extend(table_results)
|
423 |
-
|
424 |
-
recognition_results.sort(key=lambda x: x.get("reading_order", 0))
|
425 |
-
return recognition_results
|
426 |
-
|
427 |
-
|
428 |
-
def process_element_batch_optimized(elements, model, prompt, max_batch_size=4):
|
429 |
-
"""Process elements in batches for powerful GPU"""
|
430 |
-
results = []
|
431 |
-
batch_size = min(len(elements), max_batch_size)
|
432 |
-
|
433 |
-
for i in range(0, len(elements), batch_size):
|
434 |
-
batch_elements = elements[i:i+batch_size]
|
435 |
-
crops_list = [elem["crop"] for elem in batch_elements]
|
436 |
-
prompts_list = [prompt] * len(crops_list)
|
437 |
-
|
438 |
-
batch_results = model.chat(prompts_list, crops_list)
|
439 |
-
|
440 |
-
for j, result in enumerate(batch_results):
|
441 |
-
elem = batch_elements[j]
|
442 |
-
results.append({
|
443 |
-
"label": elem["label"],
|
444 |
-
"bbox": elem["bbox"],
|
445 |
-
"text": result.strip(),
|
446 |
-
"reading_order": elem["reading_order"],
|
447 |
-
})
|
448 |
-
|
449 |
-
del crops_list, batch_elements
|
450 |
-
if torch.cuda.is_available():
|
451 |
-
torch.cuda.empty_cache()
|
452 |
-
|
453 |
-
return results
|
454 |
-
|
455 |
-
|
456 |
-
def generate_fallback_markdown(recognition_results):
|
457 |
-
"""Generate basic markdown if converter fails"""
|
458 |
-
markdown_content = ""
|
459 |
-
for element in recognition_results:
|
460 |
-
if element["label"] == "tab":
|
461 |
-
markdown_content += f"\n\n{element['text']}\n\n"
|
462 |
-
elif element["label"] in ["para", "title", "sec", "sub_sec"]:
|
463 |
-
markdown_content += f"{element['text']}\n\n"
|
464 |
-
elif element["label"] == "fig":
|
465 |
-
# Image should already have alt text from processing
|
466 |
-
markdown_content += f"{element['text']}\n\n"
|
467 |
-
return markdown_content
|
468 |
-
|
469 |
-
|
470 |
-
# Initialize models
|
471 |
-
model_path = "./hf_model"
|
472 |
-
if not os.path.exists(model_path):
|
473 |
-
model_path = "ByteDance/DOLPHIN"
|
474 |
-
|
475 |
-
# Model paths and configuration
|
476 |
-
model_path = "./hf_model" if os.path.exists("./hf_model") else "ByteDance/DOLPHIN"
|
477 |
-
hf_token = os.getenv('HF_TOKEN')
|
478 |
-
gemma_model_id = "google/gemma-3n-E4B-it"
|
479 |
-
|
480 |
-
# Initialize models
|
481 |
-
print("Loading DOLPHIN model...")
|
482 |
-
dolphin_model = DOLPHIN(model_path)
|
483 |
-
print(f"✅ DOLPHIN model loaded (Device: {dolphin_model.device})")
|
484 |
-
|
485 |
-
print("Loading Gemma 3n model...")
|
486 |
-
gemma_model = Gemma3nModel(gemma_model_id)
|
487 |
-
|
488 |
-
model_status = "✅ Both models loaded successfully"
|
489 |
-
|
490 |
-
# Initialize embedding model
|
491 |
-
if RAG_DEPENDENCIES_AVAILABLE:
|
492 |
-
try:
|
493 |
-
print("Loading embedding model for RAG...")
|
494 |
-
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
495 |
-
print("✅ Embedding model loaded successfully (CPU)")
|
496 |
-
except Exception as e:
|
497 |
-
print(f"❌ Error loading embedding model: {e}")
|
498 |
-
embedding_model = None
|
499 |
-
else:
|
500 |
-
print("❌ RAG dependencies not available")
|
501 |
-
embedding_model = None
|
502 |
-
|
503 |
-
|
504 |
-
# Global state for managing tabs
|
505 |
-
processed_markdown = ""
|
506 |
-
show_results_tab = False
|
507 |
-
document_chunks = []
|
508 |
-
document_embeddings = None
|
509 |
-
|
510 |
-
# Voice chat parameters and state
|
511 |
-
IN_CHANNELS = 1
|
512 |
-
IN_RATE = 24000
|
513 |
-
IN_CHUNK = 1024
|
514 |
-
IN_SAMPLE_WIDTH = 2
|
515 |
-
VAD_STRIDE = 0.5
|
516 |
-
OUT_CHANNELS = 1
|
517 |
-
OUT_RATE = 24000
|
518 |
-
OUT_SAMPLE_WIDTH = 2
|
519 |
-
OUT_CHUNK = 20 * 4096
|
520 |
-
|
521 |
-
# Voice model already initialized earlier in the file
|
522 |
-
|
523 |
-
@dataclass
|
524 |
-
class VoiceAppState:
|
525 |
-
stream: np.ndarray | None = None
|
526 |
-
sampling_rate: int = 0
|
527 |
-
pause_detected: bool = False
|
528 |
-
started_talking: bool = False
|
529 |
-
stopped: bool = False
|
530 |
-
conversation: list = field(default_factory=list)
|
531 |
-
|
532 |
-
|
533 |
-
# Voice functionality
|
534 |
-
def run_vad(ori_audio, sr):
|
535 |
-
"""Voice Activity Detection"""
|
536 |
-
_st = time.time()
|
537 |
-
try:
|
538 |
-
audio = ori_audio
|
539 |
-
if isinstance(audio, bytes):
|
540 |
-
audio = np.frombuffer(audio, dtype=np.int16)
|
541 |
-
audio = audio.astype(np.float32) / 32768.0
|
542 |
-
sampling_rate = 16000
|
543 |
-
if sr != sampling_rate:
|
544 |
-
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
|
545 |
-
|
546 |
-
vad_parameters = {}
|
547 |
-
vad_parameters = VadOptions(**vad_parameters)
|
548 |
-
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
549 |
-
audio = collect_chunks(audio, speech_chunks)
|
550 |
-
duration_after_vad = audio.shape[0] / sampling_rate
|
551 |
-
|
552 |
-
if sr != sampling_rate:
|
553 |
-
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
|
554 |
-
else:
|
555 |
-
vad_audio = audio
|
556 |
-
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
|
557 |
-
vad_audio_bytes = vad_audio.tobytes()
|
558 |
-
|
559 |
-
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
|
560 |
-
except Exception as e:
|
561 |
-
msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {e}"
|
562 |
-
print(msg)
|
563 |
-
return -1, ori_audio, round(time.time() - _st, 4)
|
564 |
-
|
565 |
-
def determine_pause(audio: np.ndarray, sampling_rate: int, state: VoiceAppState) -> bool:
|
566 |
-
"""Determine if a pause happened in the audio stream"""
|
567 |
-
temp_audio = audio
|
568 |
-
dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate)
|
569 |
-
duration = len(audio) / sampling_rate
|
570 |
-
|
571 |
-
if dur_vad > 0.5 and not state.started_talking:
|
572 |
-
print("started talking")
|
573 |
-
state.started_talking = True
|
574 |
-
return False
|
575 |
-
|
576 |
-
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
|
577 |
-
return (duration - dur_vad) > 1
|
578 |
-
|
579 |
-
def process_voice_audio(audio: tuple, state: VoiceAppState):
|
580 |
-
"""Process streaming audio input"""
|
581 |
-
if not VOICE_DEPENDENCIES_AVAILABLE or voice_model is None:
|
582 |
-
return None, state
|
583 |
-
|
584 |
-
if state.stream is None:
|
585 |
-
state.stream = audio[1]
|
586 |
-
state.sampling_rate = audio[0]
|
587 |
-
else:
|
588 |
-
state.stream = np.concatenate((state.stream, audio[1]))
|
589 |
-
|
590 |
-
pause_detected = determine_pause(state.stream, state.sampling_rate, state)
|
591 |
-
state.pause_detected = pause_detected
|
592 |
-
|
593 |
-
if state.pause_detected and state.started_talking:
|
594 |
-
return gr.Audio(recording=False), state
|
595 |
-
return None, state
|
596 |
-
|
597 |
-
def generate_voice_response(state: VoiceAppState):
|
598 |
-
"""Generate voice response from audio input"""
|
599 |
-
if not VOICE_DEPENDENCIES_AVAILABLE or voice_model is None:
|
600 |
-
return None, VoiceAppState()
|
601 |
-
|
602 |
-
if not state.pause_detected and not state.started_talking:
|
603 |
-
return None, VoiceAppState()
|
604 |
-
|
605 |
-
try:
|
606 |
-
audio_buffer = io.BytesIO()
|
607 |
-
segment = AudioSegment(
|
608 |
-
state.stream.tobytes(),
|
609 |
-
frame_rate=state.sampling_rate,
|
610 |
-
sample_width=state.stream.dtype.itemsize,
|
611 |
-
channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
|
612 |
-
)
|
613 |
-
segment.export(audio_buffer, format="wav")
|
614 |
-
|
615 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
616 |
-
f.write(audio_buffer.getvalue())
|
617 |
-
temp_audio_path = f.name
|
618 |
-
|
619 |
-
try:
|
620 |
-
# Generate text response from audio
|
621 |
-
text_response = voice_model.generate_response(temp_audio_path)
|
622 |
-
print(f"Generated voice response: {text_response}")
|
623 |
-
|
624 |
-
# Convert text to speech
|
625 |
-
audio_response = voice_model.text_to_speech_simple(text_response)
|
626 |
-
|
627 |
-
# Convert to format expected by Gradio
|
628 |
-
audio_segment = AudioSegment.from_file(io.BytesIO(audio_response), format="wav")
|
629 |
-
audio_array = np.array(audio_segment.get_array_of_samples())
|
630 |
-
|
631 |
-
if audio_segment.channels == 2:
|
632 |
-
audio_array = audio_array.reshape((-1, 2))
|
633 |
-
|
634 |
-
# Update conversation history
|
635 |
-
state.conversation.append({"role": "user", "content": f"[Audio message]"})
|
636 |
-
state.conversation.append({"role": "assistant", "content": text_response})
|
637 |
-
|
638 |
-
return (audio_segment.frame_rate, audio_array), VoiceAppState(conversation=state.conversation)
|
639 |
-
|
640 |
-
finally:
|
641 |
-
if os.path.exists(temp_audio_path):
|
642 |
-
os.unlink(temp_audio_path)
|
643 |
-
|
644 |
-
except Exception as e:
|
645 |
-
print(f"Error generating voice response: {e}")
|
646 |
-
return None, VoiceAppState()
|
647 |
-
|
648 |
-
def start_voice_recording(state: VoiceAppState):
|
649 |
-
"""Start recording user voice input"""
|
650 |
-
if not state.stopped:
|
651 |
-
return gr.Audio(recording=True)
|
652 |
-
return gr.Audio(recording=False)
|
653 |
-
|
654 |
-
def chunk_document(text, chunk_size=1024, overlap=100):
|
655 |
-
"""Split document into overlapping chunks for RAG"""
|
656 |
-
words = text.split()
|
657 |
-
chunks = []
|
658 |
-
|
659 |
-
for i in range(0, len(words), chunk_size - overlap):
|
660 |
-
chunk = ' '.join(words[i:i + chunk_size])
|
661 |
-
if chunk.strip():
|
662 |
-
chunks.append(chunk)
|
663 |
-
|
664 |
-
return chunks
|
665 |
-
|
666 |
-
def create_embeddings(chunks):
|
667 |
-
"""Create embeddings for document chunks"""
|
668 |
-
if embedding_model is None:
|
669 |
-
return None
|
670 |
-
|
671 |
-
try:
|
672 |
-
# Process in smaller batches on CPU
|
673 |
-
batch_size = 32
|
674 |
-
embeddings = []
|
675 |
-
|
676 |
-
for i in range(0, len(chunks), batch_size):
|
677 |
-
batch = chunks[i:i + batch_size]
|
678 |
-
batch_embeddings = embedding_model.encode(batch, show_progress_bar=False)
|
679 |
-
embeddings.extend(batch_embeddings)
|
680 |
-
|
681 |
-
return np.array(embeddings)
|
682 |
-
except Exception as e:
|
683 |
-
print(f"Error creating embeddings: {e}")
|
684 |
-
return None
|
685 |
-
|
686 |
-
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
|
687 |
-
"""Retrieve most relevant chunks for a question"""
|
688 |
-
if embedding_model is None or embeddings is None:
|
689 |
-
return chunks[:3] # Fallback to first 3 chunks
|
690 |
-
|
691 |
-
try:
|
692 |
-
question_embedding = embedding_model.encode([question], show_progress_bar=False)
|
693 |
-
similarities = cosine_similarity(question_embedding, embeddings)[0]
|
694 |
-
|
695 |
-
# Get top-k most similar chunks
|
696 |
-
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
697 |
-
relevant_chunks = [chunks[i] for i in top_indices]
|
698 |
-
|
699 |
-
return relevant_chunks
|
700 |
-
except Exception as e:
|
701 |
-
print(f"Error retrieving chunks: {e}")
|
702 |
-
return chunks[:3] # Fallback
|
703 |
-
|
704 |
-
def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
|
705 |
-
"""Main processing function for uploaded PDF"""
|
706 |
-
global processed_markdown, show_results_tab, document_chunks, document_embeddings
|
707 |
-
|
708 |
-
if pdf_file is None:
|
709 |
-
return "❌ No PDF uploaded", gr.Tabs(visible=False)
|
710 |
-
|
711 |
-
try:
|
712 |
-
# Process PDF
|
713 |
-
progress(0.1, desc="Processing PDF...")
|
714 |
-
combined_markdown, status = process_pdf_document(pdf_file, dolphin_model, progress)
|
715 |
-
|
716 |
-
if status == "processing_complete":
|
717 |
-
processed_markdown = combined_markdown
|
718 |
-
|
719 |
-
# Create chunks and embeddings for RAG
|
720 |
-
progress(0.9, desc="Creating document chunks for RAG...")
|
721 |
-
document_chunks = chunk_document(processed_markdown)
|
722 |
-
document_embeddings = create_embeddings(document_chunks)
|
723 |
-
print(f"Created {len(document_chunks)} chunks")
|
724 |
-
|
725 |
-
show_results_tab = True
|
726 |
-
progress(1.0, desc="PDF processed successfully!")
|
727 |
-
return "✅ PDF processed successfully! Chatbot is ready in the Chat tab.", gr.Tabs(visible=True)
|
728 |
-
else:
|
729 |
-
show_results_tab = False
|
730 |
-
return combined_markdown, gr.Tabs(visible=False)
|
731 |
-
|
732 |
-
except Exception as e:
|
733 |
-
show_results_tab = False
|
734 |
-
error_msg = f"❌ Error processing PDF: {str(e)}"
|
735 |
-
return error_msg, gr.Tabs(visible=False)
|
736 |
-
|
737 |
-
|
738 |
-
def get_processed_markdown():
|
739 |
-
"""Return the processed markdown content"""
|
740 |
-
global processed_markdown
|
741 |
-
return processed_markdown if processed_markdown else "No document processed yet."
|
742 |
-
|
743 |
-
|
744 |
-
def clear_all():
|
745 |
-
"""Clear all data and hide results tab"""
|
746 |
-
global processed_markdown, show_results_tab, document_chunks, document_embeddings
|
747 |
-
processed_markdown = ""
|
748 |
-
show_results_tab = False
|
749 |
-
document_chunks = []
|
750 |
-
document_embeddings = None
|
751 |
-
|
752 |
-
# Clear GPU cache
|
753 |
-
if torch.cuda.is_available():
|
754 |
-
torch.cuda.empty_cache()
|
755 |
-
|
756 |
-
return None, "", gr.Tabs(visible=False)
|
757 |
-
|
758 |
-
|
759 |
-
# Create Gradio interface
|
760 |
-
with gr.Blocks(
|
761 |
-
title="DOLPHIN PDF AI - Local Gemma 3n",
|
762 |
-
theme=gr.themes.Soft(),
|
763 |
-
css="""
|
764 |
-
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
|
765 |
-
|
766 |
-
* {
|
767 |
-
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important;
|
768 |
-
}
|
769 |
-
|
770 |
-
.main-container {
|
771 |
-
max-width: 1000px;
|
772 |
-
margin: 0 auto;
|
773 |
-
}
|
774 |
-
.upload-container {
|
775 |
-
text-align: center;
|
776 |
-
padding: 40px 20px;
|
777 |
-
border: 2px dashed #e0e0e0;
|
778 |
-
border-radius: 15px;
|
779 |
-
margin: 20px 0;
|
780 |
-
}
|
781 |
-
.upload-button {
|
782 |
-
font-size: 18px !important;
|
783 |
-
padding: 15px 30px !important;
|
784 |
-
margin: 20px 0 !important;
|
785 |
-
font-weight: 600 !important;
|
786 |
-
}
|
787 |
-
.status-message {
|
788 |
-
text-align: center;
|
789 |
-
padding: 15px;
|
790 |
-
margin: 10px 0;
|
791 |
-
border-radius: 8px;
|
792 |
-
font-weight: 500;
|
793 |
-
}
|
794 |
-
.chatbot-container {
|
795 |
-
max-height: 600px;
|
796 |
-
}
|
797 |
-
h1, h2, h3 {
|
798 |
-
font-weight: 700 !important;
|
799 |
-
}
|
800 |
-
#progress-container {
|
801 |
-
margin: 10px 0;
|
802 |
-
min-height: 20px;
|
803 |
-
}
|
804 |
-
"""
|
805 |
-
) as demo:
|
806 |
-
|
807 |
-
with gr.Tabs() as main_tabs:
|
808 |
-
# Home Tab
|
809 |
-
with gr.TabItem("🏠 Home", id="home"):
|
810 |
-
embedding_status = "✅ RAG ready" if embedding_model else "❌ RAG not loaded"
|
811 |
-
voice_status = "✅ Voice chat ready" if VOICE_DEPENDENCIES_AVAILABLE and voice_model else "❌ Voice chat not available"
|
812 |
-
gr.Markdown(
|
813 |
-
"# Scholar Express - Local Gemma 3n Version with Voice\n"
|
814 |
-
"### Upload a research paper to get a web-friendly version with AI-generated alt text for accessibility. Includes an AI chatbot and voice chat powered by local Gemma 3n.\n"
|
815 |
-
f"**System:** {model_status}\n"
|
816 |
-
f"**RAG System:** {embedding_status}\n"
|
817 |
-
f"**Voice Chat:** {voice_status}\n"
|
818 |
-
f"**DOLPHIN:** Local model for PDF processing\n"
|
819 |
-
f"**Gemma 3n:** Local model for alt text generation, chat, and voice\n"
|
820 |
-
f"**Alt Text:** Gemma 3n generates descriptive alt text for images\n"
|
821 |
-
f"**GPU:** {'CUDA available' if torch.cuda.is_available() else 'CPU only'}\n\n"
|
822 |
-
"**Features:**\n"
|
823 |
-
"- 📄 PDF processing with OCR and layout analysis\n"
|
824 |
-
"- 💬 Text-based chat about your documents\n"
|
825 |
-
"- 🎙️ Voice chat with Gemma 3n (new!)\n"
|
826 |
-
"- ♿ AI-generated alt text for accessibility"
|
827 |
-
)
|
828 |
-
|
829 |
-
with gr.Column(elem_classes="upload-container"):
|
830 |
-
gr.Markdown("## 📄 Upload Your PDF Document")
|
831 |
-
|
832 |
-
pdf_input = gr.File(
|
833 |
-
file_types=[".pdf"],
|
834 |
-
label="",
|
835 |
-
height=150,
|
836 |
-
elem_id="pdf_upload"
|
837 |
-
)
|
838 |
-
|
839 |
-
process_btn = gr.Button(
|
840 |
-
"🚀 Process PDF",
|
841 |
-
variant="primary",
|
842 |
-
size="lg",
|
843 |
-
elem_classes="upload-button"
|
844 |
-
)
|
845 |
-
|
846 |
-
clear_btn = gr.Button(
|
847 |
-
"🗑️ Clear",
|
848 |
-
variant="secondary"
|
849 |
-
)
|
850 |
-
|
851 |
-
# Dedicated progress space
|
852 |
-
progress_space = gr.HTML(
|
853 |
-
value="",
|
854 |
-
visible=False,
|
855 |
-
elem_id="progress-container"
|
856 |
-
)
|
857 |
-
|
858 |
-
# Status output (hidden during processing)
|
859 |
-
status_output = gr.Markdown(
|
860 |
-
"",
|
861 |
-
elem_classes="status-message"
|
862 |
-
)
|
863 |
-
|
864 |
-
# Results Tab (initially hidden)
|
865 |
-
with gr.TabItem("📖 Document", id="results", visible=False) as results_tab:
|
866 |
-
gr.Markdown("## Processed Document")
|
867 |
-
|
868 |
-
markdown_display = gr.Markdown(
|
869 |
-
value="",
|
870 |
-
latex_delimiters=[
|
871 |
-
{"left": "$$", "right": "$$", "display": True},
|
872 |
-
{"left": "$", "right": "$", "display": False}
|
873 |
-
],
|
874 |
-
height=700
|
875 |
-
)
|
876 |
-
|
877 |
-
# Chatbot Tab (initially hidden)
|
878 |
-
with gr.TabItem("💬 Chat", id="chat", visible=False) as chat_tab:
|
879 |
-
gr.Markdown("## Ask Questions About Your Document")
|
880 |
-
|
881 |
-
chatbot = gr.Chatbot(
|
882 |
-
value=[],
|
883 |
-
height=500,
|
884 |
-
type='messages',
|
885 |
-
elem_classes="chatbot-container",
|
886 |
-
placeholder="Your conversation will appear here once you process a document..."
|
887 |
-
)
|
888 |
-
|
889 |
-
with gr.Row():
|
890 |
-
msg_input = gr.Textbox(
|
891 |
-
placeholder="Ask a question about the processed document...",
|
892 |
-
scale=4,
|
893 |
-
container=False
|
894 |
-
)
|
895 |
-
send_btn = gr.Button("Send", variant="primary", scale=1)
|
896 |
-
|
897 |
-
gr.Markdown(
|
898 |
-
"*Ask questions about your processed document. The AI uses RAG (Retrieval-Augmented Generation) with local Gemma 3n to find relevant sections and provide accurate answers.*",
|
899 |
-
elem_id="chat-notice"
|
900 |
-
)
|
901 |
-
|
902 |
-
# Voice Chat Tab
|
903 |
-
with gr.TabItem("🎙️ Talk with Gemma", id="voice") as voice_tab:
|
904 |
-
voice_status = "✅ Voice chat ready" if VOICE_DEPENDENCIES_AVAILABLE and voice_model else "❌ Voice chat not available"
|
905 |
-
gr.Markdown(f"## Voice Chat with Gemma 3n\n{voice_status}")
|
906 |
-
|
907 |
-
if VOICE_DEPENDENCIES_AVAILABLE and voice_model:
|
908 |
-
with gr.Row():
|
909 |
-
with gr.Column():
|
910 |
-
voice_input_audio = gr.Audio(
|
911 |
-
label="Speak to Gemma",
|
912 |
-
sources=["microphone"],
|
913 |
-
type="numpy",
|
914 |
-
streaming=True
|
915 |
-
)
|
916 |
-
with gr.Column():
|
917 |
-
voice_output_audio = gr.Audio(
|
918 |
-
label="Gemma's Response",
|
919 |
-
streaming=True,
|
920 |
-
autoplay=True
|
921 |
-
)
|
922 |
-
voice_chatbot = gr.Chatbot(
|
923 |
-
label="Voice Conversation",
|
924 |
-
type="messages",
|
925 |
-
height=300
|
926 |
-
)
|
927 |
-
|
928 |
-
with gr.Row():
|
929 |
-
voice_stop_btn = gr.Button("⏹️ Stop Conversation", variant="stop")
|
930 |
-
|
931 |
-
gr.Markdown(
|
932 |
-
"*Speak naturally to Gemma 3n. The AI will listen to your voice, process your speech, and respond with both text and voice. You can have conversations before or after processing PDFs.*"
|
933 |
-
)
|
934 |
-
|
935 |
-
# Voice state
|
936 |
-
voice_state = gr.State(value=VoiceAppState())
|
937 |
-
else:
|
938 |
-
gr.Markdown(
|
939 |
-
"### Voice chat is not available\n"
|
940 |
-
"To enable voice chat, please install the required dependencies:\n"
|
941 |
-
"```bash\n"
|
942 |
-
"pip install librosa pydub onnxruntime\n"
|
943 |
-
"```\n"
|
944 |
-
"And ensure the voice_chat directory is properly set up."
|
945 |
-
)
|
946 |
-
|
947 |
-
# Event handlers
|
948 |
-
process_btn.click(
|
949 |
-
fn=process_uploaded_pdf,
|
950 |
-
inputs=[pdf_input],
|
951 |
-
outputs=[status_output, results_tab],
|
952 |
-
show_progress=True
|
953 |
-
).then(
|
954 |
-
fn=get_processed_markdown,
|
955 |
-
outputs=[markdown_display]
|
956 |
-
).then(
|
957 |
-
fn=lambda: gr.TabItem(visible=True),
|
958 |
-
outputs=[chat_tab]
|
959 |
-
)
|
960 |
-
|
961 |
-
clear_btn.click(
|
962 |
-
fn=clear_all,
|
963 |
-
outputs=[pdf_input, status_output, results_tab]
|
964 |
-
).then(
|
965 |
-
fn=lambda: gr.HTML(visible=False),
|
966 |
-
outputs=[progress_space]
|
967 |
-
).then(
|
968 |
-
fn=lambda: gr.TabItem(visible=False),
|
969 |
-
outputs=[chat_tab]
|
970 |
-
)
|
971 |
-
|
972 |
-
# Chatbot functionality with local Gemma 3n
|
973 |
-
def chatbot_response(message, history):
|
974 |
-
if not message.strip():
|
975 |
-
return history
|
976 |
-
|
977 |
-
if not processed_markdown:
|
978 |
-
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Please process a PDF document first before asking questions."}]
|
979 |
-
|
980 |
-
try:
|
981 |
-
# Use RAG to get relevant chunks from markdown
|
982 |
-
if document_chunks and len(document_chunks) > 0:
|
983 |
-
relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings, top_k=3)
|
984 |
-
context = "\n\n".join(relevant_chunks)
|
985 |
-
# Smart truncation: aim for ~6000 chars for local model
|
986 |
-
if len(context) > 6000:
|
987 |
-
# Try to cut at sentence boundaries
|
988 |
-
sentences = context[:6000].split('.')
|
989 |
-
context = '.'.join(sentences[:-1]) + '...' if len(sentences) > 1 else context[:6000] + '...'
|
990 |
-
else:
|
991 |
-
# Fallback to truncated document if RAG fails
|
992 |
-
context = processed_markdown[:6000] + "..." if len(processed_markdown) > 6000 else processed_markdown
|
993 |
-
|
994 |
-
# Create prompt for Gemma 3n
|
995 |
-
prompt = f"""You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely.
|
996 |
-
|
997 |
-
Context from the document:
|
998 |
-
{context}
|
999 |
-
|
1000 |
-
Question: {message}
|
1001 |
-
|
1002 |
-
Please provide a clear and helpful answer based on the context provided."""
|
1003 |
-
|
1004 |
-
# Generate response using local Gemma 3n
|
1005 |
-
response_text = gemma_model.chat(prompt)
|
1006 |
-
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
|
1007 |
-
|
1008 |
-
except Exception as e:
|
1009 |
-
error_msg = f"❌ Error generating response: {str(e)}"
|
1010 |
-
print(f"Full error: {e}")
|
1011 |
-
import traceback
|
1012 |
-
traceback.print_exc()
|
1013 |
-
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_msg}]
|
1014 |
-
|
1015 |
-
send_btn.click(
|
1016 |
-
fn=chatbot_response,
|
1017 |
-
inputs=[msg_input, chatbot],
|
1018 |
-
outputs=[chatbot]
|
1019 |
-
).then(
|
1020 |
-
lambda: "",
|
1021 |
-
outputs=[msg_input]
|
1022 |
-
)
|
1023 |
-
|
1024 |
-
# Also allow Enter key to send message
|
1025 |
-
msg_input.submit(
|
1026 |
-
fn=chatbot_response,
|
1027 |
-
inputs=[msg_input, chatbot],
|
1028 |
-
outputs=[chatbot]
|
1029 |
-
).then(
|
1030 |
-
lambda: "",
|
1031 |
-
outputs=[msg_input]
|
1032 |
-
)
|
1033 |
-
|
1034 |
-
# Voice chat event handlers
|
1035 |
-
if VOICE_DEPENDENCIES_AVAILABLE and voice_model:
|
1036 |
-
# Stream processing
|
1037 |
-
voice_stream = voice_input_audio.stream(
|
1038 |
-
process_voice_audio,
|
1039 |
-
[voice_input_audio, voice_state],
|
1040 |
-
[voice_input_audio, voice_state],
|
1041 |
-
stream_every=0.50,
|
1042 |
-
time_limit=30,
|
1043 |
-
)
|
1044 |
-
|
1045 |
-
# Response generation
|
1046 |
-
voice_respond = voice_input_audio.stop_recording(
|
1047 |
-
generate_voice_response,
|
1048 |
-
[voice_state],
|
1049 |
-
[voice_output_audio, voice_state]
|
1050 |
-
)
|
1051 |
-
|
1052 |
-
# Update chatbot display
|
1053 |
-
voice_respond.then(
|
1054 |
-
lambda s: s.conversation,
|
1055 |
-
[voice_state],
|
1056 |
-
[voice_chatbot]
|
1057 |
-
)
|
1058 |
-
|
1059 |
-
# Restart recording
|
1060 |
-
voice_restart = voice_output_audio.stop(
|
1061 |
-
start_voice_recording,
|
1062 |
-
[voice_state],
|
1063 |
-
[voice_input_audio]
|
1064 |
-
)
|
1065 |
-
|
1066 |
-
# Stop conversation
|
1067 |
-
voice_stop_btn.click(
|
1068 |
-
lambda: (VoiceAppState(stopped=True), gr.Audio(recording=False)),
|
1069 |
-
None,
|
1070 |
-
[voice_state, voice_input_audio],
|
1071 |
-
cancels=[voice_respond, voice_restart]
|
1072 |
-
)
|
1073 |
-
|
1074 |
-
|
1075 |
-
if __name__ == "__main__":
|
1076 |
-
demo.launch(
|
1077 |
-
server_name="0.0.0.0",
|
1078 |
-
server_port=7860,
|
1079 |
-
share=False,
|
1080 |
-
show_error=True,
|
1081 |
-
max_threads=4,
|
1082 |
-
inbrowser=False,
|
1083 |
-
quiet=True
|
1084 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assets/demo.gif
DELETED
Git LFS Details
|
assets/dolphin.png
DELETED
Git LFS Details
|
assets/framework.png
DELETED
Git LFS Details
|
chat.py
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
3 |
-
SPDX-License-Identifier: MIT
|
4 |
-
"""
|
5 |
-
|
6 |
-
import os
|
7 |
-
import warnings
|
8 |
-
from collections import OrderedDict
|
9 |
-
|
10 |
-
from omegaconf import ListConfig
|
11 |
-
|
12 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
13 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
14 |
-
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
|
15 |
-
|
16 |
-
import torch
|
17 |
-
from PIL import Image
|
18 |
-
from transformers import PreTrainedTokenizerFast
|
19 |
-
|
20 |
-
from utils.model import DonutConfig, DonutModel, SwinEncoder
|
21 |
-
from utils.processor import DolphinProcessor
|
22 |
-
|
23 |
-
|
24 |
-
def try_rename_lagacy_weights(ckpt, output_path=""):
|
25 |
-
if "state_dict" in ckpt.keys():
|
26 |
-
ckpt = ckpt["state_dict"]
|
27 |
-
if "module" in ckpt.keys():
|
28 |
-
ckpt = ckpt["module"]
|
29 |
-
new_ckpt = OrderedDict()
|
30 |
-
for k, v in ckpt.items():
|
31 |
-
if k.startswith("model."):
|
32 |
-
k = k[len("model.") :]
|
33 |
-
if k.startswith("encoder"):
|
34 |
-
new_ckpt["vpm" + k[len("encoder") :]] = v
|
35 |
-
elif k.startswith("decoder"):
|
36 |
-
new_ckpt["llm" + k[len("encoder") :]] = v
|
37 |
-
else:
|
38 |
-
new_ckpt[k] = v
|
39 |
-
if output_path:
|
40 |
-
torch.save(new_ckpt, output_path)
|
41 |
-
return new_ckpt
|
42 |
-
|
43 |
-
|
44 |
-
def convert_listconfig_to_list(config):
|
45 |
-
new_config = {}
|
46 |
-
for k, v in config.items():
|
47 |
-
if isinstance(v, ListConfig):
|
48 |
-
new_config[k] = list(v)
|
49 |
-
else:
|
50 |
-
new_config[k] = v
|
51 |
-
return new_config
|
52 |
-
|
53 |
-
|
54 |
-
class DOLPHIN:
|
55 |
-
def __init__(self, config, ckpt_path="") -> None:
|
56 |
-
self.model_args = config.model
|
57 |
-
self.swin_args = config.model.pop("swin_args")
|
58 |
-
self.swin_args = convert_listconfig_to_list(self.swin_args)
|
59 |
-
|
60 |
-
vision_tower = SwinEncoder(
|
61 |
-
input_size=self.swin_args["img_size"],
|
62 |
-
patch_size=self.swin_args["patch_size"],
|
63 |
-
embed_dim=self.swin_args["embed_dim"],
|
64 |
-
window_size=self.swin_args["window_size"],
|
65 |
-
encoder_layer=self.swin_args["encoder_layer"],
|
66 |
-
num_heads=self.swin_args["num_heads"],
|
67 |
-
align_long_axis=self.swin_args["align_long_axis"],
|
68 |
-
)
|
69 |
-
|
70 |
-
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path)
|
71 |
-
self.tokenizer.pad_token = "<pad>"
|
72 |
-
self.tokenizer.bos_token = "<s>"
|
73 |
-
self.tokenizer.eos_token = "</s>"
|
74 |
-
self.tokenizer.unk_token = "<unk>"
|
75 |
-
|
76 |
-
if self.model_args.get("extra_answer_tokens", False):
|
77 |
-
# print("Allowing multitask training: adding <Answer/> to the tokenizer.")
|
78 |
-
prompt_end_token = " <Answer/>"
|
79 |
-
self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))})
|
80 |
-
self.tokenizer._prompt_end_token = prompt_end_token
|
81 |
-
self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token)
|
82 |
-
|
83 |
-
donut_config = DonutConfig(
|
84 |
-
decoder_layer=self.model_args.decoder_layer,
|
85 |
-
max_length=self.model_args.max_length,
|
86 |
-
max_position_embeddings=self.model_args.max_position_embeddings,
|
87 |
-
hidden_dimension=self.model_args.hidden_dimension,
|
88 |
-
)
|
89 |
-
|
90 |
-
self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer)
|
91 |
-
if self.model_args.model_name_or_path:
|
92 |
-
ckpt = torch.load(self.model_args.model_name_or_path)
|
93 |
-
ckpt = try_rename_lagacy_weights(ckpt)
|
94 |
-
self.model.load_state_dict(ckpt, strict=True)
|
95 |
-
|
96 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
97 |
-
self.model.to(device)
|
98 |
-
self.model.eval()
|
99 |
-
transform_args = {
|
100 |
-
"input_size": self.swin_args["img_size"],
|
101 |
-
"max_length": self.model_args.max_length,
|
102 |
-
}
|
103 |
-
self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args)
|
104 |
-
|
105 |
-
def chat(
|
106 |
-
self,
|
107 |
-
question,
|
108 |
-
image,
|
109 |
-
return_raw=False,
|
110 |
-
return_score=False,
|
111 |
-
return_img_size=False,
|
112 |
-
only_return_img_size=False,
|
113 |
-
max_batch_size=16,
|
114 |
-
):
|
115 |
-
|
116 |
-
def _preprocess_image(image):
|
117 |
-
if isinstance(image, str):
|
118 |
-
image = Image.open(image).convert("RGB")
|
119 |
-
if return_img_size or only_return_img_size:
|
120 |
-
image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True)
|
121 |
-
else:
|
122 |
-
image_tensor = self.processor.process_image_for_inference(image, return_img_size=False)
|
123 |
-
ori_size = None
|
124 |
-
return image_tensor, ori_size
|
125 |
-
|
126 |
-
def _preprocess_prompt(question):
|
127 |
-
if self.model_args.get("extra_answer_tokens", False):
|
128 |
-
if self.tokenizer._prompt_end_token not in question:
|
129 |
-
question = question + self.tokenizer._prompt_end_token
|
130 |
-
prompt_ids = self.processor.process_prompt_for_inference(question)
|
131 |
-
return prompt_ids
|
132 |
-
|
133 |
-
def _preprocess_prompt_batch(question):
|
134 |
-
if self.model_args.get("extra_answer_tokens", False):
|
135 |
-
for i in range(len(question)):
|
136 |
-
if self.tokenizer._prompt_end_token not in question[i]:
|
137 |
-
question[i] = question[i] + self.tokenizer._prompt_end_token
|
138 |
-
if not question[i].startswith("<s>"):
|
139 |
-
question[i] = "<s>" + question[i]
|
140 |
-
return question
|
141 |
-
|
142 |
-
def _postprocess(output, question):
|
143 |
-
output = output.replace("<s>", "").replace(question, "").replace("</s>", "").replace("<pad>", "")
|
144 |
-
if self.model_args.get("extra_answer_tokens", False):
|
145 |
-
output = output.split(self.tokenizer._prompt_end_token)[-1]
|
146 |
-
return output
|
147 |
-
|
148 |
-
if isinstance(question, list):
|
149 |
-
image_tensor_list = []
|
150 |
-
for i in image:
|
151 |
-
image_tensor, ori_size = _preprocess_image(i)
|
152 |
-
image_tensor_list.append(image_tensor)
|
153 |
-
image_tensor = torch.cat(image_tensor_list, dim=0)
|
154 |
-
|
155 |
-
question = _preprocess_prompt_batch(question)
|
156 |
-
self.processor.tokenizer.padding_side = "left"
|
157 |
-
prompt_ids = self.processor.tokenizer(
|
158 |
-
question, add_special_tokens=False, return_tensors="pt", padding=True
|
159 |
-
).input_ids
|
160 |
-
else:
|
161 |
-
image_tensor, ori_size = _preprocess_image(image)
|
162 |
-
prompt_ids = _preprocess_prompt(question)
|
163 |
-
|
164 |
-
if only_return_img_size:
|
165 |
-
return ori_size
|
166 |
-
|
167 |
-
model_output_batch = []
|
168 |
-
for i in range(0, image_tensor.shape[0], max_batch_size):
|
169 |
-
image_tensor_batch = image_tensor[i : i + max_batch_size]
|
170 |
-
prompt_ids_batch = prompt_ids[i : i + max_batch_size]
|
171 |
-
model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch)
|
172 |
-
model_output_batch.append(model_output)
|
173 |
-
model_output = {}
|
174 |
-
for k, v in model_output_batch[0].items():
|
175 |
-
if isinstance(v, torch.Tensor):
|
176 |
-
model_output[k] = sum(
|
177 |
-
[v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch],
|
178 |
-
[],
|
179 |
-
)
|
180 |
-
else:
|
181 |
-
model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], [])
|
182 |
-
|
183 |
-
if return_raw:
|
184 |
-
if return_img_size:
|
185 |
-
return model_output, ori_size
|
186 |
-
return model_output
|
187 |
-
else:
|
188 |
-
if isinstance(question, list):
|
189 |
-
output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))]
|
190 |
-
score = model_output["scores"]
|
191 |
-
else:
|
192 |
-
output = _postprocess(model_output["repetitions"][0], question)
|
193 |
-
score = model_output["scores"][0]
|
194 |
-
if return_score:
|
195 |
-
return output, score
|
196 |
-
if return_img_size:
|
197 |
-
return output, ori_size
|
198 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/Dolphin.yaml
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
model:
|
2 |
-
model_name_or_path: "./checkpoints/dolphin_model.bin"
|
3 |
-
tokenizer_path: "./checkpoints/dolphin_tokenizer.json"
|
4 |
-
extra_answer_tokens: True # add <Answer/> token
|
5 |
-
max_length: 4096
|
6 |
-
decoder_layer: 10
|
7 |
-
max_position_embeddings: 4096
|
8 |
-
hidden_dimension: 1024
|
9 |
-
swin_args:
|
10 |
-
name: 'swin'
|
11 |
-
img_size: [896, 896]
|
12 |
-
patch_size: 4
|
13 |
-
embed_dim: 128
|
14 |
-
align_long_axis: False
|
15 |
-
window_size: 7
|
16 |
-
encoder_layer: [2, 2, 14, 2]
|
17 |
-
num_heads: [4, 8, 16, 32]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
demo/element_imgs/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
demo/element_imgs/block_formula.jpeg
DELETED
Git LFS Details
|
demo/element_imgs/line_formula.jpeg
DELETED
Git LFS Details
|
demo/element_imgs/markdown/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
demo/element_imgs/markdown/table_1.md
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
<table><tr><td></td><td></td><td>100-class (top-1 acc.)</td><td>1000-class (top-1 acc.)</td></tr><tr><td colspan="2">4096-d (float)</td><td>77.1 ± 1.5</td><td>65.0</td></tr><tr><td rowspan="3">1024 bits</td><td>BP</td><td>72.9 ± 1.3</td><td>58.1</td></tr><tr><td>CBE</td><td>73.0 ± 1.3</td><td>59.2</td></tr><tr><td>SP</td><td>73.8 ± 1.3</td><td>60.1</td></tr><tr><td rowspan="4">4096 bits</td><td>threshold [1]</td><td>73.5 ± 1.4</td><td>59.1</td></tr><tr><td>BP</td><td>76.0 ± 1.5</td><td>63.2</td></tr><tr><td>CBE</td><td>75.9 ± 1.4</td><td>63.0</td></tr><tr><td>SP</td><td>76.3 ± 1.5</td><td>63.3</td></tr><tr><td>8192 bits</td><td>SP</td><td>76.8 ± 1.4</td><td>64.2</td></tr><tr><td>16384 bits</td><td>SP</td><td>77.1 ± 1.6</td><td>64.5</td></tr></table>
|
2 |
-
|
|
|
|
|
|
demo/element_imgs/para_1.jpg
DELETED
Git LFS Details
|
demo/element_imgs/para_2.jpg
DELETED
Git LFS Details
|
demo/element_imgs/para_3.jpeg
DELETED
Git LFS Details
|
demo/element_imgs/recognition_json/table_1.json
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"label": "tab",
|
4 |
-
"text": "<table><tr><td></td><td></td><td>100-class (top-1 acc.)</td><td>1000-class (top-1 acc.)</td></tr><tr><td colspan=\"2\">4096-d (float)</td><td>77.1 ± 1.5</td><td>65.0</td></tr><tr><td rowspan=\"3\">1024 bits</td><td>BP</td><td>72.9 ± 1.3</td><td>58.1</td></tr><tr><td>CBE</td><td>73.0 ± 1.3</td><td>59.2</td></tr><tr><td>SP</td><td>73.8 ± 1.3</td><td>60.1</td></tr><tr><td rowspan=\"4\">4096 bits</td><td>threshold [1]</td><td>73.5 ± 1.4</td><td>59.1</td></tr><tr><td>BP</td><td>76.0 ± 1.5</td><td>63.2</td></tr><tr><td>CBE</td><td>75.9 ± 1.4</td><td>63.0</td></tr><tr><td>SP</td><td>76.3 ± 1.5</td><td>63.3</td></tr><tr><td>8192 bits</td><td>SP</td><td>76.8 ± 1.4</td><td>64.2</td></tr><tr><td>16384 bits</td><td>SP</td><td>77.1 ± 1.6</td><td>64.5</td></tr></table>"
|
5 |
-
}
|
6 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/element_imgs/table_1.jpeg
DELETED
Git LFS Details
|
demo/element_imgs/table_2.jpeg
DELETED
Git LFS Details
|
demo/page_imgs/.DS_Store
DELETED
Binary file (8.2 kB)
|
|
demo/page_imgs/markdown/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
demo/page_imgs/markdown/figures/.DS_Store
DELETED
Binary file (6.15 kB)
|
|
demo/page_imgs/markdown/figures/test_page3_figure_000.png
DELETED
Git LFS Details
|
demo/page_imgs/markdown/test_page3.md
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-

|
2 |
-
|
3 |
-
Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several attention layers running in parallel.
|
4 |
-
|
5 |
-
query with all keys, divide each by $\sqrt{d_k}$ , and apply a softmax function to obtain the weights on the values.
|
6 |
-
|
7 |
-
In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix $Q$ . The keys and values are also packed together into matrices $K$ and $V$ . We compute the matrix of outputs as: $$ \\ \text{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V \\ $$
|
8 |
-
|
9 |
-
The two most commonly used attention functions are additive attention [2] , and dot-product (multiplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor of $\frac{1}{\sqrt{d_k}}$ . Additive attention computes the compatibility function using a feed-forward network with a single hidden layer. While the two are similar in theoretical complexity, dot-product attention is much faster and more space-efficient in practice, since it can be implemented using highly optimized matrix multiplication code.
|
10 |
-
|
11 |
-
While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of $d_k$ [ 3 ] . We suspect that for large values of $d_k$ , the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4 To counteract this effect, we scale the dot products by $\frac{1}{\sqrt{d_k}}$ .
|
12 |
-
|
13 |
-
3.2.2 Multi-Head Attention
|
14 |
-
|
15 |
-
Instead of performing a single attention function with $d_{\text{model}}$ -dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values $h$ times with different, learned linear projections to $d_k$ , $d_k$ and $d_v$ dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding $d_v$ -dimensional output values. These are concatenated and once again projected, resulting in the final values, as depicted in Figure 2 .
|
16 |
-
|
17 |
-
Multihead attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.
|
18 |
-
|
19 |
-
${ }^{4}$ To illustrate why the dot products get large, assume that the components of $q$ and $k$ are independent random variables with mean 0 and variance 1 . Then their dot product, $q \cdot k=\sum_{i=1}^{d_{k}} q_{i} k_{i}$, has mean 0 and variance $d_{k}$.
|
20 |
-
|
21 |
-
4
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/page_imgs/page_1.jpeg
DELETED
Git LFS Details
|
demo/page_imgs/page_2.jpeg
DELETED
Git LFS Details
|
demo/page_imgs/page_3.jpeg
DELETED
Git LFS Details
|
demo/page_imgs/page_4.png
DELETED
Git LFS Details
|
demo/page_imgs/page_5.jpg
DELETED
Git LFS Details
|
demo/page_imgs/page_6.pdf
DELETED
The diff for this file is too large to render.
See raw diff
|
|
demo/page_imgs/page_7.jpeg
DELETED
Git LFS Details
|
demo/page_imgs/recognition_json/page_1.json
DELETED
@@ -1,178 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"label": "title",
|
4 |
-
"bbox": [
|
5 |
-
271,
|
6 |
-
188,
|
7 |
-
1194,
|
8 |
-
221
|
9 |
-
],
|
10 |
-
"text": "LLaMA: Open and Efficient Foundation Language Models",
|
11 |
-
"reading_order": 0
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"label": "author",
|
15 |
-
"bbox": [
|
16 |
-
313,
|
17 |
-
272,
|
18 |
-
1154,
|
19 |
-
317
|
20 |
-
],
|
21 |
-
"text": "Hugo Touvron; Thibaut Lavril*, Gautier Izacard*, Xavier Martinet",
|
22 |
-
"reading_order": 1
|
23 |
-
},
|
24 |
-
{
|
25 |
-
"label": "para",
|
26 |
-
"bbox": [
|
27 |
-
269,
|
28 |
-
317,
|
29 |
-
1201,
|
30 |
-
425
|
31 |
-
],
|
32 |
-
"text": "Marie-Anne Lachaux, Timothee Lacroix, Baptiste Rozière, Naman Goyal\nEric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin\nEdouard Grave*Guillaume Lample*",
|
33 |
-
"reading_order": 2
|
34 |
-
},
|
35 |
-
{
|
36 |
-
"label": "para",
|
37 |
-
"bbox": [
|
38 |
-
685,
|
39 |
-
440,
|
40 |
-
795,
|
41 |
-
482
|
42 |
-
],
|
43 |
-
"text": "Meta AI",
|
44 |
-
"reading_order": 3
|
45 |
-
},
|
46 |
-
{
|
47 |
-
"label": "sec",
|
48 |
-
"bbox": [
|
49 |
-
376,
|
50 |
-
524,
|
51 |
-
502,
|
52 |
-
565
|
53 |
-
],
|
54 |
-
"text": "\\begin{abstract}",
|
55 |
-
"reading_order": 4
|
56 |
-
},
|
57 |
-
{
|
58 |
-
"label": "para",
|
59 |
-
"bbox": [
|
60 |
-
209,
|
61 |
-
586,
|
62 |
-
675,
|
63 |
-
946
|
64 |
-
],
|
65 |
-
"text": "We introduce LLaMA, a collection of founda-\ntion language models ranging from 7B to 65B\nparameters. We train our models on trillions\nof tokens, and show that it is possible to train\nstate-of-the-art models using publicly avail-\nable datasets exclusively, without resorting\nto proprietary and inaccessible datasets. In\nparticular, LLaMA-13B outperforms GPT-3\n(175B) on most benchmarks, and LLaMA-\n65B is competitive with the best models,\nChinchilla-70B and PaLM-540B. We release\nall our models to the research community $^1$ .",
|
66 |
-
"reading_order": 5
|
67 |
-
},
|
68 |
-
{
|
69 |
-
"label": "sec",
|
70 |
-
"bbox": [
|
71 |
-
167,
|
72 |
-
964,
|
73 |
-
376,
|
74 |
-
1006
|
75 |
-
],
|
76 |
-
"text": "1 Introduction",
|
77 |
-
"reading_order": 6
|
78 |
-
},
|
79 |
-
{
|
80 |
-
"label": "para",
|
81 |
-
"bbox": [
|
82 |
-
167,
|
83 |
-
1027,
|
84 |
-
718,
|
85 |
-
1498
|
86 |
-
],
|
87 |
-
"text": "Large Languages Models (LLMs) trained on mas-\nsive corpora of texts have shown their ability to per-\nform new tasks from textual instructions or from a\nfew examples ( Brown et al. , 2020 ) . These few-shot\nproperties first appeared when scaling models to a\nsufficient size ( Kaplan et al. , 2020 ) , resulting in a\nline of work that focuses on further scaling these\nmodels ( Chowdhery et al. , 2022 ; Rae et al. , 2021 ) .\nThese efforts are based on the assumption that\nmore parameters will lead to better performance.\nHowever, recent work from Hoffmann et al. ( 2022 )\nshows that, for a given compute budget, the best\nperformances are not achieved by the largest mod-\nels, but by smaller models trained on more data.",
|
88 |
-
"reading_order": 7
|
89 |
-
},
|
90 |
-
{
|
91 |
-
"label": "para",
|
92 |
-
"bbox": [
|
93 |
-
167,
|
94 |
-
1506,
|
95 |
-
717,
|
96 |
-
1844
|
97 |
-
],
|
98 |
-
"text": "The objective of the scaling laws from Hoff-\nmann et al. ( 2022 ) is to determine how to best\nscale the dataset and model sizes for a particular\ntraining compute budget. However, this objective\ndisregards the inference budget, which becomes\ncritical when serving a language model at scale.\nIn this context, given a target level of performance,\nthe preferred model is not the fastest to train but the\nfastest at inference, and although it may be cheaper\nto train a large model to reach a certain level of",
|
99 |
-
"reading_order": 8
|
100 |
-
},
|
101 |
-
{
|
102 |
-
"label": "para",
|
103 |
-
"bbox": [
|
104 |
-
753,
|
105 |
-
539,
|
106 |
-
1304,
|
107 |
-
734
|
108 |
-
],
|
109 |
-
"text": "performance, a smaller one trained longer will\nultimately be cheaper at inference. For instance,\nalthough Hoffmann et al. ( 2022 ) recommends\ntraining a 10B model on 200B tokens, we find\nthat the performance of a 7B model continues to\nimprove even after 1T tokens.",
|
110 |
-
"reading_order": 9
|
111 |
-
},
|
112 |
-
{
|
113 |
-
"label": "para",
|
114 |
-
"bbox": [
|
115 |
-
753,
|
116 |
-
769,
|
117 |
-
1305,
|
118 |
-
1236
|
119 |
-
],
|
120 |
-
"text": "The focus of this work is to train a series of\nlanguage models that achieve the best possible per-\nformance at various inference budgets, by training\non more tokens than what is typically used. The\nresulting models, called LLaMA , ranges from 7B\nto 65B parameters with competitive performance\ncompared to the best existing LLMs. For instance,\nLLaMA-13B outperforms GPT-3 on most bench-\nmarks, despite being 10 $\\times$ smaller. We believe that\nthis model will help democratize the access and\nstudy of LLMs, since it can be run on a single GPU.\nAt the higher-end of the scale, our 65B-parameter\nmodel is also competitive with the best large lan-\nguage models such as Chinchilla or PaLM-540B.",
|
121 |
-
"reading_order": 10
|
122 |
-
},
|
123 |
-
{
|
124 |
-
"label": "para",
|
125 |
-
"bbox": [
|
126 |
-
753,
|
127 |
-
1257,
|
128 |
-
1305,
|
129 |
-
1601
|
130 |
-
],
|
131 |
-
"text": "Unlike Chinchilla, PaLM, or GPT-3, we only\nuse publicly available data, making our work com-\npatible with open-sourcing, while most existing\nmodels rely on data which is either not publicly\navailable or undocumented (e.g. “ Books – 2TB ” or\n“ Social media conversations ” ). There exist some\nexceptions, notably OPT ( Zhang et al. , 2022 ) ,\nGPT-NeoX ( Black et al. , 2022 ) , BLOOM ( Scao\net al. , 2022 ) and GLM ( Zeng et al. , 2022 ) , but none\nthat are competitive with PaLM-62B or Chinchilla.",
|
132 |
-
"reading_order": 11
|
133 |
-
},
|
134 |
-
{
|
135 |
-
"label": "para",
|
136 |
-
"bbox": [
|
137 |
-
753,
|
138 |
-
1634,
|
139 |
-
1304,
|
140 |
-
1933
|
141 |
-
],
|
142 |
-
"text": "In the rest of this paper, we present an overview\nof the modifications we made to the transformer\narchitecture ( Vaswani et al. , 2017 ) , as well as our\ntraining method. We then report the performance of\nour models and compare with others LLMs on a set\nof standard benchmarks. Finally, we expose some\nof the biases and toxicity encoded in our models,\nusing some of the most recent benchmarks from\nthe responsible AI community.",
|
143 |
-
"reading_order": 12
|
144 |
-
},
|
145 |
-
{
|
146 |
-
"label": "fnote",
|
147 |
-
"bbox": [
|
148 |
-
167,
|
149 |
-
1844,
|
150 |
-
712,
|
151 |
-
1907
|
152 |
-
],
|
153 |
-
"text": "* Equal contribution.\nCorrespondence:\n{htouvron\nthibautlav,gizacard,egrave,glample}@meta.com",
|
154 |
-
"reading_order": 13
|
155 |
-
},
|
156 |
-
{
|
157 |
-
"label": "fnote",
|
158 |
-
"bbox": [
|
159 |
-
209,
|
160 |
-
1907,
|
161 |
-
632,
|
162 |
-
1931
|
163 |
-
],
|
164 |
-
"text": "https://github.com/facebookresearch/llama",
|
165 |
-
"reading_order": 14
|
166 |
-
},
|
167 |
-
{
|
168 |
-
"label": "watermark",
|
169 |
-
"bbox": [
|
170 |
-
20,
|
171 |
-
649,
|
172 |
-
83,
|
173 |
-
1530
|
174 |
-
],
|
175 |
-
"text": "arXiv:2302.1397lvl [cs.CL] 27 Feb 2023",
|
176 |
-
"reading_order": 15
|
177 |
-
}
|
178 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/page_imgs/recognition_json/test_page.json
DELETED
@@ -1,47 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"label": "header",
|
4 |
-
"bbox": [
|
5 |
-
291,
|
6 |
-
90,
|
7 |
-
675,
|
8 |
-
120
|
9 |
-
],
|
10 |
-
"text": "Scaled Dot-Product Attention",
|
11 |
-
"reading_order": 0
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"label": "fig",
|
15 |
-
"text": "",
|
16 |
-
"figure_path": "figures/test_page_figure_001.png",
|
17 |
-
"bbox": [
|
18 |
-
1274,
|
19 |
-
105,
|
20 |
-
1536,
|
21 |
-
627
|
22 |
-
],
|
23 |
-
"reading_order": 1
|
24 |
-
},
|
25 |
-
{
|
26 |
-
"label": "cap",
|
27 |
-
"bbox": [
|
28 |
-
168,
|
29 |
-
719,
|
30 |
-
1413,
|
31 |
-
789
|
32 |
-
],
|
33 |
-
"text": "Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several\nattention layers running in parallel.",
|
34 |
-
"reading_order": 2
|
35 |
-
},
|
36 |
-
{
|
37 |
-
"label": "para",
|
38 |
-
"bbox": [
|
39 |
-
168,
|
40 |
-
858,
|
41 |
-
1413,
|
42 |
-
934
|
43 |
-
],
|
44 |
-
"text": "query with all keys, divide each by $\\sqrt{d_{k}}$, and apply a softmax function to obtain the weights on the\nvalues.",
|
45 |
-
"reading_order": 3
|
46 |
-
}
|
47 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/page_imgs/recognition_json/test_page2.json
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"label": "fig",
|
4 |
-
"text": "",
|
5 |
-
"figure_path": "figures/test_page2_figure_000.png",
|
6 |
-
"bbox": [
|
7 |
-
394,
|
8 |
-
117,
|
9 |
-
897,
|
10 |
-
837
|
11 |
-
],
|
12 |
-
"reading_order": 0
|
13 |
-
},
|
14 |
-
{
|
15 |
-
"label": "cap",
|
16 |
-
"bbox": [
|
17 |
-
445,
|
18 |
-
852,
|
19 |
-
856,
|
20 |
-
873
|
21 |
-
],
|
22 |
-
"text": "Figure 1: The Transformer - model architecture",
|
23 |
-
"reading_order": 1
|
24 |
-
},
|
25 |
-
{
|
26 |
-
"label": "para",
|
27 |
-
"bbox": [
|
28 |
-
218,
|
29 |
-
920,
|
30 |
-
1086,
|
31 |
-
1044
|
32 |
-
],
|
33 |
-
"text": "wise fully connected feed-forward network. We employ a residual connection [ 10 ] around each of\nthe two sub-layers, followed by layer normalization [ 1 ] . That is, the output of each sub-layer is\n$\\mathrm{LayerNorm}(x+\\mathrm{Sublayer}(x))$ , where $\\mathrm{Sublayer}(x)$ is the function implemented by the sub-layer\nitself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding\nlayers, produce outputs of dimension $d_{\\text{model}}=512$ .",
|
34 |
-
"reading_order": 2
|
35 |
-
},
|
36 |
-
{
|
37 |
-
"label": "para",
|
38 |
-
"bbox": [
|
39 |
-
218,
|
40 |
-
1071,
|
41 |
-
1085,
|
42 |
-
1244
|
43 |
-
],
|
44 |
-
"text": "The The decoder is also composed of a stack of $N=6$ identical layers. In addition to the two\nsub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head\nattention over the output of the encoder stack. Similar to the encoder, we employ residual connections\naround each of the sub-layers, followed by layer normalization. We also modify the self-attention\nsub-layer in the decoder stack to prevent positions from attending to subsequent positions. This\nmasking, combined with fact that the output embeddings are offset by one position, ensures that the\npredictions for position $i$ can depend only on the known outputs at positions less than $i$ .",
|
45 |
-
"reading_order": 3
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"label": "sub_sec",
|
49 |
-
"bbox": [
|
50 |
-
226,
|
51 |
-
1283,
|
52 |
-
344,
|
53 |
-
1305
|
54 |
-
],
|
55 |
-
"text": "3.2 Attention",
|
56 |
-
"reading_order": 4
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"label": "para",
|
60 |
-
"bbox": [
|
61 |
-
218,
|
62 |
-
1322,
|
63 |
-
1087,
|
64 |
-
1422
|
65 |
-
],
|
66 |
-
"text": "An attention function can be described as mapping a query and a set of key-value pairs to an output,\nwhere the query, keys, values, and output are all vectors. The output is computed as a weighted sum\nof the values, where the weight assigned to each value is computed by a compatibility function of the\nquery with the corresponding key.",
|
67 |
-
"reading_order": 5
|
68 |
-
},
|
69 |
-
{
|
70 |
-
"label": "sub_sub_sec",
|
71 |
-
"bbox": [
|
72 |
-
218,
|
73 |
-
1456,
|
74 |
-
562,
|
75 |
-
1474
|
76 |
-
],
|
77 |
-
"text": "3.2.1 Scaled Dot-Product Attention",
|
78 |
-
"reading_order": 6
|
79 |
-
},
|
80 |
-
{
|
81 |
-
"label": "para",
|
82 |
-
"bbox": [
|
83 |
-
218,
|
84 |
-
1498,
|
85 |
-
1085,
|
86 |
-
1546
|
87 |
-
],
|
88 |
-
"text": "We call our particular attention \"Scaled Dot-Product Attention\" (Figure 2 ). The input consists of\nqueries and keys of dimension $d_k$ , and values of dimension $d_v$ . We compute the dot products of the",
|
89 |
-
"reading_order": 7
|
90 |
-
},
|
91 |
-
{
|
92 |
-
"label": "foot",
|
93 |
-
"bbox": [
|
94 |
-
646,
|
95 |
-
1590,
|
96 |
-
662,
|
97 |
-
1607
|
98 |
-
],
|
99 |
-
"text": "3",
|
100 |
-
"reading_order": 8
|
101 |
-
}
|
102 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/page_imgs/recognition_json/test_page3.json
DELETED
@@ -1,124 +0,0 @@
|
|
1 |
-
[
|
2 |
-
{
|
3 |
-
"label": "fig",
|
4 |
-
"text": "",
|
5 |
-
"figure_path": "figures/test_page3_figure_000.png",
|
6 |
-
"bbox": [
|
7 |
-
331,
|
8 |
-
134,
|
9 |
-
984,
|
10 |
-
489
|
11 |
-
],
|
12 |
-
"reading_order": 0
|
13 |
-
},
|
14 |
-
{
|
15 |
-
"label": "cap",
|
16 |
-
"bbox": [
|
17 |
-
198,
|
18 |
-
554,
|
19 |
-
1065,
|
20 |
-
603
|
21 |
-
],
|
22 |
-
"text": "Figure 2: (left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several\nattention layers running in parallel.",
|
23 |
-
"reading_order": 1
|
24 |
-
},
|
25 |
-
{
|
26 |
-
"label": "para",
|
27 |
-
"bbox": [
|
28 |
-
198,
|
29 |
-
652,
|
30 |
-
1065,
|
31 |
-
701
|
32 |
-
],
|
33 |
-
"text": "query with all keys, divide each by $\\sqrt{d_k}$ , and apply a softmax function to obtain the weights on the\nvalues.",
|
34 |
-
"reading_order": 2
|
35 |
-
},
|
36 |
-
{
|
37 |
-
"label": "para",
|
38 |
-
"bbox": [
|
39 |
-
198,
|
40 |
-
715,
|
41 |
-
1065,
|
42 |
-
881
|
43 |
-
],
|
44 |
-
"text": "In practice, we compute the attention function on a set of queries simultaneously, packed together\ninto a matrix $Q$ . The keys and values are also packed together into matrices $K$ and $V$ . We compute\nthe matrix of outputs as:\n\\[\n \\text{Attention}(Q, K, V) = \\mathrm{softmax}(\\frac{QK^T}{\\sqrt{d_k}})V\n\\]",
|
45 |
-
"reading_order": 3
|
46 |
-
},
|
47 |
-
{
|
48 |
-
"label": "para",
|
49 |
-
"bbox": [
|
50 |
-
198,
|
51 |
-
913,
|
52 |
-
1068,
|
53 |
-
1060
|
54 |
-
],
|
55 |
-
"text": "The two most commonly used attention functions are additive attention [2] , and dot-product (multi-\nplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor\nof $\\frac{1}{\\sqrt{d_k}}$ . Additive attention computes the compatibility function using a feed-forward network with\na single hidden layer. While the two are similar in theoretical complexity, dot-product attention is\nmuch faster and more space-efficient in practice, since it can be implemented using highly optimized\nmatrix multiplication code.",
|
56 |
-
"reading_order": 4
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"label": "para",
|
60 |
-
"bbox": [
|
61 |
-
198,
|
62 |
-
1074,
|
63 |
-
1066,
|
64 |
-
1175
|
65 |
-
],
|
66 |
-
"text": "While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms\ndot product attention without scaling for larger values of $d_k$ [ 3 ] . We suspect that for large values of\n$d_k$ , the dot products grow large in magnitude, pushing the softmax function into regions where it has\nextremely small gradients 4 To counteract this effect, we scale the dot products by $\\frac{1}{\\sqrt{d_k}}$ .",
|
67 |
-
"reading_order": 5
|
68 |
-
},
|
69 |
-
{
|
70 |
-
"label": "sub_sub_sec",
|
71 |
-
"bbox": [
|
72 |
-
198,
|
73 |
-
1207,
|
74 |
-
467,
|
75 |
-
1225
|
76 |
-
],
|
77 |
-
"text": "3.2.2 Multi-Head Attention",
|
78 |
-
"reading_order": 6
|
79 |
-
},
|
80 |
-
{
|
81 |
-
"label": "para",
|
82 |
-
"bbox": [
|
83 |
-
198,
|
84 |
-
1253,
|
85 |
-
1067,
|
86 |
-
1395
|
87 |
-
],
|
88 |
-
"text": "Instead of performing a single attention function with $d_{\\text{model}}$ -dimensional keys, values and queries,\nwe found it beneficial to linearly project the queries, keys and values $h$ times with different, learned\nlinear projections to $d_k$ , $d_k$ and $d_v$ dimensions, respectively. On each of these projected versions of\nqueries, keys and values we then perform the attention function in parallel, yielding $d_v$ -dimensional\noutput values. These are concatenated and once again projected, resulting in the final values, as\ndepicted in Figure 2 .",
|
89 |
-
"reading_order": 7
|
90 |
-
},
|
91 |
-
{
|
92 |
-
"label": "para",
|
93 |
-
"bbox": [
|
94 |
-
198,
|
95 |
-
1403,
|
96 |
-
1065,
|
97 |
-
1453
|
98 |
-
],
|
99 |
-
"text": "Multihead attention allows the model to jointly attend to information from different representation\nsubspaces at different positions. With a single attention head, averaging inhibits this.",
|
100 |
-
"reading_order": 8
|
101 |
-
},
|
102 |
-
{
|
103 |
-
"label": "fnote",
|
104 |
-
"bbox": [
|
105 |
-
198,
|
106 |
-
1485,
|
107 |
-
1065,
|
108 |
-
1535
|
109 |
-
],
|
110 |
-
"text": "${ }^{4}$ To illustrate why the dot products get large, assume that the components of $q$ and $k$ are independent random\nvariables with mean 0 and variance 1 . Then their dot product, $q \\cdot k=\\sum_{i=1}^{d_{k}} q_{i} k_{i}$, has mean 0 and variance $d_{k}$.",
|
111 |
-
"reading_order": 9
|
112 |
-
},
|
113 |
-
{
|
114 |
-
"label": "foot",
|
115 |
-
"bbox": [
|
116 |
-
625,
|
117 |
-
1578,
|
118 |
-
641,
|
119 |
-
1599
|
120 |
-
],
|
121 |
-
"text": "4",
|
122 |
-
"reading_order": 10
|
123 |
-
}
|
124 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/page_imgs/test_page2.jpeg
DELETED
Git LFS Details
|
demo/page_imgs/test_page3.jpeg
DELETED
Git LFS Details
|
demo_element.py
DELETED
@@ -1,129 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
3 |
-
SPDX-License-Identifier: MIT
|
4 |
-
"""
|
5 |
-
|
6 |
-
import argparse
|
7 |
-
import glob
|
8 |
-
import os
|
9 |
-
|
10 |
-
from omegaconf import OmegaConf
|
11 |
-
from PIL import Image
|
12 |
-
|
13 |
-
from chat import DOLPHIN
|
14 |
-
from utils.utils import *
|
15 |
-
|
16 |
-
|
17 |
-
def process_element(image_path, model, element_type, save_dir=None):
|
18 |
-
"""Process a single element image (text, table, formula)
|
19 |
-
|
20 |
-
Args:
|
21 |
-
image_path: Path to the element image
|
22 |
-
model: DOLPHIN model instance
|
23 |
-
element_type: Type of element ('text', 'table', 'formula')
|
24 |
-
save_dir: Directory to save results (default: same as input directory)
|
25 |
-
|
26 |
-
Returns:
|
27 |
-
Parsed content of the element and recognition results
|
28 |
-
"""
|
29 |
-
# Load and prepare image
|
30 |
-
pil_image = Image.open(image_path).convert("RGB")
|
31 |
-
pil_image = crop_margin(pil_image)
|
32 |
-
|
33 |
-
# Select appropriate prompt based on element type
|
34 |
-
if element_type == "table":
|
35 |
-
prompt = "Parse the table in the image."
|
36 |
-
label = "tab"
|
37 |
-
elif element_type == "formula":
|
38 |
-
prompt = "Read text in the image."
|
39 |
-
label = "formula"
|
40 |
-
else: # Default to text
|
41 |
-
prompt = "Read text in the image."
|
42 |
-
label = "text"
|
43 |
-
|
44 |
-
# Process the element
|
45 |
-
result = model.chat(prompt, pil_image)
|
46 |
-
|
47 |
-
# Create recognition result in the same format as the document parser
|
48 |
-
recognition_result = [
|
49 |
-
{
|
50 |
-
"label": label,
|
51 |
-
"text": result.strip(),
|
52 |
-
}
|
53 |
-
]
|
54 |
-
|
55 |
-
# Save results if save_dir is provided
|
56 |
-
if save_dir:
|
57 |
-
save_outputs(recognition_result, image_path, save_dir)
|
58 |
-
print(f"Results saved to {save_dir}")
|
59 |
-
|
60 |
-
return result, recognition_result
|
61 |
-
|
62 |
-
|
63 |
-
def main():
|
64 |
-
parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
|
65 |
-
parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
|
66 |
-
parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
|
67 |
-
parser.add_argument(
|
68 |
-
"--element_type",
|
69 |
-
type=str,
|
70 |
-
choices=["text", "table", "formula"],
|
71 |
-
default="text",
|
72 |
-
help="Type of element to process (text, table, formula)",
|
73 |
-
)
|
74 |
-
parser.add_argument(
|
75 |
-
"--save_dir",
|
76 |
-
type=str,
|
77 |
-
default=None,
|
78 |
-
help="Directory to save parsing results (default: same as input directory)",
|
79 |
-
)
|
80 |
-
parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
|
81 |
-
args = parser.parse_args()
|
82 |
-
|
83 |
-
# Load Model
|
84 |
-
config = OmegaConf.load(args.config)
|
85 |
-
model = DOLPHIN(config)
|
86 |
-
|
87 |
-
# Set save directory
|
88 |
-
save_dir = args.save_dir or (
|
89 |
-
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
|
90 |
-
)
|
91 |
-
setup_output_dirs(save_dir)
|
92 |
-
|
93 |
-
# Collect Images
|
94 |
-
if os.path.isdir(args.input_path):
|
95 |
-
image_files = []
|
96 |
-
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
|
97 |
-
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
|
98 |
-
image_files = sorted(image_files)
|
99 |
-
else:
|
100 |
-
if not os.path.exists(args.input_path):
|
101 |
-
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
|
102 |
-
image_files = [args.input_path]
|
103 |
-
|
104 |
-
total_samples = len(image_files)
|
105 |
-
print(f"\nTotal samples to process: {total_samples}")
|
106 |
-
|
107 |
-
# Process images one by one
|
108 |
-
for image_path in image_files:
|
109 |
-
print(f"\nProcessing {image_path}")
|
110 |
-
try:
|
111 |
-
result, recognition_result = process_element(
|
112 |
-
image_path=image_path,
|
113 |
-
model=model,
|
114 |
-
element_type=args.element_type,
|
115 |
-
save_dir=save_dir,
|
116 |
-
)
|
117 |
-
|
118 |
-
if args.print_results:
|
119 |
-
print("\nRecognition result:")
|
120 |
-
print(result)
|
121 |
-
print("-" * 40)
|
122 |
-
|
123 |
-
except Exception as e:
|
124 |
-
print(f"Error processing {image_path}: {str(e)}")
|
125 |
-
continue
|
126 |
-
|
127 |
-
|
128 |
-
if __name__ == "__main__":
|
129 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo_element_hf.py
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
3 |
-
SPDX-License-Identifier: MIT
|
4 |
-
"""
|
5 |
-
|
6 |
-
import argparse
|
7 |
-
import glob
|
8 |
-
import os
|
9 |
-
|
10 |
-
import torch
|
11 |
-
from PIL import Image
|
12 |
-
from transformers import AutoProcessor, VisionEncoderDecoderModel
|
13 |
-
|
14 |
-
from utils.utils import *
|
15 |
-
|
16 |
-
|
17 |
-
class DOLPHIN:
|
18 |
-
def __init__(self, model_id_or_path):
|
19 |
-
"""Initialize the Hugging Face model
|
20 |
-
|
21 |
-
Args:
|
22 |
-
model_id_or_path: Path to local model or Hugging Face model ID
|
23 |
-
"""
|
24 |
-
# Load model from local path or Hugging Face hub
|
25 |
-
self.processor = AutoProcessor.from_pretrained(model_id_or_path)
|
26 |
-
self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
|
27 |
-
self.model.eval()
|
28 |
-
|
29 |
-
# Set device and precision
|
30 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
-
self.model.to(self.device)
|
32 |
-
self.model = self.model.half() # Always use half precision by default
|
33 |
-
|
34 |
-
# set tokenizer
|
35 |
-
self.tokenizer = self.processor.tokenizer
|
36 |
-
|
37 |
-
def chat(self, prompt, image):
|
38 |
-
"""Process an image with the given prompt
|
39 |
-
|
40 |
-
Args:
|
41 |
-
prompt: Text prompt to guide the model
|
42 |
-
image: PIL Image to process
|
43 |
-
|
44 |
-
Returns:
|
45 |
-
Generated text from the model
|
46 |
-
"""
|
47 |
-
# Prepare image
|
48 |
-
pixel_values = self.processor(image, return_tensors="pt").pixel_values
|
49 |
-
pixel_values = pixel_values.half()
|
50 |
-
|
51 |
-
# Prepare prompt
|
52 |
-
prompt = f"<s>{prompt} <Answer/>"
|
53 |
-
prompt_ids = self.tokenizer(
|
54 |
-
prompt,
|
55 |
-
add_special_tokens=False,
|
56 |
-
return_tensors="pt"
|
57 |
-
).input_ids.to(self.device)
|
58 |
-
|
59 |
-
decoder_attention_mask = torch.ones_like(prompt_ids)
|
60 |
-
|
61 |
-
# Generate text
|
62 |
-
outputs = self.model.generate(
|
63 |
-
pixel_values=pixel_values.to(self.device),
|
64 |
-
decoder_input_ids=prompt_ids,
|
65 |
-
decoder_attention_mask=decoder_attention_mask,
|
66 |
-
min_length=1,
|
67 |
-
max_length=4096,
|
68 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
69 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
70 |
-
use_cache=True,
|
71 |
-
bad_words_ids=[[self.tokenizer.unk_token_id]],
|
72 |
-
return_dict_in_generate=True,
|
73 |
-
do_sample=False,
|
74 |
-
num_beams=1,
|
75 |
-
repetition_penalty=1.1,
|
76 |
-
temperature=1.0
|
77 |
-
)
|
78 |
-
|
79 |
-
# Process the output
|
80 |
-
sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
|
81 |
-
sequence = sequence.replace(prompt, "").replace("<pad>", "").replace("</s>", "").strip()
|
82 |
-
|
83 |
-
return sequence
|
84 |
-
|
85 |
-
def process_element(image_path, model, element_type, save_dir=None):
|
86 |
-
"""Process a single element image (text, table, formula)
|
87 |
-
|
88 |
-
Args:
|
89 |
-
image_path: Path to the element image
|
90 |
-
model: HFModel model instance
|
91 |
-
element_type: Type of element ('text', 'table', 'formula')
|
92 |
-
save_dir: Directory to save results (default: same as input directory)
|
93 |
-
|
94 |
-
Returns:
|
95 |
-
Parsed content of the element and recognition results
|
96 |
-
"""
|
97 |
-
# Load and prepare image
|
98 |
-
pil_image = Image.open(image_path).convert("RGB")
|
99 |
-
pil_image = crop_margin(pil_image)
|
100 |
-
|
101 |
-
# Select appropriate prompt based on element type
|
102 |
-
if element_type == "table":
|
103 |
-
prompt = "Parse the table in the image."
|
104 |
-
label = "tab"
|
105 |
-
elif element_type == "formula":
|
106 |
-
prompt = "Read text in the image."
|
107 |
-
label = "formula"
|
108 |
-
else: # Default to text
|
109 |
-
prompt = "Read text in the image."
|
110 |
-
label = "text"
|
111 |
-
|
112 |
-
# Process the element
|
113 |
-
result = model.chat(prompt, pil_image)
|
114 |
-
|
115 |
-
# Create recognition result in the same format as the document parser
|
116 |
-
recognition_result = [
|
117 |
-
{
|
118 |
-
"label": label,
|
119 |
-
"text": result.strip(),
|
120 |
-
}
|
121 |
-
]
|
122 |
-
|
123 |
-
# Save results if save_dir is provided
|
124 |
-
if save_dir:
|
125 |
-
save_outputs(recognition_result, image_path, save_dir)
|
126 |
-
print(f"Results saved to {save_dir}")
|
127 |
-
|
128 |
-
return result, recognition_result
|
129 |
-
|
130 |
-
|
131 |
-
def main():
|
132 |
-
parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
|
133 |
-
parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
|
134 |
-
parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
|
135 |
-
parser.add_argument(
|
136 |
-
"--element_type",
|
137 |
-
type=str,
|
138 |
-
choices=["text", "table", "formula"],
|
139 |
-
default="text",
|
140 |
-
help="Type of element to process (text, table, formula)",
|
141 |
-
)
|
142 |
-
parser.add_argument(
|
143 |
-
"--save_dir",
|
144 |
-
type=str,
|
145 |
-
default=None,
|
146 |
-
help="Directory to save parsing results (default: same as input directory)",
|
147 |
-
)
|
148 |
-
parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
|
149 |
-
args = parser.parse_args()
|
150 |
-
|
151 |
-
# Load Model
|
152 |
-
model = DOLPHIN(args.model_path)
|
153 |
-
|
154 |
-
# Set save directory
|
155 |
-
save_dir = args.save_dir or (
|
156 |
-
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
|
157 |
-
)
|
158 |
-
setup_output_dirs(save_dir)
|
159 |
-
|
160 |
-
# Collect Images
|
161 |
-
if os.path.isdir(args.input_path):
|
162 |
-
image_files = []
|
163 |
-
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
|
164 |
-
image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
|
165 |
-
image_files = sorted(image_files)
|
166 |
-
else:
|
167 |
-
if not os.path.exists(args.input_path):
|
168 |
-
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
|
169 |
-
image_files = [args.input_path]
|
170 |
-
|
171 |
-
total_samples = len(image_files)
|
172 |
-
print(f"\nTotal samples to process: {total_samples}")
|
173 |
-
|
174 |
-
# Process images one by one
|
175 |
-
for image_path in image_files:
|
176 |
-
print(f"\nProcessing {image_path}")
|
177 |
-
try:
|
178 |
-
result, recognition_result = process_element(
|
179 |
-
image_path=image_path,
|
180 |
-
model=model,
|
181 |
-
element_type=args.element_type,
|
182 |
-
save_dir=save_dir,
|
183 |
-
)
|
184 |
-
|
185 |
-
if args.print_results:
|
186 |
-
print("\nRecognition result:")
|
187 |
-
print(result)
|
188 |
-
print("-" * 40)
|
189 |
-
except Exception as e:
|
190 |
-
print(f"Error processing {image_path}: {str(e)}")
|
191 |
-
continue
|
192 |
-
|
193 |
-
|
194 |
-
if __name__ == "__main__":
|
195 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo_page.py
DELETED
@@ -1,247 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
3 |
-
SPDX-License-Identifier: MIT
|
4 |
-
"""
|
5 |
-
|
6 |
-
import argparse
|
7 |
-
import glob
|
8 |
-
import os
|
9 |
-
|
10 |
-
import cv2
|
11 |
-
from omegaconf import OmegaConf
|
12 |
-
from PIL import Image
|
13 |
-
|
14 |
-
from chat import DOLPHIN
|
15 |
-
from utils.utils import *
|
16 |
-
|
17 |
-
|
18 |
-
def process_document(document_path, model, save_dir, max_batch_size):
|
19 |
-
"""Parse documents - Handles both images and PDFs"""
|
20 |
-
file_ext = os.path.splitext(document_path)[1].lower()
|
21 |
-
|
22 |
-
if file_ext == '.pdf':
|
23 |
-
# Process PDF file
|
24 |
-
# Convert PDF to images
|
25 |
-
images = convert_pdf_to_images(document_path)
|
26 |
-
if not images:
|
27 |
-
raise Exception(f"Failed to convert PDF {document_path} to images")
|
28 |
-
|
29 |
-
all_results = []
|
30 |
-
|
31 |
-
# Process each page
|
32 |
-
for page_idx, pil_image in enumerate(images):
|
33 |
-
print(f"Processing page {page_idx + 1}/{len(images)}")
|
34 |
-
|
35 |
-
# Generate output name for this page
|
36 |
-
base_name = os.path.splitext(os.path.basename(document_path))[0]
|
37 |
-
page_name = f"{base_name}_page_{page_idx + 1:03d}"
|
38 |
-
|
39 |
-
# Process this page (don't save individual page results)
|
40 |
-
json_path, recognition_results = process_single_image(
|
41 |
-
pil_image, model, save_dir, page_name, max_batch_size, save_individual=False
|
42 |
-
)
|
43 |
-
|
44 |
-
# Add page information to results
|
45 |
-
page_results = {
|
46 |
-
"page_number": page_idx + 1,
|
47 |
-
"elements": recognition_results
|
48 |
-
}
|
49 |
-
all_results.append(page_results)
|
50 |
-
|
51 |
-
# Save combined results for multi-page PDF
|
52 |
-
combined_json_path = save_combined_pdf_results(all_results, document_path, save_dir)
|
53 |
-
|
54 |
-
return combined_json_path, all_results
|
55 |
-
|
56 |
-
else:
|
57 |
-
# Process regular image file
|
58 |
-
pil_image = Image.open(document_path).convert("RGB")
|
59 |
-
base_name = os.path.splitext(os.path.basename(document_path))[0]
|
60 |
-
return process_single_image(pil_image, model, save_dir, base_name, max_batch_size)
|
61 |
-
|
62 |
-
|
63 |
-
def process_single_image(image, model, save_dir, image_name, max_batch_size, save_individual=True):
|
64 |
-
"""Process a single image (either from file or converted from PDF page)
|
65 |
-
|
66 |
-
Args:
|
67 |
-
image: PIL Image object
|
68 |
-
model: DOLPHIN model instance
|
69 |
-
save_dir: Directory to save results
|
70 |
-
image_name: Name for the output file
|
71 |
-
max_batch_size: Maximum batch size for processing
|
72 |
-
save_individual: Whether to save individual results (False for PDF pages)
|
73 |
-
|
74 |
-
Returns:
|
75 |
-
Tuple of (json_path, recognition_results)
|
76 |
-
"""
|
77 |
-
# Stage 1: Page-level layout and reading order parsing
|
78 |
-
layout_output = model.chat("Parse the reading order of this document.", image)
|
79 |
-
|
80 |
-
# Stage 2: Element-level content parsing
|
81 |
-
padded_image, dims = prepare_image(image)
|
82 |
-
recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size, save_dir, image_name)
|
83 |
-
|
84 |
-
# Save outputs only if requested (skip for PDF pages)
|
85 |
-
json_path = None
|
86 |
-
if save_individual:
|
87 |
-
# Create a dummy image path for save_outputs function
|
88 |
-
dummy_image_path = f"{image_name}.jpg" # Extension doesn't matter, only basename is used
|
89 |
-
json_path = save_outputs(recognition_results, dummy_image_path, save_dir)
|
90 |
-
|
91 |
-
return json_path, recognition_results
|
92 |
-
|
93 |
-
|
94 |
-
def process_elements(layout_results, padded_image, dims, model, max_batch_size, save_dir=None, image_name=None):
|
95 |
-
"""Parse all document elements with parallel decoding"""
|
96 |
-
layout_results = parse_layout_string(layout_results)
|
97 |
-
|
98 |
-
text_table_elements = [] # Elements that need processing
|
99 |
-
figure_results = [] # Figure elements (no processing needed)
|
100 |
-
previous_box = None
|
101 |
-
reading_order = 0
|
102 |
-
|
103 |
-
# Collect elements for processing
|
104 |
-
for bbox, label in layout_results:
|
105 |
-
try:
|
106 |
-
# Adjust coordinates
|
107 |
-
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
|
108 |
-
bbox, padded_image, dims, previous_box
|
109 |
-
)
|
110 |
-
|
111 |
-
# Crop and parse element
|
112 |
-
cropped = padded_image[y1:y2, x1:x2]
|
113 |
-
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
|
114 |
-
if label == "fig":
|
115 |
-
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
116 |
-
|
117 |
-
figure_filename = save_figure_to_local(pil_crop, save_dir, image_name, reading_order)
|
118 |
-
|
119 |
-
# For figure regions, store relative path instead of base64
|
120 |
-
figure_results.append(
|
121 |
-
{
|
122 |
-
"label": label,
|
123 |
-
"text": f"",
|
124 |
-
"figure_path": f"figures/{figure_filename}",
|
125 |
-
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
126 |
-
"reading_order": reading_order,
|
127 |
-
}
|
128 |
-
)
|
129 |
-
else:
|
130 |
-
# For text or table regions, prepare for parsing
|
131 |
-
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
132 |
-
prompt = "Parse the table in the image." if label == "tab" else "Read text in the image."
|
133 |
-
text_table_elements.append(
|
134 |
-
{
|
135 |
-
"crop": pil_crop,
|
136 |
-
"prompt": prompt,
|
137 |
-
"label": label,
|
138 |
-
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
139 |
-
"reading_order": reading_order,
|
140 |
-
}
|
141 |
-
)
|
142 |
-
|
143 |
-
reading_order += 1
|
144 |
-
|
145 |
-
except Exception as e:
|
146 |
-
print(f"Error processing bbox with label {label}: {str(e)}")
|
147 |
-
continue
|
148 |
-
|
149 |
-
# Parse text/table elements in parallel
|
150 |
-
recognition_results = figure_results
|
151 |
-
if text_table_elements:
|
152 |
-
crops_list = [elem["crop"] for elem in text_table_elements]
|
153 |
-
prompts_list = [elem["prompt"] for elem in text_table_elements]
|
154 |
-
|
155 |
-
# Inference in batch
|
156 |
-
batch_results = model.chat(prompts_list, crops_list, max_batch_size=max_batch_size)
|
157 |
-
|
158 |
-
# Add batch results to recognition_results
|
159 |
-
for i, result in enumerate(batch_results):
|
160 |
-
elem = text_table_elements[i]
|
161 |
-
recognition_results.append(
|
162 |
-
{
|
163 |
-
"label": elem["label"],
|
164 |
-
"bbox": elem["bbox"],
|
165 |
-
"text": result.strip(),
|
166 |
-
"reading_order": elem["reading_order"],
|
167 |
-
}
|
168 |
-
)
|
169 |
-
|
170 |
-
# Sort elements by reading order
|
171 |
-
recognition_results.sort(key=lambda x: x.get("reading_order", 0))
|
172 |
-
|
173 |
-
return recognition_results
|
174 |
-
|
175 |
-
|
176 |
-
def main():
|
177 |
-
parser = argparse.ArgumentParser(description="Document parsing based on DOLPHIN")
|
178 |
-
parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
|
179 |
-
parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image/PDF or directory of files")
|
180 |
-
parser.add_argument(
|
181 |
-
"--save_dir",
|
182 |
-
type=str,
|
183 |
-
default=None,
|
184 |
-
help="Directory to save parsing results (default: same as input directory)",
|
185 |
-
)
|
186 |
-
parser.add_argument(
|
187 |
-
"--max_batch_size",
|
188 |
-
type=int,
|
189 |
-
default=4,
|
190 |
-
help="Maximum number of document elements to parse in a single batch (default: 4)",
|
191 |
-
)
|
192 |
-
args = parser.parse_args()
|
193 |
-
|
194 |
-
# Load Model
|
195 |
-
config = OmegaConf.load(args.config)
|
196 |
-
model = DOLPHIN(config)
|
197 |
-
|
198 |
-
# Collect Document Files (images and PDFs)
|
199 |
-
if os.path.isdir(args.input_path):
|
200 |
-
# Support both image and PDF files
|
201 |
-
file_extensions = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG", ".pdf", ".PDF"]
|
202 |
-
|
203 |
-
document_files = []
|
204 |
-
for ext in file_extensions:
|
205 |
-
document_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
|
206 |
-
document_files = sorted(document_files)
|
207 |
-
else:
|
208 |
-
if not os.path.exists(args.input_path):
|
209 |
-
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
|
210 |
-
|
211 |
-
# Check if it's a supported file type
|
212 |
-
file_ext = os.path.splitext(args.input_path)[1].lower()
|
213 |
-
supported_exts = ['.jpg', '.jpeg', '.png', '.pdf']
|
214 |
-
|
215 |
-
if file_ext not in supported_exts:
|
216 |
-
raise ValueError(f"Unsupported file type: {file_ext}. Supported types: {supported_exts}")
|
217 |
-
|
218 |
-
document_files = [args.input_path]
|
219 |
-
|
220 |
-
save_dir = args.save_dir or (
|
221 |
-
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
|
222 |
-
)
|
223 |
-
setup_output_dirs(save_dir)
|
224 |
-
|
225 |
-
total_samples = len(document_files)
|
226 |
-
print(f"\nTotal files to process: {total_samples}")
|
227 |
-
|
228 |
-
# Process All Document Files
|
229 |
-
for file_path in document_files:
|
230 |
-
print(f"\nProcessing {file_path}")
|
231 |
-
try:
|
232 |
-
json_path, recognition_results = process_document(
|
233 |
-
document_path=file_path,
|
234 |
-
model=model,
|
235 |
-
save_dir=save_dir,
|
236 |
-
max_batch_size=args.max_batch_size,
|
237 |
-
)
|
238 |
-
|
239 |
-
print(f"Processing completed. Results saved to {save_dir}")
|
240 |
-
|
241 |
-
except Exception as e:
|
242 |
-
print(f"Error processing {file_path}: {str(e)}")
|
243 |
-
continue
|
244 |
-
|
245 |
-
|
246 |
-
if __name__ == "__main__":
|
247 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo_page_hf.py
DELETED
@@ -1,365 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
3 |
-
SPDX-License-Identifier: MIT
|
4 |
-
"""
|
5 |
-
|
6 |
-
import argparse
|
7 |
-
import glob
|
8 |
-
import os
|
9 |
-
|
10 |
-
import cv2
|
11 |
-
import torch
|
12 |
-
from PIL import Image
|
13 |
-
from transformers import AutoProcessor, VisionEncoderDecoderModel
|
14 |
-
|
15 |
-
from utils.utils import *
|
16 |
-
|
17 |
-
|
18 |
-
class DOLPHIN:
|
19 |
-
def __init__(self, model_id_or_path):
|
20 |
-
"""Initialize the Hugging Face model
|
21 |
-
|
22 |
-
Args:
|
23 |
-
model_id_or_path: Path to local model or Hugging Face model ID
|
24 |
-
"""
|
25 |
-
# Load model from local path or Hugging Face hub
|
26 |
-
self.processor = AutoProcessor.from_pretrained(model_id_or_path)
|
27 |
-
self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
|
28 |
-
self.model.eval()
|
29 |
-
|
30 |
-
# Set device and precision
|
31 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
-
self.model.to(self.device)
|
33 |
-
self.model = self.model.half() # Always use half precision by default
|
34 |
-
|
35 |
-
# set tokenizer
|
36 |
-
self.tokenizer = self.processor.tokenizer
|
37 |
-
|
38 |
-
def chat(self, prompt, image):
|
39 |
-
"""Process an image or batch of images with the given prompt(s)
|
40 |
-
|
41 |
-
Args:
|
42 |
-
prompt: Text prompt or list of prompts to guide the model
|
43 |
-
image: PIL Image or list of PIL Images to process
|
44 |
-
|
45 |
-
Returns:
|
46 |
-
Generated text or list of texts from the model
|
47 |
-
"""
|
48 |
-
# Check if we're dealing with a batch
|
49 |
-
is_batch = isinstance(image, list)
|
50 |
-
|
51 |
-
if not is_batch:
|
52 |
-
# Single image, wrap it in a list for consistent processing
|
53 |
-
images = [image]
|
54 |
-
prompts = [prompt]
|
55 |
-
else:
|
56 |
-
# Batch of images
|
57 |
-
images = image
|
58 |
-
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
|
59 |
-
|
60 |
-
# Prepare image
|
61 |
-
batch_inputs = self.processor(images, return_tensors="pt", padding=True)
|
62 |
-
batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
|
63 |
-
|
64 |
-
# Prepare prompt
|
65 |
-
prompts = [f"<s>{p} <Answer/>" for p in prompts]
|
66 |
-
batch_prompt_inputs = self.tokenizer(
|
67 |
-
prompts,
|
68 |
-
add_special_tokens=False,
|
69 |
-
return_tensors="pt"
|
70 |
-
)
|
71 |
-
|
72 |
-
batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
|
73 |
-
batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
|
74 |
-
|
75 |
-
# Generate text
|
76 |
-
outputs = self.model.generate(
|
77 |
-
pixel_values=batch_pixel_values,
|
78 |
-
decoder_input_ids=batch_prompt_ids,
|
79 |
-
decoder_attention_mask=batch_attention_mask,
|
80 |
-
min_length=1,
|
81 |
-
max_length=4096,
|
82 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
83 |
-
eos_token_id=self.tokenizer.eos_token_id,
|
84 |
-
use_cache=True,
|
85 |
-
bad_words_ids=[[self.tokenizer.unk_token_id]],
|
86 |
-
return_dict_in_generate=True,
|
87 |
-
do_sample=False,
|
88 |
-
num_beams=1,
|
89 |
-
repetition_penalty=1.1,
|
90 |
-
temperature=1.0
|
91 |
-
)
|
92 |
-
|
93 |
-
# Process output
|
94 |
-
sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
|
95 |
-
|
96 |
-
# Clean prompt text from output
|
97 |
-
results = []
|
98 |
-
for i, sequence in enumerate(sequences):
|
99 |
-
cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
|
100 |
-
results.append(cleaned)
|
101 |
-
|
102 |
-
# Return a single result for single image input
|
103 |
-
if not is_batch:
|
104 |
-
return results[0]
|
105 |
-
return results
|
106 |
-
|
107 |
-
|
108 |
-
def process_document(document_path, model, save_dir, max_batch_size=None):
|
109 |
-
"""Parse documents with two stages - Handles both images and PDFs"""
|
110 |
-
file_ext = os.path.splitext(document_path)[1].lower()
|
111 |
-
|
112 |
-
if file_ext == '.pdf':
|
113 |
-
# Process PDF file
|
114 |
-
# Convert PDF to images
|
115 |
-
images = convert_pdf_to_images(document_path)
|
116 |
-
if not images:
|
117 |
-
raise Exception(f"Failed to convert PDF {document_path} to images")
|
118 |
-
|
119 |
-
all_results = []
|
120 |
-
|
121 |
-
# Process each page
|
122 |
-
for page_idx, pil_image in enumerate(images):
|
123 |
-
print(f"Processing page {page_idx + 1}/{len(images)}")
|
124 |
-
|
125 |
-
# Generate output name for this page
|
126 |
-
base_name = os.path.splitext(os.path.basename(document_path))[0]
|
127 |
-
page_name = f"{base_name}_page_{page_idx + 1:03d}"
|
128 |
-
|
129 |
-
# Process this page (don't save individual page results)
|
130 |
-
json_path, recognition_results = process_single_image(
|
131 |
-
pil_image, model, save_dir, page_name, max_batch_size, save_individual=False
|
132 |
-
)
|
133 |
-
|
134 |
-
# Add page information to results
|
135 |
-
page_results = {
|
136 |
-
"page_number": page_idx + 1,
|
137 |
-
"elements": recognition_results
|
138 |
-
}
|
139 |
-
all_results.append(page_results)
|
140 |
-
|
141 |
-
# Save combined results for multi-page PDF
|
142 |
-
combined_json_path = save_combined_pdf_results(all_results, document_path, save_dir)
|
143 |
-
|
144 |
-
return combined_json_path, all_results
|
145 |
-
|
146 |
-
else:
|
147 |
-
# Process regular image file
|
148 |
-
pil_image = Image.open(document_path).convert("RGB")
|
149 |
-
base_name = os.path.splitext(os.path.basename(document_path))[0]
|
150 |
-
return process_single_image(pil_image, model, save_dir, base_name, max_batch_size)
|
151 |
-
|
152 |
-
|
153 |
-
def process_single_image(image, model, save_dir, image_name, max_batch_size=None, save_individual=True):
|
154 |
-
"""Process a single image (either from file or converted from PDF page)
|
155 |
-
|
156 |
-
Args:
|
157 |
-
image: PIL Image object
|
158 |
-
model: DOLPHIN model instance
|
159 |
-
save_dir: Directory to save results
|
160 |
-
image_name: Name for the output file
|
161 |
-
max_batch_size: Maximum batch size for processing
|
162 |
-
save_individual: Whether to save individual results (False for PDF pages)
|
163 |
-
|
164 |
-
Returns:
|
165 |
-
Tuple of (json_path, recognition_results)
|
166 |
-
"""
|
167 |
-
# Stage 1: Page-level layout and reading order parsing
|
168 |
-
layout_output = model.chat("Parse the reading order of this document.", image)
|
169 |
-
|
170 |
-
# Stage 2: Element-level content parsing
|
171 |
-
padded_image, dims = prepare_image(image)
|
172 |
-
recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size, save_dir, image_name)
|
173 |
-
|
174 |
-
# Save outputs only if requested (skip for PDF pages)
|
175 |
-
json_path = None
|
176 |
-
if save_individual:
|
177 |
-
# Create a dummy image path for save_outputs function
|
178 |
-
dummy_image_path = f"{image_name}.jpg" # Extension doesn't matter, only basename is used
|
179 |
-
json_path = save_outputs(recognition_results, dummy_image_path, save_dir)
|
180 |
-
|
181 |
-
return json_path, recognition_results
|
182 |
-
|
183 |
-
|
184 |
-
def process_elements(layout_results, padded_image, dims, model, max_batch_size, save_dir=None, image_name=None):
|
185 |
-
"""Parse all document elements with parallel decoding"""
|
186 |
-
layout_results = parse_layout_string(layout_results)
|
187 |
-
|
188 |
-
# Store text and table elements separately
|
189 |
-
text_elements = [] # Text elements
|
190 |
-
table_elements = [] # Table elements
|
191 |
-
figure_results = [] # Image elements (no processing needed)
|
192 |
-
previous_box = None
|
193 |
-
reading_order = 0
|
194 |
-
|
195 |
-
# Collect elements to process and group by type
|
196 |
-
for bbox, label in layout_results:
|
197 |
-
try:
|
198 |
-
# Adjust coordinates
|
199 |
-
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
|
200 |
-
bbox, padded_image, dims, previous_box
|
201 |
-
)
|
202 |
-
|
203 |
-
# Crop and parse element
|
204 |
-
cropped = padded_image[y1:y2, x1:x2]
|
205 |
-
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
|
206 |
-
if label == "fig":
|
207 |
-
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
208 |
-
|
209 |
-
figure_filename = save_figure_to_local(pil_crop, save_dir, image_name, reading_order)
|
210 |
-
|
211 |
-
# For figure regions, store relative path instead of base64
|
212 |
-
figure_results.append(
|
213 |
-
{
|
214 |
-
"label": label,
|
215 |
-
"text": f"",
|
216 |
-
"figure_path": f"figures/{figure_filename}",
|
217 |
-
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
218 |
-
"reading_order": reading_order,
|
219 |
-
}
|
220 |
-
)
|
221 |
-
else:
|
222 |
-
# Prepare element for parsing
|
223 |
-
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
224 |
-
element_info = {
|
225 |
-
"crop": pil_crop,
|
226 |
-
"label": label,
|
227 |
-
"bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
|
228 |
-
"reading_order": reading_order,
|
229 |
-
}
|
230 |
-
|
231 |
-
# Group by type
|
232 |
-
if label == "tab":
|
233 |
-
table_elements.append(element_info)
|
234 |
-
else: # Text elements
|
235 |
-
text_elements.append(element_info)
|
236 |
-
|
237 |
-
reading_order += 1
|
238 |
-
|
239 |
-
except Exception as e:
|
240 |
-
print(f"Error processing bbox with label {label}: {str(e)}")
|
241 |
-
continue
|
242 |
-
|
243 |
-
# Initialize results list
|
244 |
-
recognition_results = figure_results.copy()
|
245 |
-
|
246 |
-
# Process text elements (in batches)
|
247 |
-
if text_elements:
|
248 |
-
text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size)
|
249 |
-
recognition_results.extend(text_results)
|
250 |
-
|
251 |
-
# Process table elements (in batches)
|
252 |
-
if table_elements:
|
253 |
-
table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size)
|
254 |
-
recognition_results.extend(table_results)
|
255 |
-
|
256 |
-
# Sort elements by reading order
|
257 |
-
recognition_results.sort(key=lambda x: x.get("reading_order", 0))
|
258 |
-
|
259 |
-
return recognition_results
|
260 |
-
|
261 |
-
|
262 |
-
def process_element_batch(elements, model, prompt, max_batch_size=None):
|
263 |
-
"""Process elements of the same type in batches"""
|
264 |
-
results = []
|
265 |
-
|
266 |
-
# Determine batch size
|
267 |
-
batch_size = len(elements)
|
268 |
-
if max_batch_size is not None and max_batch_size > 0:
|
269 |
-
batch_size = min(batch_size, max_batch_size)
|
270 |
-
|
271 |
-
# Process in batches
|
272 |
-
for i in range(0, len(elements), batch_size):
|
273 |
-
batch_elements = elements[i:i+batch_size]
|
274 |
-
crops_list = [elem["crop"] for elem in batch_elements]
|
275 |
-
|
276 |
-
# Use the same prompt for all elements in the batch
|
277 |
-
prompts_list = [prompt] * len(crops_list)
|
278 |
-
|
279 |
-
# Batch inference
|
280 |
-
batch_results = model.chat(prompts_list, crops_list)
|
281 |
-
|
282 |
-
# Add results
|
283 |
-
for j, result in enumerate(batch_results):
|
284 |
-
elem = batch_elements[j]
|
285 |
-
results.append({
|
286 |
-
"label": elem["label"],
|
287 |
-
"bbox": elem["bbox"],
|
288 |
-
"text": result.strip(),
|
289 |
-
"reading_order": elem["reading_order"],
|
290 |
-
})
|
291 |
-
|
292 |
-
return results
|
293 |
-
|
294 |
-
|
295 |
-
def main():
|
296 |
-
parser = argparse.ArgumentParser(description="Document parsing based on DOLPHIN")
|
297 |
-
parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
|
298 |
-
parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image/PDF or directory of files")
|
299 |
-
parser.add_argument(
|
300 |
-
"--save_dir",
|
301 |
-
type=str,
|
302 |
-
default=None,
|
303 |
-
help="Directory to save parsing results (default: same as input directory)",
|
304 |
-
)
|
305 |
-
parser.add_argument(
|
306 |
-
"--max_batch_size",
|
307 |
-
type=int,
|
308 |
-
default=16,
|
309 |
-
help="Maximum number of document elements to parse in a single batch (default: 16)",
|
310 |
-
)
|
311 |
-
args = parser.parse_args()
|
312 |
-
|
313 |
-
# Load Model
|
314 |
-
model = DOLPHIN(args.model_path)
|
315 |
-
|
316 |
-
# Collect Document Files (images and PDFs)
|
317 |
-
if os.path.isdir(args.input_path):
|
318 |
-
# Support both image and PDF files
|
319 |
-
file_extensions = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG", ".pdf", ".PDF"]
|
320 |
-
|
321 |
-
document_files = []
|
322 |
-
for ext in file_extensions:
|
323 |
-
document_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
|
324 |
-
document_files = sorted(document_files)
|
325 |
-
else:
|
326 |
-
if not os.path.exists(args.input_path):
|
327 |
-
raise FileNotFoundError(f"Input path {args.input_path} does not exist")
|
328 |
-
|
329 |
-
# Check if it's a supported file type
|
330 |
-
file_ext = os.path.splitext(args.input_path)[1].lower()
|
331 |
-
supported_exts = ['.jpg', '.jpeg', '.png', '.pdf']
|
332 |
-
|
333 |
-
if file_ext not in supported_exts:
|
334 |
-
raise ValueError(f"Unsupported file type: {file_ext}. Supported types: {supported_exts}")
|
335 |
-
|
336 |
-
document_files = [args.input_path]
|
337 |
-
|
338 |
-
save_dir = args.save_dir or (
|
339 |
-
args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
|
340 |
-
)
|
341 |
-
setup_output_dirs(save_dir)
|
342 |
-
|
343 |
-
total_samples = len(document_files)
|
344 |
-
print(f"\nTotal files to process: {total_samples}")
|
345 |
-
|
346 |
-
# Process All Document Files
|
347 |
-
for file_path in document_files:
|
348 |
-
print(f"\nProcessing {file_path}")
|
349 |
-
try:
|
350 |
-
json_path, recognition_results = process_document(
|
351 |
-
document_path=file_path,
|
352 |
-
model=model,
|
353 |
-
save_dir=save_dir,
|
354 |
-
max_batch_size=args.max_batch_size,
|
355 |
-
)
|
356 |
-
|
357 |
-
print(f"Processing completed. Results saved to {save_dir}")
|
358 |
-
|
359 |
-
except Exception as e:
|
360 |
-
print(f"Error processing {file_path}: {str(e)}")
|
361 |
-
continue
|
362 |
-
|
363 |
-
|
364 |
-
if __name__ == "__main__":
|
365 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deployment/ReadMe.md
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
<h1 align="center">
|
2 |
-
🚀 Dolphin Inference/Serving
|
3 |
-
</h1>
|
4 |
-
|
5 |
-
## vLLM
|
6 |
-
> [Doc](./vllm/ReadMe.md)
|
7 |
-
|
8 |
-
## TensorRT-LLM
|
9 |
-
> [Doc](./tensorrt_llm/ReadMe.md)
|
10 |
-
|
11 |
-
## Others
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deployment/tensorrt_llm/ReadMe.md
DELETED
@@ -1,89 +0,0 @@
|
|
1 |
-
<h1 align="center">
|
2 |
-
🚀 Dolphin TensorRT-LLM Demo
|
3 |
-
</h1>
|
4 |
-
|
5 |
-
## ✅ Introduction
|
6 |
-
The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json),
|
7 |
-
its architectures field is specified as "VisionEncoderDecoderModel". **Dolphin**, **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)**, and **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** share the same model architecture. TensorRT-LLM has already supported the Nougat model.
|
8 |
-
Following Nougat's conversion script, we have successfully implemented Dolphin on TensorRT-LLM.
|
9 |
-
|
10 |
-
**Note:** [prompt_ids](./dolphin_runner.py#L120) MUST be of **int32** type, otherwise TensorRT-LLM will produce incorrect results.
|
11 |
-
|
12 |
-
## 🛠️ Installation
|
13 |
-
> We only test TensorRT-LLM 0.18.1 on Linux.
|
14 |
-
|
15 |
-
https://nvidia.github.io/TensorRT-LLM/0.18.1/installation/linux.html
|
16 |
-
|
17 |
-
|
18 |
-
## ⚡ Offline Inference
|
19 |
-
```
|
20 |
-
export MODEL_NAME="Dolphin"
|
21 |
-
|
22 |
-
# predict elements reading order
|
23 |
-
python run_dolphin.py \
|
24 |
-
--batch_size 1 \
|
25 |
-
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
26 |
-
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
27 |
-
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
28 |
-
--max_new_tokens 4096 \
|
29 |
-
--repetition_penalty 1.0 \
|
30 |
-
--input_text "Parse the reading order of this document." \
|
31 |
-
--image_path "../../demo/page_imgs/page_1.jpeg"
|
32 |
-
|
33 |
-
# recognize text/latex
|
34 |
-
python run_dolphin.py \
|
35 |
-
--batch_size 1 \
|
36 |
-
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
37 |
-
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
38 |
-
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
39 |
-
--max_new_tokens 4096 \
|
40 |
-
--repetition_penalty 1.0 \
|
41 |
-
--input_text "Read text in the image." \
|
42 |
-
--image_path "../../demo/element_imgs/block_formula.jpeg"
|
43 |
-
|
44 |
-
|
45 |
-
python run_dolphin.py \
|
46 |
-
--batch_size 1 \
|
47 |
-
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
48 |
-
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
49 |
-
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
50 |
-
--max_new_tokens 4096 \
|
51 |
-
--repetition_penalty 1.0 \
|
52 |
-
--input_text "Read text in the image." \
|
53 |
-
--image_path "../../demo/element_imgs/para_1.jpg"
|
54 |
-
|
55 |
-
# recognize table
|
56 |
-
python run_dolphin.py \
|
57 |
-
--batch_size 1 \
|
58 |
-
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
59 |
-
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
60 |
-
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
61 |
-
--max_new_tokens 4096 \
|
62 |
-
--repetition_penalty 1.0 \
|
63 |
-
--input_text "Parse the table in the image." \
|
64 |
-
--image_path "../../demo/element_imgs/table_1.jpeg"
|
65 |
-
```
|
66 |
-
|
67 |
-
|
68 |
-
## ⚡ Online Inference
|
69 |
-
```
|
70 |
-
# 1. Start Api Server
|
71 |
-
export MODEL_NAME="Dolphin"
|
72 |
-
|
73 |
-
python api_server.py \
|
74 |
-
--hf_model_dir tmp/hf_models/${MODEL_NAME} \
|
75 |
-
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
|
76 |
-
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
|
77 |
-
--max_batch_size 16
|
78 |
-
|
79 |
-
# 2. Predict
|
80 |
-
# predict elements reading order
|
81 |
-
python deployment/tensorrt_llm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
|
82 |
-
|
83 |
-
# recognize text/latex
|
84 |
-
python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
|
85 |
-
python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
|
86 |
-
|
87 |
-
# recognize table
|
88 |
-
python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
|
89 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deployment/tensorrt_llm/api_client.py
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
# SPDX-License-Identifier: Apache-2.0
|
2 |
-
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 |
-
"""Example Python client for `vllm.entrypoints.api_server`
|
4 |
-
Start the demo server:
|
5 |
-
python -m vllm.entrypoints.api_server --model <model_name>
|
6 |
-
|
7 |
-
NOTE: The API server is used only for demonstration and simple performance
|
8 |
-
benchmarks. It is not intended for production use.
|
9 |
-
For production use, we recommend `vllm serve` and the OpenAI client API.
|
10 |
-
"""
|
11 |
-
|
12 |
-
import argparse
|
13 |
-
import base64
|
14 |
-
import json
|
15 |
-
from argparse import Namespace
|
16 |
-
from collections.abc import Iterable
|
17 |
-
|
18 |
-
import requests
|
19 |
-
|
20 |
-
|
21 |
-
def clear_line(n: int = 1) -> None:
|
22 |
-
LINE_UP = "\033[1A"
|
23 |
-
LINE_CLEAR = "\x1b[2K"
|
24 |
-
for _ in range(n):
|
25 |
-
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
26 |
-
|
27 |
-
|
28 |
-
def encode_image_base64(image_path: str) -> str:
|
29 |
-
"""Encode local image to base64 format."""
|
30 |
-
|
31 |
-
with open(image_path, "rb") as f:
|
32 |
-
image_data = f.read()
|
33 |
-
result = base64.b64encode(image_data).decode("utf-8")
|
34 |
-
|
35 |
-
return result
|
36 |
-
|
37 |
-
|
38 |
-
def post_http_request(
|
39 |
-
prompt: str, image_path: str, api_url: str, stream: bool = False
|
40 |
-
) -> requests.Response:
|
41 |
-
headers = {"User-Agent": "Test Client"}
|
42 |
-
pload = {
|
43 |
-
"prompt": prompt,
|
44 |
-
"image_base64": encode_image_base64(image_path),
|
45 |
-
}
|
46 |
-
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
|
47 |
-
return response
|
48 |
-
|
49 |
-
|
50 |
-
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
|
51 |
-
for chunk in response.iter_lines(
|
52 |
-
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
|
53 |
-
):
|
54 |
-
if chunk:
|
55 |
-
data = json.loads(chunk.decode("utf-8"))
|
56 |
-
output = data["text"]
|
57 |
-
yield output
|
58 |
-
|
59 |
-
|
60 |
-
def get_response(response: requests.Response) -> list[str]:
|
61 |
-
data = json.loads(response.content)
|
62 |
-
output = data["text"]
|
63 |
-
return output
|
64 |
-
|
65 |
-
|
66 |
-
def parse_args():
|
67 |
-
parser = argparse.ArgumentParser()
|
68 |
-
parser.add_argument("--host", type=str, default="localhost")
|
69 |
-
parser.add_argument("--port", type=int, default=8000)
|
70 |
-
parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
|
71 |
-
parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
|
72 |
-
parser.add_argument("--stream", action="store_true")
|
73 |
-
return parser.parse_args()
|
74 |
-
|
75 |
-
|
76 |
-
def main(args: Namespace):
|
77 |
-
prompt = args.prompt
|
78 |
-
image_path = args.image_path
|
79 |
-
api_url = f"http://{args.host}:{args.port}/generate"
|
80 |
-
stream = args.stream
|
81 |
-
|
82 |
-
print(f"Prompt: {prompt!r}\n", flush=True)
|
83 |
-
response = post_http_request(prompt, image_path, api_url, stream)
|
84 |
-
|
85 |
-
if stream:
|
86 |
-
num_printed_lines = 0
|
87 |
-
for h in get_streaming_response(response):
|
88 |
-
clear_line(num_printed_lines)
|
89 |
-
num_printed_lines = 0
|
90 |
-
for i, line in enumerate(h):
|
91 |
-
num_printed_lines += 1
|
92 |
-
print(f"Response {i}: {line!r}", flush=True)
|
93 |
-
else:
|
94 |
-
output = get_response(response)
|
95 |
-
print(f"Response: {output!r}", flush=True)
|
96 |
-
|
97 |
-
|
98 |
-
if __name__ == "__main__":
|
99 |
-
args = parse_args()
|
100 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deployment/tensorrt_llm/api_server.py
DELETED
@@ -1,112 +0,0 @@
|
|
1 |
-
# copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py
|
2 |
-
|
3 |
-
#!/usr/bin/env python
|
4 |
-
import asyncio
|
5 |
-
import base64
|
6 |
-
import io
|
7 |
-
import logging
|
8 |
-
import signal
|
9 |
-
from http import HTTPStatus
|
10 |
-
from PIL import Image
|
11 |
-
from typing import Optional
|
12 |
-
|
13 |
-
import click
|
14 |
-
import uvicorn
|
15 |
-
from fastapi import FastAPI, Request
|
16 |
-
from fastapi.responses import JSONResponse, Response
|
17 |
-
|
18 |
-
from tensorrt_llm.executor import CppExecutorError, RequestError
|
19 |
-
from dolphin_runner import DolphinRunner, InferenceConfig
|
20 |
-
|
21 |
-
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
22 |
-
|
23 |
-
|
24 |
-
async def decode_image(image_base64: str) -> Image.Image:
|
25 |
-
image_data = base64.b64decode(image_base64)
|
26 |
-
image = Image.open(io.BytesIO(image_data))
|
27 |
-
return image
|
28 |
-
|
29 |
-
|
30 |
-
class LlmServer:
|
31 |
-
def __init__(self, runner: DolphinRunner):
|
32 |
-
self.runner = runner
|
33 |
-
self.app = FastAPI()
|
34 |
-
self.register_routes()
|
35 |
-
|
36 |
-
def register_routes(self):
|
37 |
-
self.app.add_api_route("/health", self.health, methods=["GET"])
|
38 |
-
self.app.add_api_route("/generate", self.generate, methods=["POST"])
|
39 |
-
|
40 |
-
async def health(self) -> Response:
|
41 |
-
return Response(status_code=200)
|
42 |
-
|
43 |
-
async def generate(self, request: Request) -> Response:
|
44 |
-
""" Generate completion for the request.
|
45 |
-
|
46 |
-
The request should be a JSON object with the following fields:
|
47 |
-
- prompt: the prompt to use for the generation.
|
48 |
-
- image_base64: the image to use for the generation.
|
49 |
-
"""
|
50 |
-
request_dict = await request.json()
|
51 |
-
|
52 |
-
prompt = request_dict.pop("prompt", "")
|
53 |
-
logging.info(f"request prompt: {prompt}")
|
54 |
-
image_base64 = request_dict.pop("image_base64", "")
|
55 |
-
image = await decode_image(image_base64)
|
56 |
-
|
57 |
-
try:
|
58 |
-
output_texts = self.runner.run([prompt], [image], 4024)
|
59 |
-
output_texts = [texts[0] for texts in output_texts]
|
60 |
-
return JSONResponse({"text": output_texts[0]})
|
61 |
-
except RequestError as e:
|
62 |
-
return JSONResponse(content=str(e),
|
63 |
-
status_code=HTTPStatus.BAD_REQUEST)
|
64 |
-
except CppExecutorError:
|
65 |
-
# If internal executor error is raised, shutdown the server
|
66 |
-
signal.raise_signal(signal.SIGINT)
|
67 |
-
|
68 |
-
async def __call__(self, host, port):
|
69 |
-
config = uvicorn.Config(self.app,
|
70 |
-
host=host,
|
71 |
-
port=port,
|
72 |
-
log_level="info",
|
73 |
-
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
74 |
-
await uvicorn.Server(config).serve()
|
75 |
-
|
76 |
-
|
77 |
-
@click.command()
|
78 |
-
@click.option("--hf_model_dir", type=str, required=True)
|
79 |
-
@click.option("--visual_engine_dir", type=str, required=True)
|
80 |
-
@click.option("--llm_engine_dir", type=str, required=True)
|
81 |
-
@click.option("--max_batch_size", type=int, default=16)
|
82 |
-
@click.option("--max_new_tokens", type=int, default=4024)
|
83 |
-
@click.option("--host", type=str, default=None)
|
84 |
-
@click.option("--port", type=int, default=8000)
|
85 |
-
def entrypoint(hf_model_dir: str,
|
86 |
-
visual_engine_dir: str,
|
87 |
-
llm_engine_dir: str,
|
88 |
-
max_batch_size: int,
|
89 |
-
max_new_tokens: int,
|
90 |
-
host: Optional[str] = None,
|
91 |
-
port: int = 8000):
|
92 |
-
host = host or "0.0.0.0"
|
93 |
-
port = port or 8000
|
94 |
-
logging.info(f"Starting server at {host}:{port}")
|
95 |
-
|
96 |
-
config = InferenceConfig(
|
97 |
-
max_new_tokens=max_new_tokens,
|
98 |
-
batch_size=max_batch_size,
|
99 |
-
log_level="info",
|
100 |
-
hf_model_dir=hf_model_dir,
|
101 |
-
visual_engine_dir=visual_engine_dir,
|
102 |
-
llm_engine_dir=llm_engine_dir,
|
103 |
-
)
|
104 |
-
|
105 |
-
dolphin_runner = DolphinRunner(config)
|
106 |
-
server = LlmServer(runner=dolphin_runner)
|
107 |
-
|
108 |
-
asyncio.run(server(host, port))
|
109 |
-
|
110 |
-
|
111 |
-
if __name__ == "__main__":
|
112 |
-
entrypoint()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|