MilanM commited on
Commit
c634380
·
verified ·
1 Parent(s): 8dfc6ff

Create vision_llm_text_extraction.py

Browse files
new_templates/vision_llm_text_extraction.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extract_text_from_images_deployable():
2
+ """
3
+ Deployable watsonx.ai function that extracts text from multiple images/PDFs using foundation models.
4
+
5
+ Expected input payload:
6
+ {
7
+ "input_data": [{
8
+ "values": [["<image_url_1>", "<image_url_2>", ...], ["<optional_extraction_prompt>"]]
9
+ }]
10
+ }
11
+
12
+ Returns:
13
+ {
14
+ "predictions": [{
15
+ "fields": ["extracted_texts"],
16
+ "values": [[["<extracted_text_1>", "<extracted_text_2>", ...]]]
17
+ }]
18
+ }
19
+ """
20
+
21
+ import mimetypes
22
+ import base64
23
+ import requests
24
+ from urllib.parse import urlparse
25
+ import fitz
26
+ from ibm_watsonx_ai import APIClient, Credentials
27
+ from ibm_watsonx_ai.foundation_models import ModelInference
28
+
29
+ # Initialize watsonx client (these should be set as environment variables)
30
+ import os
31
+ WX_URL = os.getenv('WX_URL', "")
32
+ WX_APIKEY = os.getenv('WX_APIKEY', "")
33
+ PROJECT_ID = os.getenv('PROJECT_ID', "")
34
+ CHAT_MODEL = os.getenv('CHAT_MODEL', 'mistralai/mistral-medium-2505')
35
+
36
+ DEFAULT_EXTRACTION_PROMPT = '''Extract all text within the image in a markdown form as close as possible to the original, free of any additional outputs that are not in the text, including descriptions of the element, comments about making outputs, etc.'''
37
+
38
+ wx_credentials = Credentials(
39
+ url=WX_URL,
40
+ api_key=WX_APIKEY
41
+ )
42
+ client = APIClient(credentials=wx_credentials, project_id=PROJECT_ID)
43
+
44
+ def create_data_url(source, filename=None):
45
+ """Create data URL from bytes, file path, or URL. Returns list for PDFs."""
46
+
47
+ if isinstance(source, str) and source.startswith(('http://', 'https://')):
48
+ content = requests.get(source).content
49
+ filename = filename or urlparse(source).path.split('/')[-1] or 'file'
50
+ elif isinstance(source, str):
51
+ with open(source, 'rb') as f:
52
+ content = f.read()
53
+ filename = filename or source
54
+ else:
55
+ content = source
56
+ if not filename:
57
+ raise ValueError("filename required for bytes input")
58
+
59
+ mime_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
60
+
61
+ if mime_type == 'application/pdf':
62
+ doc = fitz.open(stream=content, filetype="pdf")
63
+ result = []
64
+ for page in doc:
65
+ pix = page.get_pixmap(matrix=fitz.Matrix(1.5, 1.5))
66
+ img_data = pix.tobytes("png")
67
+ encoded = base64.b64encode(img_data).decode('utf-8')
68
+ result.append(f"data:image/png;base64,{encoded}")
69
+ doc.close()
70
+ return result
71
+
72
+ encoded = base64.b64encode(content).decode('utf-8')
73
+ return f"data:{mime_type};base64,{encoded}"
74
+
75
+ def score(payload):
76
+ """
77
+ Score function called for each prediction request.
78
+
79
+ Args:
80
+ payload: Input payload containing list of image URLs/paths and optional extraction prompt
81
+
82
+ Returns:
83
+ Dictionary with predictions containing list of extracted texts
84
+ """
85
+ try:
86
+ # Extract input data from payload
87
+ input_values = payload.get("input_data")[0].get("values")
88
+ image_urls = input_values[0] # List of URLs
89
+ extraction_prompt = input_values[1] if len(input_values) > 1 else DEFAULT_EXTRACTION_PROMPT
90
+
91
+ # Model parameters
92
+ params = {
93
+ "temperature": 1.0,
94
+ "max_tokens": 6553,
95
+ "top_p": 1.0,
96
+ "stop": [
97
+ "</s>",
98
+ "<|end_of_text|>"
99
+ ]
100
+ }
101
+
102
+ extracted_texts = []
103
+
104
+ # Process each image URL
105
+ for image_url in image_urls:
106
+ # Convert image to data URL
107
+ image_data_url = create_data_url(image_url)
108
+
109
+ # Handle PDF case (multiple pages)
110
+ if isinstance(image_data_url, list):
111
+ all_extracted_text = []
112
+ for page_num, page_url in enumerate(image_data_url):
113
+ messages = [
114
+ {
115
+ "role": "user",
116
+ "content": [
117
+ {
118
+ "type": "text",
119
+ "text": f"Page {page_num + 1}:\n{extraction_prompt}"
120
+ },
121
+ {
122
+ "type": "image_url",
123
+ "image_url": {
124
+ "url": page_url,
125
+ }
126
+ }
127
+ ]
128
+ }
129
+ ]
130
+
131
+ chat_model = ModelInference(api_client=client, model_id=CHAT_MODEL, params=params)
132
+ model_response = chat_model.chat(messages=messages)
133
+ page_text = model_response["choices"][0]["message"]["content"]
134
+ all_extracted_text.append(f"## Page {page_num + 1}\n\n{page_text}")
135
+
136
+ extracted_text = "\n\n".join(all_extracted_text)
137
+ else:
138
+ # Single image case
139
+ messages = [
140
+ {
141
+ "role": "user",
142
+ "content": [
143
+ {
144
+ "type": "text",
145
+ "text": extraction_prompt
146
+ },
147
+ {
148
+ "type": "image_url",
149
+ "image_url": {
150
+ "url": image_data_url,
151
+ }
152
+ }
153
+ ]
154
+ }
155
+ ]
156
+
157
+ chat_model = ModelInference(api_client=client, model_id=CHAT_MODEL, params=params)
158
+ model_response = chat_model.chat(messages=messages)
159
+ extracted_text = model_response["choices"][0]["message"]["content"]
160
+
161
+ extracted_texts.append(extracted_text)
162
+
163
+ # Return in required format
164
+ return {
165
+ 'predictions': [{
166
+ 'fields': ['extracted_texts'],
167
+ 'values': [extracted_texts]
168
+ }]
169
+ }
170
+
171
+ except Exception as e:
172
+ # Return error in predictions format
173
+ return {
174
+ 'predictions': [{
175
+ 'fields': ['extracted_texts', 'error'],
176
+ 'values': [[], str(e)]
177
+ }]
178
+ }
179
+
180
+ return score