demo / rag.py
Kazel's picture
Upload 5 files
aada01f verified
raw
history blame
12.3 kB
import requests
import os
import re
from typing import List
from utils import encode_image
from PIL import Image
from ollama import chat
import torch
import subprocess
import psutil
import torch
from transformers import AutoModel, AutoTokenizer
from google import genai
class Rag:
def _clean_raw_token_response(self, response_text):
"""
Clean raw token responses that contain undecoded token IDs
This handles cases where models return raw tokens instead of decoded text
"""
if not response_text:
return response_text
# Check if response contains raw token patterns
token_patterns = [
r'<unused\d+>', # unused tokens
r'<bos>', # beginning of sequence
r'<eos>', # end of sequence
r'<unk>', # unknown tokens
r'<mask>', # mask tokens
r'<pad>', # padding tokens
r'\[multimodal\]', # multimodal tokens
]
# If response contains raw tokens, try to clean them
has_raw_tokens = any(re.search(pattern, response_text) for pattern in token_patterns)
if has_raw_tokens:
print("⚠️ Detected raw token response, attempting to clean...")
# Remove common raw token patterns
cleaned_text = response_text
# Remove unused tokens
cleaned_text = re.sub(r'<unused\d+>', '', cleaned_text)
# Remove special tokens
cleaned_text = re.sub(r'<(bos|eos|unk|mask|pad)>', '', cleaned_text)
# Remove multimodal tokens
cleaned_text = re.sub(r'\[multimodal\]', '', cleaned_text)
# Clean up extra whitespace
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
# If we still have mostly tokens, return an error message
if len(cleaned_text.strip()) < 10:
return "❌ **Model Response Error**: The model returned raw token IDs instead of decoded text. This may be due to model configuration issues. Please try:\n\n1. Restarting the Ollama server\n2. Using a different model\n3. Checking model compatibility with multimodal inputs"
return cleaned_text
return response_text
def get_answer_from_gemini(self, query, imagePaths):
print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
try:
genai.configure(api_key='AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs')
model = genai.GenerativeModel('gemini-2.0-flash')
images = [Image.open(path) for path in imagePaths]
chat = model.start_chat()
response = chat.send_message([*images, query])
answer = response.text
print(answer)
return answer
except Exception as e:
print(f"An error occurred while querying Gemini: {e}")
return f"Error: {str(e)}"
#os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
def get_answer_from_openai(self, query, imagesPaths):
#import environ variables from .env
import dotenv
# Load the .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
#ollama method below
torch.cuda.empty_cache() #release cuda so that ollama can use gpu!
os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1"
if os.environ['ollama'] == "minicpm-v":
os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version
elif os.environ['ollama'] == "gemma3":
os.environ['ollama'] = "gemma3:12b" #set to upscaled version
# Add specific environment variables for Gemma3 to prevent raw token issues
os.environ['OLLAMA_KEEP_ALIVE'] = "5m"
os.environ['OLLAMA_ORIGINS'] = "*"
# Close model thread (colpali)
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
try:
# Enhanced prompt for more detailed responses with explicit page usage
enhanced_query = f"""
Please provide a comprehensive and detailed answer to the following query.
Use ALL available information from the provided document images to give a thorough response.
Query: {query}
CRITICAL INSTRUCTIONS:
- You have been provided with {len(imagesPaths)} document page(s)
- You MUST reference information from ALL {len(imagesPaths)} page(s) in your response
- Do not skip any pages - each page contains relevant information
- If you mention one page, you must also mention the others
- Ensure your response reflects the complete information from all pages
Instructions for detailed response:
1. Provide extensive background information and context
2. Include specific details, examples, and data points from ALL documents
3. Explain concepts thoroughly with step-by-step breakdowns
4. Provide comprehensive analysis rather than simple answers when requested
5. Explicitly reference each page and what information it contributes
6. Cross-reference information between pages when relevant
7. Ensure no page is left unmentioned in your analysis
SPECIAL INSTRUCTIONS FOR TABULAR DATA:
- If the query requests a table, list, or structured data, organize your response in a clear, structured format
- Use numbered lists, bullet points, or clear categories when appropriate
- Include specific data points or comparisons when available
- Structure information in a way that can be easily converted to a table format
IMPORTANT: Respond with natural, human-readable text only. Do not include any special tokens, codes, or technical identifiers in your response.
Make sure to acknowledge and use information from all {len(imagesPaths)} provided pages.
"""
# Try with current model first
current_model = os.environ['ollama']
# Set different options based on the model
if "gemma3" in current_model.lower():
# Specific options for Gemma3 to prevent raw token issues
model_options = {
"num_predict": 1024, # Shorter responses for Gemma3
"stop": ["<eos>", "<|endoftext|>", "</s>", "<|im_end|>"], # More stop tokens
"top_k": 20, # Lower top_k for more focused generation
"top_p": 0.8, # Lower top_p for more deterministic output
"repeat_penalty": 1.2, # Higher repeat penalty
"seed": 42, # Consistent results
"temperature": 0.7, # Lower temperature for more focused responses
}
else:
# Default options for other models
model_options = {
"num_predict": 2048, # Limit response length
"stop": ["<eos>", "<|endoftext|>", "</s>"], # Stop at end tokens
"top_k": 40, # Reduce randomness
"top_p": 0.9, # Nucleus sampling
"repeat_penalty": 1.1, # Prevent repetition
"seed": 42, # Consistent results
}
response = chat(
model=current_model,
messages=[
{
'role': 'user',
'content': enhanced_query,
'images': imagesPaths,
"temperature":float(os.environ['temperature']), #test if temp makes a diff
}
],
options=model_options
)
answer = response.message.content
# Clean the response to handle raw token issues
cleaned_answer = self._clean_raw_token_response(answer)
# If the cleaned answer is still problematic, try fallback models
if cleaned_answer and "❌ **Model Response Error**" in cleaned_answer:
print(f"⚠️ Primary model {current_model} failed, trying fallback models...")
# List of fallback models to try
fallback_models = [
"llama3.2-vision:latest",
"llava:latest",
"bakllava:latest",
"llama3.2:latest"
]
for fallback_model in fallback_models:
try:
print(f"πŸ”„ Trying fallback model: {fallback_model}")
response = chat(
model=fallback_model,
messages=[
{
'role': 'user',
'content': enhanced_query,
'images': imagesPaths,
"temperature":float(os.environ['temperature']),
}
],
options={
"num_predict": 2048,
"stop": ["<eos>", "<|endoftext|>", "</s>"],
"top_k": 40,
"top_p": 0.9,
"repeat_penalty": 1.1,
"seed": 42,
}
)
fallback_answer = response.message.content
cleaned_fallback = self._clean_raw_token_response(fallback_answer)
if cleaned_fallback and "❌ **Model Response Error**" not in cleaned_fallback:
print(f"βœ… Fallback model {fallback_model} succeeded")
return cleaned_fallback
except Exception as fallback_error:
print(f"❌ Fallback model {fallback_model} failed: {fallback_error}")
continue
# If all fallbacks fail, return the original error
return cleaned_answer
print(f"Original response: {answer}")
print(f"Cleaned response: {cleaned_answer}")
return cleaned_answer
except Exception as e:
print(f"An error occurred while querying OpenAI: {e}")
return None
def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
image_payload = []
for imagePath in imagesPaths:
base64_image = encode_image(imagePath)
image_payload.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
payload = {
"model": "Llama3.2-vision", #change model here as needed
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
*image_payload
]
}
],
"max_tokens": 1024 #reduce token size to reduce processing time
}
return payload
# if __name__ == "__main__":
# rag = Rag()
# query = "Based on attached images, how many new cases were reported during second wave peak"
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
# rag.get_answer_from_gemini(query, imagesPaths)