Spaces:
Sleeping
Sleeping
import os | |
import re | |
import io | |
import streamlit as st | |
from PIL import Image, ImageDraw, ImageFont | |
from google import genai | |
from google.genai import types | |
from pdf2image import convert_from_bytes | |
# Constants | |
DETECTION_PROMPT = """\ | |
Identify ALL text regions in this document. Return bounding boxes as a Python list of lists | |
in format [[xmin, ymin, xmax, ymax]] where coordinates are normalized between 0-1. | |
Only return the list, nothing else. Example: | |
[[0.05, 0.12, 0.25, 0.18], [0.30, 0.40, 0.50, 0.55]] | |
""" | |
TEXT_EXTRACTION_PROMPT = "Extract the text in this image. Return only the exact text, nothing else." | |
def parse_list_boxes(text): | |
"""Improved parsing with better error handling""" | |
try: | |
return eval(text) | |
except: | |
matches = re.findall(r'\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]', text) | |
return [[float(x) for x in m] for m in matches] | |
def draw_bounding_boxes(image, boxes): | |
"""Enhanced drawing with numbering""" | |
if not boxes: | |
return image | |
draw = ImageDraw.Draw(image) | |
width, height = image.size | |
for i, box in enumerate(boxes): | |
try: | |
# Convert normalized coordinates to pixel values | |
xmin = max(0.0, min(1.0, box[0])) * width | |
ymin = max(0.0, min(1.0, box[1])) * height | |
xmax = max(0.0, min(1.0, box[2])) * width | |
ymax = max(0.0, min(1.0, box[3])) * height | |
# Draw bounding box | |
draw.rectangle([xmin, ymin, xmax, ymax], outline="#00FF00", width=3) | |
# Draw number label | |
label = str(i+1) | |
draw.text((xmin + 5, ymin + 5), label, fill="red") | |
except Exception as e: | |
st.error(f"Error drawing box: {str(e)}") | |
return image | |
def extract_text_from_region(client, image, box): | |
"""Extract text from a specific region using Gemini""" | |
try: | |
width, height = image.size | |
# Convert normalized coordinates to pixel values | |
xmin = int(max(0.0, min(1.0, box[0])) * width) | |
ymin = int(max(0.0, min(1.0, box[1])) * height) | |
xmax = int(max(0.0, min(1.0, box[2])) * width) | |
ymax = int(max(0.0, min(1.0, box[3])) * height) | |
if xmin >= xmax or ymin >= ymax: | |
return "" | |
# Crop and convert to bytes | |
cropped = image.crop((xmin, ymin, xmax, ymax)) | |
img_byte_arr = io.BytesIO() | |
cropped.save(img_byte_arr, format='PNG') | |
# Call Gemini API | |
response = client.models.generate_content( | |
model="gemini-2.0-flash-exp", | |
contents=[ | |
TEXT_EXTRACTION_PROMPT, | |
types.Part.from_bytes( | |
data=img_byte_arr.getvalue(), | |
mime_type="image/png" | |
) | |
] | |
) | |
return response.text.strip() | |
except Exception as e: | |
st.error(f"Text extraction error: {str(e)}") | |
return "" | |
# Streamlit UI | |
st.title("PDF Text Detection") | |
uploaded_file = st.file_uploader("Upload PDF", type=["pdf"]) | |
if uploaded_file and st.button("Analyze"): | |
with st.spinner("Processing..."): | |
try: | |
images = convert_from_bytes(uploaded_file.read(), dpi=300) | |
client = genai.Client(api_key=os.getenv("KEY")) | |
tabs = st.tabs([f"Page {i+1}" for i in range(len(images))]) | |
for idx, (tab, image) in enumerate(zip(tabs, images)): | |
with tab: | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption="Original", use_column_width=True) | |
with col2: | |
# Get bounding boxes | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
response = client.models.generate_content( | |
model="gemini-2.0-flash-exp", | |
contents=[ | |
DETECTION_PROMPT, | |
types.Part.from_bytes( | |
data=img_byte_arr.getvalue(), | |
mime_type="image/png" | |
) | |
] | |
) | |
boxes = parse_list_boxes(response.text) | |
texts = [extract_text_from_region(client, image, box) for box in boxes] | |
# Draw annotated image | |
annotated = draw_bounding_boxes(image.copy(), boxes) | |
st.image(annotated, | |
caption=f"Detected {len(boxes)} text regions", | |
use_column_width=True) | |
# Display extracted texts | |
if any(texts): | |
st.subheader("Extracted Texts:") | |
for i, text in enumerate(texts, 1): | |
st.write(f"{i}. {text if text else 'No text detected'}") | |
# Debug section | |
debug_expander = st.expander("Debug Details") | |
with debug_expander: | |
st.write("**Raw API Response:**") | |
st.code(response.text) | |
st.write("**Parsed Boxes:**") | |
st.write(boxes) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") |