File size: 5,639 Bytes
6459986
6c79114
d1dce8a
6c79114
d2aded5
6c79114
1d8d466
7e4f227
5f554b3
316d102
6018547
62c24e3
 
 
 
316d102
 
d2aded5
 
6c79114
62c24e3
 
6c64ea6
62c24e3
 
 
5f554b3
cdb1e78
d2aded5
62c24e3
 
 
cdb1e78
 
 
d2aded5
62c24e3
d2aded5
62c24e3
 
 
 
 
d2aded5
62c24e3
d2aded5
 
 
 
62c24e3
 
6c79114
d1dce8a
d2aded5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c79114
62c24e3
 
 
 
 
 
6c64ea6
70da8e5
62c24e3
6c64ea6
 
 
 
 
62c24e3
6c64ea6
d2aded5
62c24e3
6c64ea6
d2aded5
6c64ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
d2aded5
6c64ea6
d2aded5
 
6c64ea6
 
d2aded5
6c64ea6
d2aded5
 
 
 
 
 
6c64ea6
 
 
 
 
 
 
62c24e3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import re
import io
import streamlit as st
from PIL import Image, ImageDraw, ImageFont
from google import genai
from google.genai import types
from pdf2image import convert_from_bytes

# Constants
DETECTION_PROMPT = """\
Identify ALL text regions in this document. Return bounding boxes as a Python list of lists 
in format [[xmin, ymin, xmax, ymax]] where coordinates are normalized between 0-1. 
Only return the list, nothing else. Example:
[[0.05, 0.12, 0.25, 0.18], [0.30, 0.40, 0.50, 0.55]]
"""

TEXT_EXTRACTION_PROMPT = "Extract the text in this image. Return only the exact text, nothing else."

def parse_list_boxes(text):
    """Improved parsing with better error handling"""
    try:
        return eval(text)
    except:
        matches = re.findall(r'\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]', text)
        return [[float(x) for x in m] for m in matches]

def draw_bounding_boxes(image, boxes):
    """Enhanced drawing with numbering"""
    if not boxes:
        return image
        
    draw = ImageDraw.Draw(image)
    width, height = image.size
    
    for i, box in enumerate(boxes):
        try:
            # Convert normalized coordinates to pixel values
            xmin = max(0.0, min(1.0, box[0])) * width
            ymin = max(0.0, min(1.0, box[1])) * height
            xmax = max(0.0, min(1.0, box[2])) * width
            ymax = max(0.0, min(1.0, box[3])) * height
            
            # Draw bounding box
            draw.rectangle([xmin, ymin, xmax, ymax], outline="#00FF00", width=3)
            
            # Draw number label
            label = str(i+1)
            draw.text((xmin + 5, ymin + 5), label, fill="red")
        except Exception as e:
            st.error(f"Error drawing box: {str(e)}")
    return image

def extract_text_from_region(client, image, box):
    """Extract text from a specific region using Gemini"""
    try:
        width, height = image.size
        # Convert normalized coordinates to pixel values
        xmin = int(max(0.0, min(1.0, box[0])) * width)
        ymin = int(max(0.0, min(1.0, box[1])) * height)
        xmax = int(max(0.0, min(1.0, box[2])) * width)
        ymax = int(max(0.0, min(1.0, box[3])) * height)

        if xmin >= xmax or ymin >= ymax:
            return ""

        # Crop and convert to bytes
        cropped = image.crop((xmin, ymin, xmax, ymax))
        img_byte_arr = io.BytesIO()
        cropped.save(img_byte_arr, format='PNG')
        
        # Call Gemini API
        response = client.models.generate_content(
            model="gemini-2.0-flash-exp",
            contents=[
                TEXT_EXTRACTION_PROMPT,
                types.Part.from_bytes(
                    data=img_byte_arr.getvalue(),
                    mime_type="image/png"
                )
            ]
        )
        return response.text.strip()
    except Exception as e:
        st.error(f"Text extraction error: {str(e)}")
        return ""

# Streamlit UI
st.title("PDF Text Detection")
uploaded_file = st.file_uploader("Upload PDF", type=["pdf"])

if uploaded_file and st.button("Analyze"):
    with st.spinner("Processing..."):
        try:
            images = convert_from_bytes(uploaded_file.read(), dpi=300)
            client = genai.Client(api_key=os.getenv("KEY"))
            
            tabs = st.tabs([f"Page {i+1}" for i in range(len(images))])
            
            for idx, (tab, image) in enumerate(zip(tabs, images)):
                with tab:
                    col1, col2 = st.columns(2)
                    
                    with col1:
                        st.image(image, caption="Original", use_column_width=True)
                    
                    with col2:
                        # Get bounding boxes
                        img_byte_arr = io.BytesIO()
                        image.save(img_byte_arr, format='PNG')
                        response = client.models.generate_content(
                            model="gemini-2.0-flash-exp",
                            contents=[
                                DETECTION_PROMPT,
                                types.Part.from_bytes(
                                    data=img_byte_arr.getvalue(),
                                    mime_type="image/png"
                                )
                            ]
                        )
                        
                        boxes = parse_list_boxes(response.text)
                        texts = [extract_text_from_region(client, image, box) for box in boxes]
                        
                        # Draw annotated image
                        annotated = draw_bounding_boxes(image.copy(), boxes)
                        st.image(annotated, 
                               caption=f"Detected {len(boxes)} text regions", 
                               use_column_width=True)
                        
                        # Display extracted texts
                        if any(texts):
                            st.subheader("Extracted Texts:")
                            for i, text in enumerate(texts, 1):
                                st.write(f"{i}. {text if text else 'No text detected'}")

                        # Debug section
                        debug_expander = st.expander("Debug Details")
                        with debug_expander:
                            st.write("**Raw API Response:**")
                            st.code(response.text)
                            st.write("**Parsed Boxes:**")
                            st.write(boxes)

        except Exception as e:
            st.error(f"Error: {str(e)}")