Spaces:
Running
Running
import os | |
import re | |
import io | |
import streamlit as st | |
from PIL import Image, ImageDraw, ImageFont | |
from google import genai | |
from google.genai import types | |
from pdf2image import convert_from_bytes | |
DETECTION_PROMPT = """\ | |
Analyze this document image and identify text regions following these rules: | |
1. GROUP RELATED CONTENT: | |
- Full tables as SINGLE regions (including headers and all rows) | |
- Paragraphs as SINGLE rectangular blocks (multiple lines as one box) | |
- Keep text columns intact | |
- Treat list items as single region if visually grouped | |
2. TEXT REGION REQUIREMENTS: | |
- Boundaries must tightly wrap text content | |
- Include 2% padding around text clusters | |
- Exclude isolated decorative elements | |
- Merge adjacent text fragments with ≤1% spacing | |
3. COORDINATE FORMAT: | |
Python list of lists [[xmin, ymin, xmax, ymax]] | |
- Normalized 0-1 with 3 decimal places | |
- Ordered top-to-bottom, left-to-right | |
- Table example: [[0.12, 0.35, 0.88, 0.65]] for full table | |
4. SPECIAL CASES: | |
- Table cells should NOT have individual boxes | |
- Page headers/footers as separate regions | |
- Text wrapped around images as distinct regions | |
Example response for table + 2 paragraphs: | |
[[0.07, 0.12, 0.93, 0.28], # Header | |
[0.12, 0.35, 0.88, 0.65], # Full table | |
[0.10, 0.70, 0.90, 0.85], # First paragraph | |
[0.10, 0.88, 0.90, 0.95]] # Second paragraph | |
ONLY RETURN THE PYTHON LIST! No explanations. | |
""" | |
TEXT_EXTRACTION_PROMPT = "Extract the text in this image. Return only the exact text, nothing else." | |
def parse_list_boxes(text): | |
"""Improved parsing with better error handling""" | |
try: | |
return eval(text) | |
except: | |
matches = re.findall(r'\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]', text) | |
return [[float(x) for x in m] for m in matches] | |
def draw_bounding_boxes(image, boxes): | |
"""Enhanced drawing with numbering""" | |
if not boxes: | |
return image | |
draw = ImageDraw.Draw(image) | |
width, height = image.size | |
for i, box in enumerate(boxes): | |
try: | |
# Convert normalized coordinates to pixel values | |
xmin = max(0.0, min(1.0, box[0])) * width | |
ymin = max(0.0, min(1.0, box[1])) * height | |
xmax = max(0.0, min(1.0, box[2])) * width | |
ymax = max(0.0, min(1.0, box[3])) * height | |
# Draw bounding box | |
draw.rectangle([xmin, ymin, xmax, ymax], outline="#00FF00", width=3) | |
# Draw number label | |
label = str(i+1) | |
draw.text((xmin + 5, ymin + 5), label, fill="red") | |
except Exception as e: | |
st.error(f"Error drawing box: {str(e)}") | |
return image | |
def extract_text_from_region(client, image, box): | |
"""Extract text from a specific region using Gemini""" | |
try: | |
width, height = image.size | |
# Convert normalized coordinates to pixel values | |
xmin = int(max(0.0, min(1.0, box[0])) * width) | |
ymin = int(max(0.0, min(1.0, box[1])) * height) | |
xmax = int(max(0.0, min(1.0, box[2])) * width) | |
ymax = int(max(0.0, min(1.0, box[3])) * height) | |
if xmin >= xmax or ymin >= ymax: | |
return "" | |
# Crop and convert to bytes | |
cropped = image.crop((xmin, ymin, xmax, ymax)) | |
img_byte_arr = io.BytesIO() | |
cropped.save(img_byte_arr, format='PNG') | |
# Call Gemini API | |
response = client.models.generate_content( | |
model="gemini-2.5-pro-exp-03-25", | |
contents=[ | |
TEXT_EXTRACTION_PROMPT, | |
types.Part.from_bytes( | |
data=img_byte_arr.getvalue(), | |
mime_type="image/png" | |
) | |
] | |
) | |
return response.text.strip() | |
except Exception as e: | |
st.error(f"Text extraction error: {str(e)}") | |
return "" | |
# Streamlit UI | |
st.title("PDF Text Detection") | |
uploaded_file = st.file_uploader("Upload PDF", type=["pdf"]) | |
if uploaded_file and st.button("Analyze"): | |
with st.spinner("Processing..."): | |
try: | |
images = convert_from_bytes(uploaded_file.read(), dpi=300) | |
client = genai.Client(api_key=os.getenv("KEY")) | |
tabs = st.tabs([f"Page {i+1}" for i in range(len(images))]) | |
for idx, (tab, image) in enumerate(zip(tabs, images)): | |
with tab: | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Original", use_container_width=True) | |
with col2: | |
# Get bounding boxes | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
response = client.models.generate_content( | |
model="gemini-2.0-flash-exp", | |
contents=[ | |
DETECTION_PROMPT, | |
types.Part.from_bytes( | |
data=img_byte_arr.getvalue(), | |
mime_type="image/png" | |
) | |
] | |
) | |
boxes = parse_list_boxes(response.text) | |
texts = [extract_text_from_region(client, image, box) for box in boxes] | |
# Draw annotated image | |
annotated = draw_bounding_boxes(image.copy(), boxes) | |
st.image(annotated, | |
caption=f"Detected {len(boxes)} text regions", | |
use_container_width=True) | |
# Display extracted texts | |
if any(texts): | |
st.subheader("Extracted Texts:") | |
for i, text in enumerate(texts, 1): | |
st.write(f"{i}. {text if text else 'No text detected'}") | |
# Debug section | |
debug_expander = st.expander("Debug Details") | |
with debug_expander: | |
st.write("**Raw API Response:**") | |
st.code(response.text) | |
st.write("**Parsed Boxes:**") | |
st.write(boxes) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") |