Spaces:
Sleeping
Sleeping
File size: 4,303 Bytes
2f3144c d85fa29 2f3144c d85fa29 7508b05 2f3144c 3ef82d2 7508b05 3ef82d2 2f3144c d85fa29 cacc570 d85fa29 cacc570 2f3144c d85fa29 2f3144c d85fa29 7508b05 2f3144c d85fa29 2f3144c d85fa29 2f3144c d85fa29 2f3144c d85fa29 2f3144c d85fa29 2f3144c 7508b05 d85fa29 2f3144c 7508b05 d85fa29 7508b05 2be66f7 cacc570 2be66f7 d85fa29 2be66f7 d85fa29 cacc570 d85fa29 cacc570 d85fa29 7508b05 d85fa29 7508b05 d85fa29 7508b05 d85fa29 7508b05 d85fa29 7508b05 d85fa29 7508b05 d85fa29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import re
# Load models
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
extracted_text = "" # Store the extracted text globally for keyword search
def ocr_and_extract(image, text_query=None):
global extracted_text
try:
# Save the uploaded image temporarily
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
# Index the image with Byaldi
rag_model.index(
input_path=temp_image_path,
index_name="image_index",
store_collection_with_index=False,
overwrite=True
)
# Perform the search query on the indexed image
results = rag_model.search(text_query, k=1)
# Prepare the input for Qwen2-VL
image_data = Image.open(temp_image_path)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_data},
{"type": "text", "text": text_query},
],
}
]
# Process input for Qwen2-VL
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = processor(
text=[text_input],
images=image_inputs,
padding=True,
return_tensors="pt",
)
qwen_model.to("cuda")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Generate the output with Qwen2-VL
generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
# Store the extracted text for keyword search
extracted_text = output_text[0]
os.remove(temp_image_path)
return extracted_text
except Exception as e:
error_message = str(e)
traceback.print_exc()
return f"Error: {error_message}"
def search_keywords(keyword):
global extracted_text
if not extracted_text:
return "No text extracted yet. Please upload an image."
# Perform basic keyword search within the extracted text
if re.search(rf"\b{re.escape(keyword)}\b", extracted_text, re.IGNORECASE):
highlighted_text = re.sub(rf"({re.escape(keyword)})", r"<mark>\1</mark>", extracted_text, flags=re.IGNORECASE)
return f"Keyword found! {highlighted_text}"
else:
return "Keyword not found in the extracted text."
# Gradio interface
image_input = gr.Image(type="pil")
text_output = gr.Textbox(label="Extracted Text", interactive=True)
keyword_search = gr.Textbox(label="Enter keywords to search")
search_button = gr.Button("Search Keywords")
search_output = gr.HTML()
extract_button = gr.Button("Extract Text")
# Layout update
iface = gr.Interface(
fn=ocr_and_extract,
inputs=[image_input],
outputs=[text_output],
title="Image OCR with Byaldi + Qwen2-VL",
description="Upload an image containing Hindi and English text for OCR. Then, search for specific keywords.",
)
# Keyword search layout
iface_search = gr.Interface(
fn=search_keywords,
inputs=[keyword_search],
outputs=[search_output],
)
# Move extract button above the text output
def combined_interface(image, keyword):
ocr_text = ocr_and_extract(image)
search_result = search_keywords(keyword)
return ocr_text, search_result
combined_iface = gr.Interface(
fn=combined_interface,
inputs=[image_input, keyword_search],
outputs=[text_output, search_output],
live=True,
title="Image OCR & Keyword Search",
description="Extract text from the image and search for specific keywords."
)
# Launch the app
combined_iface.launch()
|