File size: 3,560 Bytes
2f3144c
 
 
 
 
 
 
 
3ef82d2
6017a53
2f3144c
3ef82d2
 
 
 
 
3077ea4
2f3144c
3ef82d2
 
2f3144c
 
3ef82d2
2f3144c
 
 
 
 
 
 
 
6017a53
 
bc401e8
2f3144c
 
 
6017a53
2f3144c
6017a53
2f3144c
 
 
47c68aa
2f3144c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ef82d2
2f3144c
 
 
 
 
3ef82d2
2f3144c
 
 
 
2be66f7
 
 
200c7bb
 
47c68aa
2be66f7
 
 
200c7bb
2be66f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import os
import traceback
import spaces
import time

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the Byaldi and Qwen2-VL models
rag_model = RAGMultiModalModel.from_pretrained("vidore/colpali")  # Do not move Byaldi to GPU
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16
).to(device)  # Move Qwen2-VL to GPU

# Processor for Qwen2-VL
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True)

@spaces.GPU  # Decorate the function for GPU management
def ocr_and_extract(image, text_query):
    try:
        # Save the uploaded image temporarily
        temp_image_path = "temp_image.jpg"
        image.save(temp_image_path)

        # Generate a unique index name using the current timestamp
        unique_index_name = f"image_index_{int(time.time())}"

        # Index the image with Byaldi
        rag_model.index(
            input_path=temp_image_path,
            index_name=unique_index_name,  # Use the unique index name
            store_collection_with_index=False,
            overwrite=True  # Ensure the index is overwritten if it already exists
        )

        # Perform the search query on the indexed image
        results = rag_model.search(text_query, k=1)

        # Prepare the input for Qwen2-VL
        image_data = Image.open(temp_image_path)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image_data},
                    {"type": "text", "text": text_query},
                ],
            }
        ]

        # Process the message and prepare for Qwen2-VL
        text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, _ = process_vision_info(messages)

        # Move the image inputs and processor outputs to CUDA
        inputs = processor(
            text=[text_input],
            images=image_inputs,
            padding=True,
            return_tensors="pt",
        ).to(device)

        # Generate the output with Qwen2-VL
        generated_ids = qwen_model.generate(**inputs, max_new_tokens=50)
        output_text = processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        # Filter out "You are a helpful assistant" and "assistant" labels
        filtered_output = [line for line in output_text[0].split("\n") if not any(kw in line.lower() for kw in ["you are a helpful assistant", "assistant", "user", "system"])]

        # Clean up the temporary file
        os.remove(temp_image_path)

        return "\n".join(filtered_output).strip()

    except Exception as e:
        error_message = str(e)
        traceback.print_exc()
        return f"Error: {error_message}"

# Gradio interface for image input
iface = gr.Interface(
    fn=ocr_and_extract,
    inputs=[
        gr.Image(type="pil"),
        gr.Textbox(label="Enter your query (optional)"),
    ],
    outputs="text",
    title="Image OCR with Byaldi + Qwen2-VL",
    description="Upload an image (JPEG/PNG) containing Hindi and English text for OCR.",
)

# Launch the Gradio app
iface.launch()