Spaces:
Running
Running
import os | |
import io | |
import time | |
import base64 | |
import logging | |
import fitz # PyMuPDF | |
from PIL import Image | |
import gradio as gr | |
from openai import OpenAI # Use the OpenAI client that supports multimodal messages | |
# Load API key from environment variable (secrets) | |
HF_API_KEY = os.getenv("OPENAI_TOKEN") | |
if not HF_API_KEY: | |
raise ValueError("HF_API_KEY environment variable not set") | |
# Create the client pointing to the Hugging Face Inference endpoint | |
client = OpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=HF_API_KEY | |
) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ------------------------------- | |
# Document State and File Processing | |
# ------------------------------- | |
class DocumentState: | |
def __init__(self): | |
self.current_doc_images = [] | |
self.current_doc_text = "" | |
self.doc_type = None | |
def clear(self): | |
self.current_doc_images = [] | |
self.current_doc_text = "" | |
self.doc_type = None | |
doc_state = DocumentState() | |
def process_pdf_file(file_path): | |
"""Convert PDF pages to images and extract text using PyMuPDF.""" | |
try: | |
doc = fitz.open(file_path) | |
images = [] | |
text = "" | |
for page_num in range(doc.page_count): | |
try: | |
page = doc[page_num] | |
page_text = page.get_text("text") | |
if page_text.strip(): | |
text += f"Page {page_num + 1}:\n{page_text}\n\n" | |
# Render page as an image with a zoom factor | |
zoom = 3 | |
mat = fitz.Matrix(zoom, zoom) | |
pix = page.get_pixmap(matrix=mat, alpha=False) | |
img_data = pix.tobytes("png") | |
img = Image.open(io.BytesIO(img_data)).convert("RGB") | |
# Resize if image is too large | |
max_size = 1600 | |
if max(img.size) > max_size: | |
ratio = max_size / max(img.size) | |
new_size = tuple(int(dim * ratio) for dim in img.size) | |
img = img.resize(new_size, Image.Resampling.LANCZOS) | |
images.append(img) | |
except Exception as e: | |
logger.error(f"Error processing page {page_num}: {str(e)}") | |
continue | |
doc.close() | |
if not images: | |
raise ValueError("No valid images could be extracted from the PDF") | |
return images, text | |
except Exception as e: | |
logger.error(f"Error processing PDF file: {str(e)}") | |
raise | |
def process_uploaded_file(file): | |
"""Process an uploaded file (PDF or image) and update document state.""" | |
try: | |
doc_state.clear() | |
if file is None: | |
return "No file uploaded. Please upload a file." | |
# Get the file path from the Gradio upload (may be a dict or file-like object) | |
if isinstance(file, dict): | |
file_path = file["name"] | |
else: | |
file_path = file.name | |
file_ext = file_path.lower().split('.')[-1] | |
image_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'} | |
if file_ext == 'pdf': | |
doc_state.doc_type = 'pdf' | |
try: | |
doc_state.current_doc_images, doc_state.current_doc_text = process_pdf_file(file_path) | |
return f"PDF processed successfully. Total pages: {len(doc_state.current_doc_images)}. You can now ask questions about the content." | |
except Exception as e: | |
return f"Error processing PDF: {str(e)}. Please try a different PDF file." | |
elif file_ext in image_extensions: | |
doc_state.doc_type = 'image' | |
try: | |
img = Image.open(file_path).convert("RGB") | |
max_size = 1600 | |
if max(img.size) > max_size: | |
ratio = max_size / max(img.size) | |
new_size = tuple(int(dim * ratio) for dim in img.size) | |
img = img.resize(new_size, Image.Resampling.LANCZOS) | |
doc_state.current_doc_images = [img] | |
return "Image loaded successfully. You can now ask questions about the content." | |
except Exception as e: | |
return f"Error processing image: {str(e)}. Please try a different image file." | |
else: | |
return f"Unsupported file type: {file_ext}. Please upload a PDF or image file (PNG, JPG, JPEG, GIF, BMP, WEBP)." | |
except Exception as e: | |
logger.error(f"Error in process_uploaded_file: {str(e)}") | |
return "An error occurred while processing the file. Please try again." | |
# ------------------------------- | |
# Bot Streaming Function Using the Multimodal API | |
# ------------------------------- | |
def bot_streaming(prompt_option, max_new_tokens=500): | |
""" | |
Build a multimodal message payload and call the inference API. | |
The payload includes: | |
- A text segment (the selected prompt and any document context). | |
- If available, an image as a data URI (using a base64-encoded PNG). | |
""" | |
try: | |
# Predetermined prompts (you can adjust these as needed) | |
prompts = { | |
"NOC Timesheet": ( | |
"""Extract structured information from the provided timesheet. The extracted details should include: | |
Name | |
Position Title | |
Work Location | |
Contractor | |
NOC ID | |
Month and Year | |
Regular Service Days (ONSHORE) | |
Standby Days (ONSHORE in Doha) | |
Offshore Days | |
Standby & Extended Hitch Days (OFFSHORE) | |
Extended Hitch Days (ONSHORE Rotational) | |
Service during Weekends & Public Holidays | |
ONSHORE Overtime Hours (Over 8 hours) | |
OFFSHORE Overtime Hours (Over 12 hours) | |
Per Diem Days (ONSHORE/OFFSHORE Rotational Personnel) | |
Training Days | |
Travel Days | |
Noc representative appoval's name as approved_by | |
Noc representative's date approval_date | |
Noc representative status as approval_status | |
Format the output as valid JSON. | |
""" | |
), | |
"NOC Basic": ( | |
"Based on the provided timesheet details, extract the following information:\n" | |
" - Full name\n" | |
" - Position title\n" | |
" - Work location\n" | |
" - Contractor's name\n" | |
" - NOC ID\n" | |
" - Month and year (MM/YYYY)" | |
), | |
"Aramco Full structured": ( | |
"""You are a document parsing assistant designed to extract structured data from various documents such as invoices, timesheets, purchase orders, and travel bookings. Return only valid JSON with no extra text. | |
""" | |
), | |
"Aramco Timesheet only": ( | |
"""Extract time tracking, work details, and approvals. | |
Return a JSON object following the specified structure. | |
""" | |
), | |
"NOC Invoice": ( | |
"""You are a highly accurate data extraction system. Analyze the provided invoice image and extract all data into the following JSON format: | |
{ | |
"invoiceDetails": { ... }, | |
"from": { ... }, | |
"to": { ... }, | |
"services": [ ... ], | |
"totals": { ... }, | |
"bankDetails": { ... } | |
} | |
""" | |
) | |
} | |
# Select the appropriate prompt | |
selected_prompt = prompts.get(prompt_option, "Invalid prompt selected.") | |
context = "" | |
if doc_state.current_doc_images and doc_state.current_doc_text: | |
context = "\nDocument context:\n" + doc_state.current_doc_text | |
full_prompt = selected_prompt + context | |
# Build the message payload in the expected format. | |
# The content field is a list of objects—one for text, and (if an image is available) one for the image. | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": full_prompt | |
} | |
] | |
} | |
] | |
# If an image is available, encode it as a data URI and append it as an image_url message. | |
if doc_state.current_doc_images: | |
buffered = io.BytesIO() | |
doc_state.current_doc_images[0].save(buffered, format="PNG") | |
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# Create a data URI (many APIs accept this format in place of a public URL) | |
data_uri = f"data:image/png;base64,{img_b64}" | |
messages[0]["content"].append({ | |
"type": "image_url", | |
"image_url": {"url": data_uri} | |
}) | |
# Call the inference API with streaming enabled. | |
stream = client.chat.completions.create( | |
model="qwen/qwen2.5-vl-72b-instruct:free", | |
messages=messages, | |
max_tokens=max_new_tokens, | |
stream=True | |
) | |
buffer = "" | |
for chunk in stream: | |
# The response structure is similar to the reference: each chunk contains a delta. | |
delta = chunk.choices[0].delta.content | |
buffer += delta | |
time.sleep(0.01) | |
yield buffer | |
except Exception as e: | |
logger.error(f"Error in bot_streaming: {str(e)}") | |
yield "An error occurred while processing your request. Please try again." | |
def clear_context(): | |
"""Clear the current document context.""" | |
doc_state.clear() | |
return "Document context cleared. You can upload a new document." | |
# ------------------------------- | |
# Create the Gradio Interface | |
# ------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Document Analyzer with Predetermined Prompts") | |
gr.Markdown("Upload a PDF or image (PNG, JPG, JPEG, GIF, BMP, WEBP) and select a prompt to analyze its contents.") | |
with gr.Row(): | |
file_upload = gr.File( | |
label="Upload Document", | |
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"] | |
) | |
upload_status = gr.Textbox(label="Upload Status", interactive=False) | |
with gr.Row(): | |
prompt_dropdown = gr.Dropdown( | |
label="Select Prompt", | |
choices=["NOC Timesheet", "NOC Basic", "Aramco Full structured", "Aramco Timesheet only", "NOC Invoice"], | |
value="NOC Timesheet" | |
) | |
generate_btn = gr.Button("Generate") | |
clear_btn = gr.Button("Clear Document Context") | |
output_text = gr.Textbox(label="Output", interactive=False) | |
file_upload.change(fn=process_uploaded_file, inputs=[file_upload], outputs=[upload_status]) | |
generate_btn.click(fn=bot_streaming, inputs=[prompt_dropdown], outputs=[output_text]) | |
clear_btn.click(fn=clear_context, outputs=[upload_status]) | |
demo.launch(debug=True) | |