Spaces:
Running
on
Zero
Running
on
Zero
import requests | |
import os | |
import re | |
from typing import List | |
from utils import encode_image | |
from PIL import Image | |
from ollama import chat | |
import torch | |
import subprocess | |
import psutil | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
from google import genai | |
class Rag: | |
def _clean_raw_token_response(self, response_text): | |
""" | |
Clean raw token responses that contain undecoded token IDs | |
This handles cases where models return raw tokens instead of decoded text | |
""" | |
if not response_text: | |
return response_text | |
# Check if response contains raw token patterns | |
token_patterns = [ | |
r'<unused\d+>', # unused tokens | |
r'<bos>', # beginning of sequence | |
r'<eos>', # end of sequence | |
r'<unk>', # unknown tokens | |
r'<mask>', # mask tokens | |
r'<pad>', # padding tokens | |
r'\[multimodal\]', # multimodal tokens | |
] | |
# If response contains raw tokens, try to clean them | |
has_raw_tokens = any(re.search(pattern, response_text) for pattern in token_patterns) | |
if has_raw_tokens: | |
print("β οΈ Detected raw token response, attempting to clean...") | |
# Remove common raw token patterns | |
cleaned_text = response_text | |
# Remove unused tokens | |
cleaned_text = re.sub(r'<unused\d+>', '', cleaned_text) | |
# Remove special tokens | |
cleaned_text = re.sub(r'<(bos|eos|unk|mask|pad)>', '', cleaned_text) | |
# Remove multimodal tokens | |
cleaned_text = re.sub(r'\[multimodal\]', '', cleaned_text) | |
# Clean up extra whitespace | |
cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip() | |
# If we still have mostly tokens, return an error message | |
if len(cleaned_text.strip()) < 10: | |
return "β **Model Response Error**: The model returned raw token IDs instead of decoded text. This may be due to model configuration issues. Please try:\n\n1. Restarting the Ollama server\n2. Using a different model\n3. Checking model compatibility with multimodal inputs" | |
return cleaned_text | |
return response_text | |
def get_answer_from_gemini(self, query, imagePaths): | |
print(f"Querying Gemini for query={query}, imagePaths={imagePaths}") | |
try: | |
client = genai.Client(api_key="AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs") | |
images = [Image.open(path) for path in imagePaths] | |
response = client.models.generate_content( | |
model="gemini-2.5-pro", | |
contents=[images, query], | |
) | |
print(response.text) | |
answer = response.text | |
return answer | |
except Exception as e: | |
print(f"An error occurred while querying Gemini: {e}") | |
return f"Error: {str(e)}" | |
#os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work" | |
def get_answer_from_openai(self, query, imagesPaths): | |
#import environ variables from .env | |
import dotenv | |
# Load the .env file | |
dotenv_file = dotenv.find_dotenv() | |
dotenv.load_dotenv(dotenv_file) | |
#ollama method below | |
torch.cuda.empty_cache() #release cuda so that ollama can use gpu! | |
os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1" | |
if os.environ['ollama'] == "minicpm-v": | |
os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version | |
elif os.environ['ollama'] == "gemma3": | |
os.environ['ollama'] = "gemma3:12b" #set to upscaled version 12b when needed | |
# Add specific environment variables for Gemma3 to prevent raw token issues | |
os.environ['OLLAMA_KEEP_ALIVE'] = "5m" | |
os.environ['OLLAMA_ORIGINS'] = "*" | |
# Close model thread (colpali) | |
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}") | |
try: | |
# Enhanced prompt for more detailed responses with explicit page usage | |
enhanced_query = f""" | |
Please provide a comprehensive and detailed answer to the following query. | |
Use ALL available information from the provided document images to give a thorough response. | |
Query: {query} | |
CRITICAL INSTRUCTIONS: | |
- You have been provided with {len(imagesPaths)} document page(s) | |
- You MUST reference information from ALL {len(imagesPaths)} page(s) in your response | |
- Do not skip any pages - each page contains relevant information | |
- If you mention one page, you must also mention the others | |
- Ensure your response reflects the complete information from all pages | |
Instructions for detailed response: | |
1. Provide extensive background information and context | |
2. Include specific details, examples, and data points from ALL documents | |
3. Explain concepts thoroughly with step-by-step breakdowns | |
4. Provide comprehensive analysis rather than simple answers when requested | |
5. Explicitly reference each page and what information it contributes | |
6. Cross-reference information between pages when relevant | |
7. Ensure no page is left unmentioned in your analysis | |
SPECIAL INSTRUCTIONS FOR TABULAR DATA: | |
- If the query requests a table, list, or structured data, organize your response in a clear, structured format | |
- Use numbered lists, bullet points, or clear categories when appropriate | |
- Include specific data points or comparisons when available | |
- Structure information in a way that can be easily converted to a table format | |
IMPORTANT: Respond with natural, human-readable text only. Do not include any special tokens, codes, or technical identifiers in your response. | |
Make sure to acknowledge and use information from all {len(imagesPaths)} provided pages. | |
""" | |
# Try with current model first | |
current_model = os.environ['ollama'] | |
# Set different options based on the model | |
if "gemma3" in current_model.lower(): | |
# Specific options for Gemma3 to prevent raw token issues | |
model_options = { | |
"num_predict": 1024, # Shorter responses for Gemma3 | |
"stop": ["<eos>", "<|endoftext|>", "</s>", "<|im_end|>"], # More stop tokens | |
"top_k": 20, # Lower top_k for more focused generation | |
"top_p": 0.8, # Lower top_p for more deterministic output | |
"repeat_penalty": 1.2, # Higher repeat penalty | |
"seed": 42, # Consistent results | |
"temperature": 0.7, # Lower temperature for more focused responses | |
} | |
else: | |
# Default options for other models | |
model_options = { | |
"num_predict": 2048, # Limit response length | |
"stop": ["<eos>", "<|endoftext|>", "</s>"], # Stop at end tokens | |
"top_k": 40, # Reduce randomness | |
"top_p": 0.9, # Nucleus sampling | |
"repeat_penalty": 1.1, # Prevent repetition | |
"seed": 42, # Consistent results | |
} | |
response = chat( | |
model=current_model, | |
messages=[ | |
{ | |
'role': 'user', | |
'content': enhanced_query, | |
'images': imagesPaths, | |
"temperature":float(os.environ['temperature']), #test if temp makes a diff | |
} | |
], | |
options=model_options | |
) | |
answer = response.message.content | |
# Clean the response to handle raw token issues | |
cleaned_answer = self._clean_raw_token_response(answer) | |
# If the cleaned answer is still problematic, try fallback models | |
if cleaned_answer and "β **Model Response Error**" in cleaned_answer: | |
print(f"β οΈ Primary model {current_model} failed, trying fallback models...") | |
# List of fallback models to try | |
fallback_models = [ | |
"llama3.2-vision:latest", | |
"llava:latest", | |
"bakllava:latest", | |
"llama3.2:latest" | |
] | |
for fallback_model in fallback_models: | |
try: | |
print(f"π Trying fallback model: {fallback_model}") | |
response = chat( | |
model=fallback_model, | |
messages=[ | |
{ | |
'role': 'user', | |
'content': enhanced_query, | |
'images': imagesPaths, | |
"temperature":float(os.environ['temperature']), | |
} | |
], | |
options={ | |
"num_predict": 2048, | |
"stop": ["<eos>", "<|endoftext|>", "</s>"], | |
"top_k": 40, | |
"top_p": 0.9, | |
"repeat_penalty": 1.1, | |
"seed": 42, | |
} | |
) | |
fallback_answer = response.message.content | |
cleaned_fallback = self._clean_raw_token_response(fallback_answer) | |
if cleaned_fallback and "β **Model Response Error**" not in cleaned_fallback: | |
print(f"β Fallback model {fallback_model} succeeded") | |
return cleaned_fallback | |
except Exception as fallback_error: | |
print(f"β Fallback model {fallback_model} failed: {fallback_error}") | |
continue | |
# If all fallbacks fail, return the original error | |
return cleaned_answer | |
print(f"Original response: {answer}") | |
print(f"Cleaned response: {cleaned_answer}") | |
return cleaned_answer | |
except Exception as e: | |
print(f"An error occurred while querying OpenAI: {e}") | |
return None | |
def __get_openai_api_payload(self, query:str, imagesPaths:List[str]): | |
image_payload = [] | |
for imagePath in imagesPaths: | |
base64_image = encode_image(imagePath) | |
image_payload.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{base64_image}" | |
} | |
}) | |
payload = { | |
"model": "Llama3.2-vision", #change model here as needed | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": query | |
}, | |
*image_payload | |
] | |
} | |
], | |
"max_tokens": 1024 #reduce token size to reduce processing time | |
} | |
return payload | |
# if __name__ == "__main__": | |
# rag = Rag() | |
# query = "Based on attached images, how many new cases were reported during second wave peak" | |
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"] | |
# rag.get_answer_from_gemini(query, imagesPaths) |