Spaces:
Running
on
Zero
Running
on
Zero
import requests | |
import os | |
import re | |
from typing import List | |
from utils import encode_image | |
from PIL import Image | |
from google import genai | |
import torch | |
import subprocess | |
import psutil | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
from google import genai | |
class Rag: | |
def _clean_raw_token_response(self, response_text): | |
""" | |
Clean raw token responses that contain undecoded token IDs | |
This handles cases where models return raw tokens instead of decoded text | |
""" | |
if not response_text: | |
return response_text | |
# Check if response contains raw token patterns | |
token_patterns = [ | |
r'<unused\d+>', # unused tokens | |
r'<bos>', # beginning of sequence | |
r'<eos>', # end of sequence | |
r'<unk>', # unknown tokens | |
r'<mask>', # mask tokens | |
r'<pad>', # padding tokens | |
r'\[multimodal\]', # multimodal tokens | |
] | |
# If response contains raw tokens, try to clean them | |
has_raw_tokens = any(re.search(pattern, response_text) for pattern in token_patterns) | |
if has_raw_tokens: | |
print("⚠️ Detected raw token response, attempting to clean...") | |
# Remove common raw token patterns | |
cleaned_text = response_text | |
# Remove unused tokens | |
cleaned_text = re.sub(r'<unused\d+>', '', cleaned_text) | |
# Remove special tokens | |
cleaned_text = re.sub(r'<(bos|eos|unk|mask|pad)>', '', cleaned_text) | |
# Remove multimodal tokens | |
cleaned_text = re.sub(r'\[multimodal\]', '', cleaned_text) | |
# Clean up extra whitespace | |
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() | |
# If we still have mostly tokens, return an error message | |
if len(cleaned_text.strip()) < 10: | |
return "❌ **Model Response Error**: The model returned raw token IDs instead of decoded text. This may be due to model configuration issues. Please try:\n\n1. Restarting the Ollama server\n2. Using a different model\n3. Checking model compatibility with multimodal inputs" | |
return cleaned_text | |
return response_text | |
def get_answer_from_gemini(self, query: str, image_paths: List[str]) -> str: | |
print(f"Querying Gemini 2.5 Pro for query={query}, image_paths={image_paths}") | |
try: | |
# Use environment variable GEMINI_API_KEY | |
api_key = os.environ.get('GEMINI_API_KEY') | |
if not api_key: | |
return "Error: GEMINI_API_KEY is not set." | |
genai.configure(api_key=api_key) | |
model = genai.GenerativeModel('gemini-2.5-pro') | |
# Load images | |
images = [] | |
for p in image_paths: | |
try: | |
images.append(Image.open(p)) | |
except Exception: | |
pass | |
chat_session = model.start_chat() | |
response = chat_session.send_message([*images, query]) | |
return response.text | |
except Exception as e: | |
print(f"Gemini error: {e}") | |
return f"Error: {str(e)}" | |
#os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work" | |
def get_answer_from_openai(self, query, imagesPaths): | |
#import environ variables from .env | |
import dotenv | |
# Load the .env file | |
dotenv_file = dotenv.find_dotenv() | |
dotenv.load_dotenv(dotenv_file) | |
# This function formerly used Ollama. Replace with Gemini 2.5 Pro. | |
print(f"Querying Gemini (replacement for Ollama) for query={query}, imagesPaths={imagesPaths}") | |
try: | |
enhanced_query = f"Use all {len(imagesPaths)} pages to answer comprehensively.\n\nQuery: {query}" | |
return self.get_answer_from_gemini(enhanced_query, imagesPaths) | |
except Exception as e: | |
print(f"Gemini replacement error: {e}") | |
return None | |
def __get_openai_api_payload(self, query:str, imagesPaths:List[str]): | |
image_payload = [] | |
for imagePath in imagesPaths: | |
base64_image = encode_image(imagePath) | |
image_payload.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_image}" | |
} | |
}) | |
payload = { | |
"model": "Llama3.2-vision", #change model here as needed | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": query | |
}, | |
*image_payload | |
] | |
} | |
], | |
"max_tokens": 1024 #reduce token size to reduce processing time | |
} | |
return payload | |
# if __name__ == "__main__": | |
# rag = Rag() | |
# query = "Based on attached images, how many new cases were reported during second wave peak" | |
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"] | |
# rag.get_answer_from_gemini(query, imagesPaths) |