Spaces:
Sleeping
Sleeping
File size: 3,560 Bytes
2f3144c 3ef82d2 6017a53 2f3144c 3ef82d2 3077ea4 2f3144c 3ef82d2 2f3144c 3ef82d2 2f3144c 6017a53 bc401e8 2f3144c 6017a53 2f3144c 6017a53 2f3144c 47c68aa 2f3144c 3ef82d2 2f3144c 3ef82d2 2f3144c 2be66f7 200c7bb 47c68aa 2be66f7 200c7bb 2be66f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import spaces
import time
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the Byaldi and Qwen2-VL models
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali") # Do not move Byaldi to GPU
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
).to(device) # Move Qwen2-VL to GPU
# Processor for Qwen2-VL
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)
@spaces.GPU # Decorate the function for GPU management
def ocr_and_extract(image, text_query):
try:
# Save the uploaded image temporarily
temp_image_path = "temp_image.jpg"
image.save(temp_image_path)
# Generate a unique index name using the current timestamp
unique_index_name = f"image_index_{int(time.time())}"
# Index the image with Byaldi
rag_model.index(
input_path=temp_image_path,
index_name=unique_index_name, # Use the unique index name
store_collection_with_index=False,
overwrite=True # Ensure the index is overwritten if it already exists
)
# Perform the search query on the indexed image
results = rag_model.search(text_query, k=1)
# Prepare the input for Qwen2-VL
image_data = Image.open(temp_image_path)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_data},
{"type": "text", "text": text_query},
],
}
]
# Process the message and prepare for Qwen2-VL
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
# Move the image inputs and processor outputs to CUDA
inputs = processor(
text=[text_input],
images=image_inputs,
padding=True,
return_tensors="pt",
).to(device)
# Generate the output with Qwen2-VL
generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# Filter out "You are a helpful assistant" and "assistant" labels
filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])]
# Clean up the temporary file
os.remove(temp_image_path)
return "\n".join(filtered_output).strip()
except Exception as e:
error_message = str(e)
traceback.print_exc()
return f"Error: {error_message}"
# Gradio interface for image input
iface = gr.Interface(
fn=ocr_and_extract,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Enter your query (optional)"),
],
outputs="text",
title="Image OCR with Byaldi + Qwen2-VL",
description="Upload an image (JPEG/PNG) containing Hindi and English text for OCR.",
)
# Launch the Gradio app
iface.launch()
|