Spaces:
Running
Running
import os | |
import io | |
import time | |
import base64 | |
import logging | |
import fitz # PyMuPDF | |
from PIL import Image | |
import gradio as gr | |
from openai import OpenAI # Use the OpenAI client that supports multimodal messages | |
# Load API key from environment variable | |
HF_API_KEY = os.getenv("OPENAI_TOKEN") | |
if not HF_API_KEY: | |
raise ValueError("OPENAI_TOKEN environment variable not set") | |
# Create the client pointing to the inference endpoint (e.g., OpenRouter) | |
client = OpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=HF_API_KEY | |
) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ------------------------------- | |
# Document State and File Processing | |
# ------------------------------- | |
class DocumentState: | |
def __init__(self): | |
self.current_doc_images = [] | |
self.current_doc_text = "" | |
self.doc_type = None | |
def clear(self): | |
self.current_doc_images = [] | |
self.current_doc_text = "" | |
self.doc_type = None | |
doc_state = DocumentState() | |
def process_pdf_file(file_path): | |
"""Convert PDF pages to images and extract text using PyMuPDF.""" | |
try: | |
doc = fitz.open(file_path) | |
images = [] | |
text = "" | |
for page_num in range(doc.page_count): | |
try: | |
page = doc[page_num] | |
page_text = page.get_text("text") | |
if page_text.strip(): | |
text += f"Page {page_num+1}:\n{page_text}\n\n" | |
# Render page as an image with a zoom factor | |
zoom = 3 | |
mat = fitz.Matrix(zoom, zoom) | |
pix = page.get_pixmap(matrix=mat, alpha=False) | |
img_data = pix.tobytes("png") | |
img = Image.open(io.BytesIO(img_data)).convert("RGB") | |
# Resize if image is too large | |
max_size = 1600 | |
if max(img.size) > max_size: | |
ratio = max_size / max(img.size) | |
new_size = tuple(int(dim * ratio) for dim in img.size) | |
img = img.resize(new_size, Image.Resampling.LANCZOS) | |
images.append(img) | |
except Exception as e: | |
logger.error(f"Error processing page {page_num}: {str(e)}") | |
continue | |
doc.close() | |
if not images: | |
raise ValueError("No valid images could be extracted from the PDF") | |
return images, text | |
except Exception as e: | |
logger.error(f"Error processing PDF file: {str(e)}") | |
raise | |
def process_uploaded_file(file): | |
"""Process an uploaded file (PDF or image) and update document state.""" | |
try: | |
doc_state.clear() | |
if file is None: | |
return "No file uploaded. Please upload a file." | |
# Gradio may pass a dict or a file-like object | |
if isinstance(file, dict): | |
file_path = file["name"] | |
else: | |
file_path = file.name | |
file_ext = file_path.lower().split('.')[-1] | |
image_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'} | |
if file_ext == 'pdf': | |
doc_state.doc_type = 'pdf' | |
try: | |
doc_state.current_doc_images, doc_state.current_doc_text = process_pdf_file(file_path) | |
return f"PDF processed successfully. Total pages: {len(doc_state.current_doc_images)}. You can now chat with the bot." | |
except Exception as e: | |
return f"Error processing PDF: {str(e)}. Please try a different PDF file." | |
elif file_ext in image_extensions: | |
doc_state.doc_type = 'image' | |
try: | |
img = Image.open(file_path).convert("RGB") | |
max_size = 1600 | |
if max(img.size) > max_size: | |
ratio = max_size / max(img.size) | |
new_size = tuple(int(dim * ratio) for dim in img.size) | |
img = img.resize(new_size, Image.Resampling.LANCZOS) | |
doc_state.current_doc_images = [img] | |
return "Image loaded successfully. You can now chat with the bot." | |
except Exception as e: | |
return f"Error processing image: {str(e)}. Please try a different image file." | |
else: | |
return f"Unsupported file type: {file_ext}. Please upload a PDF or image file (PNG, JPG, JPEG, GIF, BMP, WEBP)." | |
except Exception as e: | |
logger.error(f"Error in process_uploaded_file: {str(e)}") | |
return "An error occurred while processing the file. Please try again." | |
def clear_context(): | |
"""Clear the current document context and chat history.""" | |
doc_state.clear() | |
return "Document context cleared. You can upload a new document.", [] | |
# ------------------------------- | |
# Predetermined Prompts | |
# ------------------------------- | |
predetermined_prompts = { | |
"NOC Timesheet": ( | |
"Extract structured information from the provided timesheet. The extracted details should include:\n" | |
"Name, Position Title, Work Location, Contractor, NOC ID, Month and Year, Regular Service Days, " | |
"Standby Days, Offshore Days, Extended Hitch Days, and approvals. Format the output as valid JSON." | |
), | |
"Aramco Full structured": ( | |
"You are a document parsing assistant designed to extract structured data from various documents such as " | |
"invoices, timesheets, purchase orders, and travel bookings. Return only valid JSON with no extra text." | |
), | |
"Aramco Timesheet only": ( | |
"Extract time tracking, work details, and approvals. Return a JSON object following the specified structure." | |
), | |
"NOC Invoice": ( | |
"You are a highly accurate data extraction system. Analyze the provided invoice image and extract all data " | |
"into the following JSON format:\n" | |
"{\n 'invoiceDetails': { ... },\n 'from': { ... },\n 'to': { ... },\n 'services': [ ... ],\n " | |
"'totals': { ... },\n 'bankDetails': { ... }\n}" | |
), | |
"Software Tester": ( | |
"Act as a software tester. Analyze the uploaded image of a software interface and generate comprehensive " | |
"test cases for its features. For each feature, provide test steps, expected results, and any necessary " | |
"preconditions. Be as detailed as possible." | |
) | |
} | |
# ------------------------------- | |
# Chat Function (Non-streaming Version) | |
# ------------------------------- | |
def chat_respond(user_message, history, prompt_option): | |
""" | |
Append the user message to the conversation history, call the API, | |
and return the full response. | |
Each message passed to the API is now a dictionary with a string value for 'content'. | |
If an image was uploaded, its data URI is appended to the first user message. | |
The conversation history is a list of [user_text, assistant_text] pairs. | |
""" | |
# On the first message, if none is provided, use the predetermined prompt. | |
if history == []: | |
if not user_message.strip(): | |
user_message = predetermined_prompts.get(prompt_option, "Hello") | |
else: | |
user_message = predetermined_prompts.get(prompt_option, "") + "\n" + user_message | |
history = history + [[user_message, ""]] | |
messages = [] | |
# Build the messages list with each message as a dictionary containing role and a string content. | |
for i, (user_msg, assistant_msg) in enumerate(history): | |
# For the very first user message, attach the image (if available) by appending its data URI. | |
if i == 0 and doc_state.current_doc_images: | |
buffered = io.BytesIO() | |
doc_state.current_doc_images[0].save(buffered, format="PNG") | |
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
data_uri = f"data:image/png;base64,{img_b64}" | |
text_to_send = user_msg + "\n[Attached Image: " + data_uri + "]" | |
else: | |
text_to_send = user_msg | |
messages.append({"role": "user", "content": text_to_send}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
try: | |
# Call the API without streaming. The messages are now standard dictionaries. | |
response = client.chat.completions.create( | |
model="qwen/qwen-vl-plus:free", | |
messages=messages, | |
max_tokens=500 | |
) | |
except Exception as e: | |
logger.error(f"Error calling the API: {str(e)}") | |
history[-1][1] = "An error occurred while processing your request. Please check your API credentials." | |
return history, history | |
# Assuming the API returns a standard completion response, extract the assistant's reply. | |
try: | |
full_response = response.choices[0].message["content"] | |
except Exception as e: | |
logger.error(f"Error extracting API response: {str(e)}") | |
full_response = "An error occurred while processing the API response." | |
history[-1][1] = full_response | |
return history, history | |
# ------------------------------- | |
# Create the Gradio Interface | |
# ------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Document Analyzer & Software Testing Chatbot") | |
gr.Markdown( | |
"Upload a PDF or an image (PNG, JPG, JPEG, GIF, BMP, WEBP). Then choose a prompt from the dropdown. " | |
"For example, select **Software Tester** to have the bot analyze an image of a software interface " | |
"and generate test cases. You can also chat with the model—the conversation history is preserved." | |
) | |
with gr.Row(): | |
file_upload = gr.File( | |
label="Upload Document", | |
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"] | |
) | |
upload_status = gr.Textbox(label="Upload Status", interactive=False) | |
with gr.Row(): | |
prompt_dropdown = gr.Dropdown( | |
label="Select Prompt", | |
choices=[ | |
"NOC Timesheet", | |
"Aramco Full structured", | |
"Aramco Timesheet only", | |
"NOC Invoice", | |
"Software Tester" | |
], | |
value="Software Tester" | |
) | |
clear_btn = gr.Button("Clear Document Context & Chat History") | |
# Set type='messages' to avoid deprecation warnings. | |
chatbot = gr.Chatbot(label="Chat History", type="messages", elem_id="chatbot") | |
with gr.Row(): | |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", show_label=False) | |
send_btn = gr.Button("Send") | |
# State to hold the conversation history | |
chat_state = gr.State([]) | |
# When a file is uploaded, process it. | |
file_upload.change(fn=process_uploaded_file, inputs=file_upload, outputs=upload_status) | |
# Clear document context and chat history. | |
clear_btn.click(fn=clear_context, outputs=[upload_status, chat_state]) | |
# When the user clicks Send, process the message and update the chat. | |
send_btn.click( | |
fn=chat_respond, | |
inputs=[user_input, chat_state, prompt_dropdown], | |
outputs=[chatbot, chat_state] | |
) | |
demo.launch(debug=True) | |