Spaces:
Running
Running
from transformers import MllamaForConditionalGeneration, AutoProcessor, TextIteratorStreamer | |
from PIL import Image | |
import requests | |
import torch | |
from threading import Thread | |
import gradio as gr | |
from gradio import FileData | |
import time | |
import spaces | |
import fitz # PyMuPDF | |
import io | |
import numpy as np | |
ckpt = "Daemontatox/DocumentCogito" | |
model = MllamaForConditionalGeneration.from_pretrained(ckpt, | |
torch_dtype=torch.bfloat16).to("cuda") | |
processor = AutoProcessor.from_pretrained(ckpt) | |
class DocumentState: | |
def __init__(self): | |
self.current_doc_images = [] | |
self.current_doc_text = "" | |
self.doc_type = None # 'pdf' or 'image' | |
def clear(self): | |
self.current_doc_images = [] | |
self.current_doc_text = "" | |
self.doc_type = None | |
doc_state = DocumentState() | |
def process_pdf_file(file_path): | |
"""Convert PDF to images and extract text using PyMuPDF.""" | |
doc = fitz.open(file_path) | |
images = [] | |
text = "" | |
for page_num, page in enumerate(doc): | |
# Extract text | |
text += f"\n=== Page {page_num + 1} ===\n" | |
text += page.get_text() + "\n" | |
# Convert page to image | |
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72)) # 300 DPI | |
img_data = pix.tobytes("png") | |
img = Image.open(io.BytesIO(img_data)) | |
images.append(img.convert("RGB")) | |
doc.close() | |
return images, text | |
def process_file(file): | |
"""Process either PDF or image file and update document state.""" | |
doc_state.clear() | |
if isinstance(file, dict): | |
file_path = file["path"] | |
else: | |
file_path = file | |
if file_path.lower().endswith('.pdf'): | |
doc_state.doc_type = 'pdf' | |
doc_state.current_doc_images, doc_state.current_doc_text = process_pdf_file(file_path) | |
return f"PDF processed successfully. {len(doc_state.current_doc_images)} pages loaded. You can now ask questions about the content." | |
else: | |
doc_state.doc_type = 'image' | |
doc_state.current_doc_images = [Image.open(file_path).convert("RGB")] | |
return "Image loaded successfully. You can now ask questions about the content." | |
def bot_streaming(message, history, max_new_tokens=2048): | |
txt = message["text"] | |
messages = [] | |
images = [] | |
# Process new file if provided | |
if message.get("files") and len(message["files"]) > 0: | |
process_file(message["files"][0]) | |
# Process history and maintain context | |
for i, msg in enumerate(history): | |
if isinstance(msg[0], tuple): | |
messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]}) | |
messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]}) | |
elif isinstance(history[i-1], tuple) and isinstance(msg[0], str): | |
pass | |
elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): | |
messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]}) | |
messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]}) | |
# Include document context in the current message | |
if doc_state.current_doc_images: | |
images.extend(doc_state.current_doc_images) | |
context = "" | |
if doc_state.doc_type == 'pdf': | |
context = f"\nContext from PDF:\n{doc_state.current_doc_text}" | |
current_msg = f"{txt}{context}" | |
messages.append({"role": "user", "content": [{"type": "text", "text": current_msg}, {"type": "image"}]}) | |
else: | |
messages.append({"role": "user", "content": [{"type": "text", "text": txt}]}) | |
texts = processor.apply_chat_template(messages, add_generation_prompt=True) | |
if not images: | |
inputs = processor(text=texts, return_tensors="pt").to("cuda") | |
else: | |
# Process images in batches if needed | |
max_images = 12 # Increased maximum number of images/pages | |
if len(images) > max_images: | |
# Take evenly spaced samples if we have too many pages | |
indices = np.linspace(0, len(images) - 1, max_images, dtype=int) | |
images = [images[i] for i in indices] | |
txt += f"\n(Note: Analyzing {max_images} evenly distributed pages from the document)" | |
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
def clear_context(): | |
"""Clear the current document context.""" | |
doc_state.clear() | |
return "Document context cleared. You can upload a new document." | |
# Create the Gradio interface with enhanced features | |
with gr.Blocks() as demo: | |
gr.Markdown("# Document Analyzer with Chat Support") | |
gr.Markdown("Upload a PDF or image and chat about its contents. The context is maintained throughout the conversation.") | |
chatbot = gr.ChatInterface( | |
fn=bot_streaming, | |
title="Document Chat", | |
examples=[ | |
[{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]}, 200], | |
[{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]}, 250], | |
[{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]}, 250], | |
[{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]}, 250], | |
[{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]}, 250], | |
], | |
textbox=gr.MultimodalTextbox(), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=10, | |
maximum=2048, | |
value=2048, | |
step=10, | |
label="Maximum number of new tokens to generate", | |
) | |
], | |
cache_examples=False, | |
stop_btn="Stop Generation", | |
fill_height=True, | |
multimodal=True | |
) | |
clear_btn = gr.Button("Clear Document Context") | |
clear_btn.click(fn=clear_context) | |
# Update accepted file types | |
chatbot.textbox.file_types = ["image", "pdf"] | |
# Launch the interface | |
demo.launch(debug=True) |