File size: 12,279 Bytes
065b6ad
 
aada01f
065b6ad
 
 
 
aada01f
d9fa664
 
 
 
 
aada01f
065b6ad
 
 
 
aada01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
065b6ad
aada01f
065b6ad
 
 
 
aada01f
 
065b6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9fa664
d8688d6
d9fa664
 
 
aada01f
d9fa664
 
065b6ad
 
 
d9fa664
 
 
 
 
 
aada01f
 
 
 
 
065b6ad
 
 
 
 
 
 
aada01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9fa664
aada01f
d9fa664
065b6ad
 
aada01f
065b6ad
d9fa664
065b6ad
 
aada01f
065b6ad
 
 
aada01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
065b6ad
aada01f
 
065b6ad
aada01f
065b6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aada01f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import requests
import os
import re

from typing import List
from utils import encode_image
from PIL import Image
from ollama import chat
import torch
import subprocess
import psutil
import torch
from transformers import AutoModel, AutoTokenizer
from google import genai


class Rag:
    
    def _clean_raw_token_response(self, response_text):
        """
        Clean raw token responses that contain undecoded token IDs
        This handles cases where models return raw tokens instead of decoded text
        """
        if not response_text:
            return response_text
            
        # Check if response contains raw token patterns
        token_patterns = [
            r'<unused\d+>',  # unused tokens
            r'<bos>',        # beginning of sequence
            r'<eos>',        # end of sequence
            r'<unk>',        # unknown tokens
            r'<mask>',       # mask tokens
            r'<pad>',        # padding tokens
            r'\[multimodal\]', # multimodal tokens
        ]
        
        # If response contains raw tokens, try to clean them
        has_raw_tokens = any(re.search(pattern, response_text) for pattern in token_patterns)
        
        if has_raw_tokens:
            print("⚠️  Detected raw token response, attempting to clean...")
            
            # Remove common raw token patterns
            cleaned_text = response_text
            
            # Remove unused tokens
            cleaned_text = re.sub(r'<unused\d+>', '', cleaned_text)
            
            # Remove special tokens
            cleaned_text = re.sub(r'<(bos|eos|unk|mask|pad)>', '', cleaned_text)
            
            # Remove multimodal tokens
            cleaned_text = re.sub(r'\[multimodal\]', '', cleaned_text)
            
            # Clean up extra whitespace
            cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
            
            # If we still have mostly tokens, return an error message
            if len(cleaned_text.strip()) < 10:
                return "❌ **Model Response Error**: The model returned raw token IDs instead of decoded text. This may be due to model configuration issues. Please try:\n\n1. Restarting the Ollama server\n2. Using a different model\n3. Checking model compatibility with multimodal inputs"
            
            return cleaned_text
        
        return response_text
    
    def get_answer_from_gemini(self, query, imagePaths):
        

        print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")

        try:
            genai.configure(api_key='AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs')
            model = genai.GenerativeModel('gemini-2.0-flash')
            
            images = [Image.open(path) for path in imagePaths]

            chat = model.start_chat()

            response = chat.send_message([*images, query])

            answer = response.text

            print(answer)
            
            return answer
        
        except Exception as e:
            print(f"An error occurred while querying Gemini: {e}")
            return f"Error: {str(e)}"
    
    #os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
    
    def get_answer_from_openai(self, query, imagesPaths):
        #import environ variables from .env
        import dotenv

        # Load the .env file
        dotenv_file = dotenv.find_dotenv()
        dotenv.load_dotenv(dotenv_file)
        
        #ollama method below
        
        torch.cuda.empty_cache() #release cuda so that ollama can use gpu!

    
        os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1"
        if os.environ['ollama'] == "minicpm-v":
            os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version
        elif os.environ['ollama'] == "gemma3":
            os.environ['ollama'] = "gemma3:12b" #set to upscaled version
            # Add specific environment variables for Gemma3 to prevent raw token issues
            os.environ['OLLAMA_KEEP_ALIVE'] = "5m"
            os.environ['OLLAMA_ORIGINS'] = "*"
        

        # Close model thread (colpali)
        print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")

        try:    
            
            # Enhanced prompt for more detailed responses with explicit page usage
            enhanced_query = f"""
            Please provide a comprehensive and detailed answer to the following query. 
            Use ALL available information from the provided document images to give a thorough response.
            
            Query: {query}
            
            CRITICAL INSTRUCTIONS:
            - You have been provided with {len(imagesPaths)} document page(s)
            - You MUST reference information from ALL {len(imagesPaths)} page(s) in your response
            - Do not skip any pages - each page contains relevant information
            - If you mention one page, you must also mention the others
            - Ensure your response reflects the complete information from all pages
            
            Instructions for detailed response:
            1. Provide extensive background information and context
            2. Include specific details, examples, and data points from ALL documents
            3. Explain concepts thoroughly with step-by-step breakdowns 
            4. Provide comprehensive analysis rather than simple answers when requested
            5. Explicitly reference each page and what information it contributes
            6. Cross-reference information between pages when relevant
            7. Ensure no page is left unmentioned in your analysis
            
            SPECIAL INSTRUCTIONS FOR TABULAR DATA:
            - If the query requests a table, list, or structured data, organize your response in a clear, structured format
            - Use numbered lists, bullet points, or clear categories when appropriate
            - Include specific data points or comparisons when available
            - Structure information in a way that can be easily converted to a table format
            
            IMPORTANT: Respond with natural, human-readable text only. Do not include any special tokens, codes, or technical identifiers in your response.
            
            Make sure to acknowledge and use information from all {len(imagesPaths)} provided pages.
            """
            
            # Try with current model first
            current_model = os.environ['ollama']
            
            # Set different options based on the model
            if "gemma3" in current_model.lower():
                # Specific options for Gemma3 to prevent raw token issues
                model_options = {
                    "num_predict": 1024,  # Shorter responses for Gemma3
                    "stop": ["<eos>", "<|endoftext|>", "</s>", "<|im_end|>"],  # More stop tokens
                    "top_k": 20,  # Lower top_k for more focused generation
                    "top_p": 0.8,  # Lower top_p for more deterministic output
                    "repeat_penalty": 1.2,  # Higher repeat penalty
                    "seed": 42,  # Consistent results
                    "temperature": 0.7,  # Lower temperature for more focused responses
                }
            else:
                # Default options for other models
                model_options = {
                    "num_predict": 2048,  # Limit response length
                    "stop": ["<eos>", "<|endoftext|>", "</s>"],  # Stop at end tokens
                    "top_k": 40,  # Reduce randomness
                    "top_p": 0.9,  # Nucleus sampling
                    "repeat_penalty": 1.1,  # Prevent repetition
                    "seed": 42,  # Consistent results
                }
            
            response = chat(  
                    model=current_model,
                    messages=[
                    {
                    'role': 'user',
                    'content': enhanced_query,
                    'images': imagesPaths,
                    "temperature":float(os.environ['temperature']), #test if temp makes a diff
                    }
                ],
                options=model_options
                )
    
            answer = response.message.content
            
            # Clean the response to handle raw token issues
            cleaned_answer = self._clean_raw_token_response(answer)
            
            # If the cleaned answer is still problematic, try fallback models
            if cleaned_answer and "❌ **Model Response Error**" in cleaned_answer:
                print(f"⚠️  Primary model {current_model} failed, trying fallback models...")
                
                # List of fallback models to try
                fallback_models = [
                    "llama3.2-vision:latest",
                    "llava:latest", 
                    "bakllava:latest",
                    "llama3.2:latest"
                ]
                
                for fallback_model in fallback_models:
                    try:
                        print(f"πŸ”„ Trying fallback model: {fallback_model}")
                        response = chat(
                            model=fallback_model,
                            messages=[
                            {
                            'role': 'user',
                            'content': enhanced_query,
                            'images': imagesPaths,
                            "temperature":float(os.environ['temperature']),
                            }
                        ],
                        options={
                            "num_predict": 2048,
                            "stop": ["<eos>", "<|endoftext|>", "</s>"],
                            "top_k": 40,
                            "top_p": 0.9,
                            "repeat_penalty": 1.1,
                            "seed": 42,
                        }
                        )
                        
                        fallback_answer = response.message.content
                        cleaned_fallback = self._clean_raw_token_response(fallback_answer)
                        
                        if cleaned_fallback and "❌ **Model Response Error**" not in cleaned_fallback:
                            print(f"βœ… Fallback model {fallback_model} succeeded")
                            return cleaned_fallback
                            
                    except Exception as fallback_error:
                        print(f"❌ Fallback model {fallback_model} failed: {fallback_error}")
                        continue
                
                # If all fallbacks fail, return the original error
                return cleaned_answer
    
            print(f"Original response: {answer}")
            print(f"Cleaned response: {cleaned_answer}")
    
            return cleaned_answer
    
        except Exception as e:
            print(f"An error occurred while querying OpenAI: {e}")
            return None
        


    def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
        image_payload = []

        for imagePath in imagesPaths:
            base64_image = encode_image(imagePath)
            image_payload.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                }
            })

        payload = {
            "model": "Llama3.2-vision", #change model here as needed
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": query
                        },
                        *image_payload
                    ]
                }
            ],
            "max_tokens": 1024 #reduce token size to reduce processing time
        }

        return payload
    


# if __name__ == "__main__":
#     rag = Rag()
    
#     query = "Based on attached images, how many new cases were reported during second wave peak"
#     imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
    
#     rag.get_answer_from_gemini(query, imagesPaths)