Kazel commited on
Commit
aada01f
·
verified ·
1 Parent(s): c70cfb9

Upload 5 files

Browse files
Files changed (4) hide show
  1. app.py +0 -0
  2. middleware.py +4 -4
  3. milvus_manager.py +9 -6
  4. rag.py +179 -36
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
middleware.py CHANGED
@@ -46,16 +46,16 @@ class Middleware:
46
 
47
 
48
 
49
- def search(self, search_queries: list[str]):
50
- print(f"Searching for {len(search_queries)} queries")
51
 
52
  final_res = []
53
 
54
  for query in search_queries:
55
  print(f"Searching for query: {query}")
56
  query_vec = colpali_manager.process_text([query])[0]
57
- search_res = self.milvus_manager.search(query_vec, topk=1)
58
- print(f"Search result: {search_res} for query: {query}")
59
  final_res.append(search_res)
60
 
61
  return final_res
 
46
 
47
 
48
 
49
+ def search(self, search_queries: list[str], topk: int = 10):
50
+ print(f"Searching for {len(search_queries)} queries with topk={topk}")
51
 
52
  final_res = []
53
 
54
  for query in search_queries:
55
  print(f"Searching for query: {query}")
56
  query_vec = colpali_manager.process_text([query])[0]
57
+ search_res = self.milvus_manager.search(query_vec, topk=topk)
58
+ print(f"Search result: {len(search_res)} results for query: {query}")
59
  final_res.append(search_res)
60
 
61
  return final_res
milvus_manager.py CHANGED
@@ -13,7 +13,7 @@ class MilvusManager:
13
  dotenv_file = dotenv.find_dotenv()
14
  dotenv.load_dotenv(dotenv_file)
15
 
16
- self.client = MilvusClient(uri=milvus_uri)
17
  self.collection_name = collection_name
18
  self.dim = dim
19
 
@@ -50,10 +50,13 @@ class MilvusManager:
50
 
51
  index_params.add_index(
52
  field_name="vector",
53
- metric_type="COSINE",
54
- index_type="IVF_FLAT",
55
  index_name="vector_index",
56
- params={ "nlist": 128 }
 
 
 
 
 
57
  )
58
 
59
  self.client.create_index(
@@ -65,7 +68,7 @@ class MilvusManager:
65
  collections = self.client.list_collections()
66
 
67
  # Set search parameters (here, using Inner Product metric).
68
- search_params = {"metric_type": "COSINE", "params": {}} #default metric type is "IP"
69
 
70
  # Set to store unique (doc_id, collection_name) pairs across all collections.
71
  doc_collection_pairs = set()
@@ -121,7 +124,7 @@ class MilvusManager:
121
  # Unload the collection after search to free memory.
122
  self.client.release_collection(collection_name=collection)
123
 
124
- return scores[:topk] if len(scores) >= topk else scores
125
  """
126
  search_params = {"metric_type": "IP", "params": {}}
127
  results = self.client.search(
 
13
  dotenv_file = dotenv.find_dotenv()
14
  dotenv.load_dotenv(dotenv_file)
15
 
16
+ self.client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
17
  self.collection_name = collection_name
18
  self.dim = dim
19
 
 
50
 
51
  index_params.add_index(
52
  field_name="vector",
 
 
53
  index_name="vector_index",
54
+ index_type="HNSW", #use HNSW option if got more mem, if not use IVF for faster processing
55
+ metric_type=os.environ["metrictype"], #"IP"
56
+ params={
57
+ "M": int(os.environ["mnum"]), #M:16 for HNSW, capital M
58
+ "efConstruction": int(os.environ["efnum"]), #500 for HNSW
59
+ },
60
  )
61
 
62
  self.client.create_index(
 
68
  collections = self.client.list_collections()
69
 
70
  # Set search parameters (here, using Inner Product metric).
71
+ search_params = {"metric_type": os.environ["metrictype"], "params": {}} #default metric type is "IP"
72
 
73
  # Set to store unique (doc_id, collection_name) pairs across all collections.
74
  doc_collection_pairs = set()
 
124
  # Unload the collection after search to free memory.
125
  self.client.release_collection(collection_name=collection)
126
 
127
+ return scores[:topk] if len(scores) >= topk else scores #topk is the number of scores to return back
128
  """
129
  search_params = {"metric_type": "IP", "params": {}}
130
  results = self.client.search(
rag.py CHANGED
@@ -1,27 +1,77 @@
1
  import requests
2
  import os
 
3
 
4
  from typing import List
5
  from utils import encode_image
6
  from PIL import Image
 
7
  import torch
8
  import subprocess
9
  import psutil
10
  import torch
11
  from transformers import AutoModel, AutoTokenizer
12
- import google.generativeai as genai
13
-
14
 
15
 
16
  class Rag:
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def get_answer_from_gemini(self, query, imagePaths):
 
19
 
20
  print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
21
 
22
  try:
23
- genai.configure(api_key="AIzaSyBF-MJKxRROIr-X6YiG1_8uOHrFZDX3IBI")
24
- model = genai.GenerativeModel('gemini-2.5-flash')
25
 
26
  images = [Image.open(path) for path in imagePaths]
27
 
@@ -45,35 +95,10 @@ class Rag:
45
  #import environ variables from .env
46
  import dotenv
47
 
48
- # Load the .env file
49
  dotenv_file = dotenv.find_dotenv()
50
  dotenv.load_dotenv(dotenv_file)
51
- """ #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
52
- print(f"Querying for query={query}, imagesPaths={imagesPaths}")
53
-
54
- model = AutoModel.from_pretrained(
55
- 'openbmb/MiniCPM-o-2_6-int4',
56
- trust_remote_code=True,
57
- attn_implementation='flash_attention_2', # sdpa or flash_attention_2
58
- torch_dtype=torch.bfloat16,
59
- init_vision=True,
60
- )
61
-
62
-
63
- model = model.eval().cuda()
64
- tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6-int4', trust_remote_code=True)
65
- image = Image.open(imagesPaths[0]).convert('RGB')
66
 
67
- msgs = [{'role': 'user', 'content': [image, query]}]
68
- answer = model.chat(
69
- image=None,
70
- msgs=msgs,
71
- tokenizer=tokenizer
72
- )
73
- print(answer)
74
- return answer
75
- """
76
-
77
  #ollama method below
78
 
79
  torch.cuda.empty_cache() #release cuda so that ollama can use gpu!
@@ -82,31 +107,149 @@ class Rag:
82
  os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1"
83
  if os.environ['ollama'] == "minicpm-v":
84
  os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version
 
 
 
 
 
85
 
86
 
87
  # Close model thread (colpali)
88
  print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
89
- from ollama import chat
90
 
91
  try:
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  response = chat(
94
- model=os.environ['ollama'],
95
  messages=[
96
  {
97
  'role': 'user',
98
- 'content': query,
99
  'images': imagesPaths,
100
  "temperature":float(os.environ['temperature']), #test if temp makes a diff
101
  }
102
  ],
 
103
  )
104
 
105
  answer = response.message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- print(answer)
 
108
 
109
- return answer
110
 
111
  except Exception as e:
112
  print(f"An error occurred while querying OpenAI: {e}")
@@ -153,4 +296,4 @@ class Rag:
153
  # query = "Based on attached images, how many new cases were reported during second wave peak"
154
  # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
155
 
156
- # rag.get_answer_from_gemini(query, imagesPaths)
 
1
  import requests
2
  import os
3
+ import re
4
 
5
  from typing import List
6
  from utils import encode_image
7
  from PIL import Image
8
+ from ollama import chat
9
  import torch
10
  import subprocess
11
  import psutil
12
  import torch
13
  from transformers import AutoModel, AutoTokenizer
14
+ from google import genai
 
15
 
16
 
17
  class Rag:
18
 
19
+ def _clean_raw_token_response(self, response_text):
20
+ """
21
+ Clean raw token responses that contain undecoded token IDs
22
+ This handles cases where models return raw tokens instead of decoded text
23
+ """
24
+ if not response_text:
25
+ return response_text
26
+
27
+ # Check if response contains raw token patterns
28
+ token_patterns = [
29
+ r'<unused\d+>', # unused tokens
30
+ r'<bos>', # beginning of sequence
31
+ r'<eos>', # end of sequence
32
+ r'<unk>', # unknown tokens
33
+ r'<mask>', # mask tokens
34
+ r'<pad>', # padding tokens
35
+ r'\[multimodal\]', # multimodal tokens
36
+ ]
37
+
38
+ # If response contains raw tokens, try to clean them
39
+ has_raw_tokens = any(re.search(pattern, response_text) for pattern in token_patterns)
40
+
41
+ if has_raw_tokens:
42
+ print("⚠️ Detected raw token response, attempting to clean...")
43
+
44
+ # Remove common raw token patterns
45
+ cleaned_text = response_text
46
+
47
+ # Remove unused tokens
48
+ cleaned_text = re.sub(r'<unused\d+>', '', cleaned_text)
49
+
50
+ # Remove special tokens
51
+ cleaned_text = re.sub(r'<(bos|eos|unk|mask|pad)>', '', cleaned_text)
52
+
53
+ # Remove multimodal tokens
54
+ cleaned_text = re.sub(r'\[multimodal\]', '', cleaned_text)
55
+
56
+ # Clean up extra whitespace
57
+ cleaned_text = re.sub(r'\s+', ' ', cleaned_text).strip()
58
+
59
+ # If we still have mostly tokens, return an error message
60
+ if len(cleaned_text.strip()) < 10:
61
+ return "❌ **Model Response Error**: The model returned raw token IDs instead of decoded text. This may be due to model configuration issues. Please try:\n\n1. Restarting the Ollama server\n2. Using a different model\n3. Checking model compatibility with multimodal inputs"
62
+
63
+ return cleaned_text
64
+
65
+ return response_text
66
+
67
  def get_answer_from_gemini(self, query, imagePaths):
68
+
69
 
70
  print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
71
 
72
  try:
73
+ genai.configure(api_key='AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs')
74
+ model = genai.GenerativeModel('gemini-2.0-flash')
75
 
76
  images = [Image.open(path) for path in imagePaths]
77
 
 
95
  #import environ variables from .env
96
  import dotenv
97
 
98
+ # Load the .env file
99
  dotenv_file = dotenv.find_dotenv()
100
  dotenv.load_dotenv(dotenv_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
 
 
 
 
102
  #ollama method below
103
 
104
  torch.cuda.empty_cache() #release cuda so that ollama can use gpu!
 
107
  os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1"
108
  if os.environ['ollama'] == "minicpm-v":
109
  os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version
110
+ elif os.environ['ollama'] == "gemma3":
111
+ os.environ['ollama'] = "gemma3:12b" #set to upscaled version
112
+ # Add specific environment variables for Gemma3 to prevent raw token issues
113
+ os.environ['OLLAMA_KEEP_ALIVE'] = "5m"
114
+ os.environ['OLLAMA_ORIGINS'] = "*"
115
 
116
 
117
  # Close model thread (colpali)
118
  print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
 
119
 
120
  try:
121
 
122
+ # Enhanced prompt for more detailed responses with explicit page usage
123
+ enhanced_query = f"""
124
+ Please provide a comprehensive and detailed answer to the following query.
125
+ Use ALL available information from the provided document images to give a thorough response.
126
+
127
+ Query: {query}
128
+
129
+ CRITICAL INSTRUCTIONS:
130
+ - You have been provided with {len(imagesPaths)} document page(s)
131
+ - You MUST reference information from ALL {len(imagesPaths)} page(s) in your response
132
+ - Do not skip any pages - each page contains relevant information
133
+ - If you mention one page, you must also mention the others
134
+ - Ensure your response reflects the complete information from all pages
135
+
136
+ Instructions for detailed response:
137
+ 1. Provide extensive background information and context
138
+ 2. Include specific details, examples, and data points from ALL documents
139
+ 3. Explain concepts thoroughly with step-by-step breakdowns
140
+ 4. Provide comprehensive analysis rather than simple answers when requested
141
+ 5. Explicitly reference each page and what information it contributes
142
+ 6. Cross-reference information between pages when relevant
143
+ 7. Ensure no page is left unmentioned in your analysis
144
+
145
+ SPECIAL INSTRUCTIONS FOR TABULAR DATA:
146
+ - If the query requests a table, list, or structured data, organize your response in a clear, structured format
147
+ - Use numbered lists, bullet points, or clear categories when appropriate
148
+ - Include specific data points or comparisons when available
149
+ - Structure information in a way that can be easily converted to a table format
150
+
151
+ IMPORTANT: Respond with natural, human-readable text only. Do not include any special tokens, codes, or technical identifiers in your response.
152
+
153
+ Make sure to acknowledge and use information from all {len(imagesPaths)} provided pages.
154
+ """
155
+
156
+ # Try with current model first
157
+ current_model = os.environ['ollama']
158
+
159
+ # Set different options based on the model
160
+ if "gemma3" in current_model.lower():
161
+ # Specific options for Gemma3 to prevent raw token issues
162
+ model_options = {
163
+ "num_predict": 1024, # Shorter responses for Gemma3
164
+ "stop": ["<eos>", "<|endoftext|>", "</s>", "<|im_end|>"], # More stop tokens
165
+ "top_k": 20, # Lower top_k for more focused generation
166
+ "top_p": 0.8, # Lower top_p for more deterministic output
167
+ "repeat_penalty": 1.2, # Higher repeat penalty
168
+ "seed": 42, # Consistent results
169
+ "temperature": 0.7, # Lower temperature for more focused responses
170
+ }
171
+ else:
172
+ # Default options for other models
173
+ model_options = {
174
+ "num_predict": 2048, # Limit response length
175
+ "stop": ["<eos>", "<|endoftext|>", "</s>"], # Stop at end tokens
176
+ "top_k": 40, # Reduce randomness
177
+ "top_p": 0.9, # Nucleus sampling
178
+ "repeat_penalty": 1.1, # Prevent repetition
179
+ "seed": 42, # Consistent results
180
+ }
181
+
182
  response = chat(
183
+ model=current_model,
184
  messages=[
185
  {
186
  'role': 'user',
187
+ 'content': enhanced_query,
188
  'images': imagesPaths,
189
  "temperature":float(os.environ['temperature']), #test if temp makes a diff
190
  }
191
  ],
192
+ options=model_options
193
  )
194
 
195
  answer = response.message.content
196
+
197
+ # Clean the response to handle raw token issues
198
+ cleaned_answer = self._clean_raw_token_response(answer)
199
+
200
+ # If the cleaned answer is still problematic, try fallback models
201
+ if cleaned_answer and "❌ **Model Response Error**" in cleaned_answer:
202
+ print(f"⚠️ Primary model {current_model} failed, trying fallback models...")
203
+
204
+ # List of fallback models to try
205
+ fallback_models = [
206
+ "llama3.2-vision:latest",
207
+ "llava:latest",
208
+ "bakllava:latest",
209
+ "llama3.2:latest"
210
+ ]
211
+
212
+ for fallback_model in fallback_models:
213
+ try:
214
+ print(f"🔄 Trying fallback model: {fallback_model}")
215
+ response = chat(
216
+ model=fallback_model,
217
+ messages=[
218
+ {
219
+ 'role': 'user',
220
+ 'content': enhanced_query,
221
+ 'images': imagesPaths,
222
+ "temperature":float(os.environ['temperature']),
223
+ }
224
+ ],
225
+ options={
226
+ "num_predict": 2048,
227
+ "stop": ["<eos>", "<|endoftext|>", "</s>"],
228
+ "top_k": 40,
229
+ "top_p": 0.9,
230
+ "repeat_penalty": 1.1,
231
+ "seed": 42,
232
+ }
233
+ )
234
+
235
+ fallback_answer = response.message.content
236
+ cleaned_fallback = self._clean_raw_token_response(fallback_answer)
237
+
238
+ if cleaned_fallback and "❌ **Model Response Error**" not in cleaned_fallback:
239
+ print(f"✅ Fallback model {fallback_model} succeeded")
240
+ return cleaned_fallback
241
+
242
+ except Exception as fallback_error:
243
+ print(f"❌ Fallback model {fallback_model} failed: {fallback_error}")
244
+ continue
245
+
246
+ # If all fallbacks fail, return the original error
247
+ return cleaned_answer
248
 
249
+ print(f"Original response: {answer}")
250
+ print(f"Cleaned response: {cleaned_answer}")
251
 
252
+ return cleaned_answer
253
 
254
  except Exception as e:
255
  print(f"An error occurred while querying OpenAI: {e}")
 
296
  # query = "Based on attached images, how many new cases were reported during second wave peak"
297
  # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
298
 
299
+ # rag.get_answer_from_gemini(query, imagesPaths)