profchaos commited on
Commit
e799756
·
verified ·
1 Parent(s): 03a03be

Upload ocr_using_qwenvl_by_ps.py

Browse files
Files changed (1) hide show
  1. ocr_using_qwenvl_by_ps.py +119 -0
ocr_using_qwenvl_by_ps.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """OCR_using_QwenVL_by_PS.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1lNLVl8FzVRrSv4dMd9vXqnz8SYtKoebf
8
+ """
9
+
10
+ # Import libraries
11
+ import cv2
12
+ from PIL import Image
13
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
14
+ import torch
15
+ from byaldi import RAGMultiModalModel
16
+ from google.colab import files
17
+ from IPython.display import display, HTML
18
+ import os
19
+ import re
20
+
21
+ # to detect cuda(GPU)
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ print("Using device:", device)
24
+
25
+ #loading models
26
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali", verbose=0)
27
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
28
+ "Qwen/Qwen2-VL-2B-Instruct",
29
+ torch_dtype=torch.float16,
30
+ device_map="auto"
31
+ )
32
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
33
+
34
+ torch.cuda.empty_cache()
35
+
36
+ #Upload image
37
+ # def upload_image():
38
+ # uploaded = files.upload()
39
+ # for filename in uploaded.keys():
40
+ # print(f'Uploaded file: {filename}')
41
+ # return filename
42
+
43
+ # image_path = upload_image()
44
+
45
+ # Preprocessing using OpenCV
46
+ def preprocess_image(image_path):
47
+ image = cv2.imread(image_path)
48
+ if image is None:
49
+ raise FileNotFoundError(f"Image not found at the path: {image_path}")
50
+
51
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
52
+
53
+ # Maintain aspect ratio
54
+ height, width = gray.shape
55
+ if height > width:
56
+ new_height = 1024
57
+ new_width = int((width / height) * new_height)
58
+ else:
59
+ new_width = 1024
60
+ new_height = int((height / width) * new_width)
61
+
62
+ resized_image = cv2.resize(gray, (new_width, new_height))
63
+
64
+ blurred = cv2.GaussianBlur(resized_image, (5, 5), 0)
65
+ thresholded = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
66
+ denoised = cv2.fastNlMeansDenoising(thresholded, h=30)
67
+ pil_image = Image.fromarray(denoised)
68
+
69
+ return pil_image
70
+
71
+ # Call the function and store the result
72
+ # pil_image = preprocess_image(image_path)
73
+
74
+ # display(pil_image) # Now pil_image is accessible here
75
+
76
+ #extract the text
77
+ def extract_text(image_path):
78
+ try:
79
+ processed_image = preprocess_image(image_path)
80
+ messages = [
81
+ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "PLease extract the both hindi and english text as they appear in the image"}]}
82
+ ]
83
+ text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
84
+ inputs = processor(text=[text_prompt], images=[processed_image], padding=True, return_tensors="pt").to(device)
85
+ output_ids = model.generate(**inputs, max_new_tokens=1042)
86
+ generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
87
+ extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
88
+ return extracted_text
89
+ except Exception as e:
90
+ return f"An error occurred during text extraction: {e}"
91
+
92
+ #keyword searching
93
+ def keyword_search(extracted_text, keywords):
94
+ if not keywords:
95
+ return extracted_text, "Please enter a keyword to search and highlight."
96
+ keywords = [keyword.strip() for keyword in keywords.split(",") if keyword.strip()]
97
+
98
+ highlighted_text = ""
99
+
100
+ lines = extracted_text.split('\n')
101
+ for line in lines:
102
+ for keyword in keywords:
103
+ pattern = re.compile(re.escape(keyword), re.IGNORECASE)
104
+ line = pattern.sub(lambda m: f'<span style="color: red;">{m.group()}</span>', line)
105
+ highlighted_text += line + '\n'
106
+ return highlighted_text
107
+
108
+ #OCR and keyword search interface
109
+ def ocr_interface(image):
110
+ image_path = "temp_image.png"
111
+ image.save(image_path)
112
+ extracted_text = extract_text(image_path)
113
+ os.remove(image_path)
114
+
115
+ return extracted_text, ""
116
+ def keyword_interface(extracted_text, keywords):
117
+ highlighted_text = keyword_search(extracted_text, keywords)
118
+ return highlighted_text
119
+