File size: 4,303 Bytes
2f3144c
 
 
 
 
 
 
 
d85fa29
2f3144c
d85fa29
7508b05
2f3144c
3ef82d2
7508b05
3ef82d2
2f3144c
d85fa29
cacc570
d85fa29
cacc570
2f3144c
d85fa29
2f3144c
 
 
d85fa29
 
 
 
 
 
 
 
 
7508b05
2f3144c
d85fa29
2f3144c
d85fa29
2f3144c
d85fa29
 
 
 
 
 
 
2f3144c
 
d85fa29
2f3144c
 
d85fa29
 
 
 
 
 
 
2f3144c
7508b05
 
d85fa29
 
2f3144c
7508b05
 
d85fa29
7508b05
2be66f7
 
cacc570
2be66f7
 
d85fa29
2be66f7
d85fa29
 
 
 
 
cacc570
d85fa29
 
 
 
 
 
 
 
 
 
 
 
 
 
cacc570
d85fa29
7508b05
d85fa29
7508b05
 
d85fa29
 
7508b05
d85fa29
7508b05
 
d85fa29
 
 
 
 
7508b05
 
d85fa29
 
 
 
 
 
 
 
 
 
 
 
 
 
7508b05
d85fa29
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import re

# Load models
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

extracted_text = ""  # Store the extracted text globally for keyword search

def ocr_and_extract(image, text_query=None):
    global extracted_text
    try:
        # Save the uploaded image temporarily
        temp_image_path = "temp_image.jpg"
        image.save(temp_image_path)

        # Index the image with Byaldi
        rag_model.index(
            input_path=temp_image_path,
            index_name="image_index",
            store_collection_with_index=False,
            overwrite=True
        )

        # Perform the search query on the indexed image
        results = rag_model.search(text_query, k=1)

        # Prepare the input for Qwen2-VL
        image_data = Image.open(temp_image_path)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_data},
                    {"type": "text", "text": text_query},
                ],
            }
        ]

        # Process input for Qwen2-VL
        text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)

        inputs = processor(
            text=[text_input],
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        )

        qwen_model.to("cuda")
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        # Generate the output with Qwen2-VL
        generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
        output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        
        # Store the extracted text for keyword search
        extracted_text = output_text[0]
        os.remove(temp_image_path)

        return extracted_text

    except Exception as e:
        error_message = str(e)
        traceback.print_exc()
        return f"Error: {error_message}"

def search_keywords(keyword):
    global extracted_text
    if not extracted_text:
        return "No text extracted yet. Please upload an image."
    
    # Perform basic keyword search within the extracted text
    if re.search(rf"\b{re.escape(keyword)}\b", extracted_text, re.IGNORECASE):
        highlighted_text = re.sub(rf"({re.escape(keyword)})", r"<mark>\1</mark>", extracted_text, flags=re.IGNORECASE)
        return f"Keyword found! {highlighted_text}"
    else:
        return "Keyword not found in the extracted text."

# Gradio interface
image_input = gr.Image(type="pil")
text_output = gr.Textbox(label="Extracted Text", interactive=True)
keyword_search = gr.Textbox(label="Enter keywords to search")
search_button = gr.Button("Search Keywords")
search_output = gr.HTML()

extract_button = gr.Button("Extract Text")

# Layout update
iface = gr.Interface(
    fn=ocr_and_extract,
    inputs=[image_input],
    outputs=[text_output],
    title="Image OCR with Byaldi + Qwen2-VL",
    description="Upload an image containing Hindi and English text for OCR. Then, search for specific keywords.",
)

# Keyword search layout
iface_search = gr.Interface(
    fn=search_keywords,
    inputs=[keyword_search],
    outputs=[search_output],
)

# Move extract button above the text output
def combined_interface(image, keyword):
    ocr_text = ocr_and_extract(image)
    search_result = search_keywords(keyword)
    return ocr_text, search_result

combined_iface = gr.Interface(
    fn=combined_interface,
    inputs=[image_input, keyword_search],
    outputs=[text_output, search_output],
    live=True,
    title="Image OCR & Keyword Search",
    description="Extract text from the image and search for specific keywords."
)

# Launch the app
combined_iface.launch()