Swekerr commited on
Commit
5e1aec2
·
verified ·
1 Parent(s): 591cd15

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import json
4
+ from byaldi import RAGMultiModalModel
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+ import torch
8
+
9
+ # Load models
10
+ def load_models():
11
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
12
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct",
13
+ trust_remote_code=True, torch_dtype=torch.float32) # Change to float32 for CPU
14
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
15
+ return RAG, model, processor
16
+
17
+ RAG, model, processor = load_models()
18
+
19
+ # Function for OCR and search
20
+ def ocr_and_search(image, keyword):
21
+ # Hardcoded query to extract text in English, Sanskrit, and Hindi
22
+ text_query = "Extract all the text in English, Sanskrit, and Hindi from the image."
23
+ results = RAG.search(text_query, k=1)
24
+
25
+ if not results: # Check if results are empty
26
+ return "No results found for the given query.", [], "{}"
27
+
28
+ # Prepare message for Qwen model
29
+ messages = [
30
+ {
31
+ "role": "user",
32
+ "content": [
33
+ {"type": "image", "image": image},
34
+ {"type": "text", "text": text_query},
35
+ ],
36
+ }
37
+ ]
38
+
39
+ # Process the image
40
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
41
+ image_inputs, video_inputs = process_vision_info(messages)
42
+ inputs = processor(
43
+ text=[text],
44
+ images=image_inputs,
45
+ videos=video_inputs,
46
+ padding=True,
47
+ return_tensors="pt",
48
+ ).to("cpu") # Use CPU
49
+
50
+ # Generate text
51
+ with torch.no_grad():
52
+ generated_ids = model.generate(**inputs, max_new_tokens=50)
53
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
54
+ extracted_text = processor.batch_decode(
55
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
56
+ )[0]
57
+
58
+ # Save extracted text to JSON
59
+ output_json = {"query": text_query, "extracted_text": extracted_text}
60
+ json_output = json.dumps(output_json, ensure_ascii=False, indent=4)
61
+
62
+ # Perform keyword search
63
+ keyword_lower = keyword.lower()
64
+ sentences = extracted_text.split('. ')
65
+ matched_sentences = [sentence for sentence in sentences if keyword_lower in sentence.lower()]
66
+
67
+ return extracted_text, matched_sentences, json_output
68
+
69
+ # Gradio App function
70
+ def app(image, keyword):
71
+ # Call OCR and search function
72
+ extracted_text, search_results, json_output = ocr_and_search(image, keyword)
73
+
74
+ search_results_str = "\n".join(search_results) if search_results else "No matches found."
75
+
76
+ return extracted_text, search_results_str, json_output
77
+
78
+ # Gradio Interface
79
+ iface = gr.Interface(
80
+ fn=app,
81
+ inputs=[
82
+ gr.inputs.Image(type="pil", label="Upload an Image"),
83
+ gr.inputs.Textbox(label="Enter keyword to search in extracted text", default="")
84
+ ],
85
+ outputs=[
86
+ gr.outputs.Textbox(label="Extracted Text"),
87
+ gr.outputs.Textbox(label="Search Results"),
88
+ gr.outputs.JSON(label="JSON Output")
89
+ ],
90
+ title="OCR and Keyword Search in Images",
91
+ )
92
+
93
+ # Launch Gradio App
94
+ iface.launch()