demo-updated / rag.py
Kazel's picture
start
0400df3
raw
history blame
5.41 kB
import requests
import os
import re
from typing import List
from utils import encode_image
from PIL import Image
from google import genai
import torch
import subprocess
import psutil
import torch
from transformers import AutoModel, AutoTokenizer
from google import genai
class Rag:
def _clean_raw_token_response(self, response_text):
"""
Clean raw token responses that contain undecoded token IDs
This handles cases where models return raw tokens instead of decoded text
"""
if not response_text:
return response_text
# Check if response contains raw token patterns
token_patterns = [
r'<unused\d+>', # unused tokens
r'<bos>', # beginning of sequence
r'<eos>', # end of sequence
r'<unk>', # unknown tokens
r'<mask>', # mask tokens
r'<pad>', # padding tokens
r'\[multimodal\]', # multimodal tokens
]
# If response contains raw tokens, try to clean them
has_raw_tokens = any(re.search(pattern, response_text) for pattern in token_patterns)
if has_raw_tokens:
print("⚠️ Detected raw token response, attempting to clean...")
# Remove common raw token patterns
cleaned_text = response_text
# Remove unused tokens
cleaned_text = re.sub(r'<unused\d+>', '', cleaned_text)
# Remove special tokens
cleaned_text = re.sub(r'<(bos|eos|unk|mask|pad)>', '', cleaned_text)
# Remove multimodal tokens
cleaned_text = re.sub(r'\[multimodal\]', '', cleaned_text)
# Clean up extra whitespace
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
# If we still have mostly tokens, return an error message
if len(cleaned_text.strip()) < 10:
return "❌ **Model Response Error**: The model returned raw token IDs instead of decoded text. This may be due to model configuration issues. Please try:\n\n1. Restarting the Ollama server\n2. Using a different model\n3. Checking model compatibility with multimodal inputs"
return cleaned_text
return response_text
def get_answer_from_gemini(self, query: str, image_paths: List[str]) -> str:
print(f"Querying Gemini 2.5 Pro for query={query}, image_paths={image_paths}")
try:
# Use environment variable GEMINI_API_KEY
api_key = os.environ.get('GEMINI_API_KEY')
if not api_key:
return "Error: GEMINI_API_KEY is not set."
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro')
# Load images
images = []
for p in image_paths:
try:
images.append(Image.open(p))
except Exception:
pass
chat_session = model.start_chat()
response = chat_session.send_message([*images, query])
return response.text
except Exception as e:
print(f"Gemini error: {e}")
return f"Error: {str(e)}"
#os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
def get_answer_from_openai(self, query, imagesPaths):
#import environ variables from .env
import dotenv
# Load the .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
# This function formerly used Ollama. Replace with Gemini 2.5 Pro.
print(f"Querying Gemini (replacement for Ollama) for query={query}, imagesPaths={imagesPaths}")
try:
enhanced_query = f"Use all {len(imagesPaths)} pages to answer comprehensively.\n\nQuery: {query}"
return self.get_answer_from_gemini(enhanced_query, imagesPaths)
except Exception as e:
print(f"Gemini replacement error: {e}")
return None
def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
image_payload = []
for imagePath in imagesPaths:
base64_image = encode_image(imagePath)
image_payload.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
})
payload = {
"model": "Llama3.2-vision", #change model here as needed
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
*image_payload
]
}
],
"max_tokens": 1024 #reduce token size to reduce processing time
}
return payload
# if __name__ == "__main__":
# rag = Rag()
# query = "Based on attached images, how many new cases were reported during second wave peak"
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
# rag.get_answer_from_gemini(query, imagesPaths)