Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -74,8 +74,8 @@ class DocumentRAG:
|
|
74 |
def setup_fallback_model(self):
|
75 |
"""Fallback to smaller model if Mistral fails"""
|
76 |
try:
|
77 |
-
# Use a
|
78 |
-
model_name = "
|
79 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
80 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
81 |
|
@@ -86,8 +86,38 @@ class DocumentRAG:
|
|
86 |
print("β
Fallback model loaded")
|
87 |
except Exception as e:
|
88 |
print(f"β Fallback model failed: {e}")
|
|
|
89 |
self.model = None
|
90 |
self.tokenizer = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
def extract_text_from_file(self, file_path: str) -> str:
|
93 |
"""Extract text from various file formats"""
|
@@ -251,7 +281,7 @@ class DocumentRAG:
|
|
251 |
return ""
|
252 |
|
253 |
def generate_answer(self, query: str, context: str) -> str:
|
254 |
-
"""Generate answer using the LLM with
|
255 |
if self.model is None or self.tokenizer is None:
|
256 |
return "β Model not available. Please try again."
|
257 |
|
@@ -261,28 +291,37 @@ class DocumentRAG:
|
|
261 |
is_mistral = 'mistral' in model_name
|
262 |
|
263 |
if is_mistral:
|
264 |
-
#
|
265 |
-
prompt = f"""<s>[INST] You are a
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
-
|
268 |
-
{context[:
|
269 |
|
270 |
-
|
271 |
|
272 |
-
|
273 |
else:
|
274 |
-
#
|
275 |
-
prompt = f"""
|
276 |
|
277 |
-
|
|
|
278 |
|
279 |
-
|
|
|
|
|
280 |
|
281 |
# Tokenize with proper handling
|
282 |
inputs = self.tokenizer(
|
283 |
prompt,
|
284 |
return_tensors="pt",
|
285 |
-
max_length=
|
286 |
truncation=True,
|
287 |
padding=True
|
288 |
)
|
@@ -291,15 +330,17 @@ Answer based on the context:"""
|
|
291 |
if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
|
292 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
293 |
|
294 |
-
# Generate with
|
295 |
with torch.no_grad():
|
296 |
outputs = self.model.generate(
|
297 |
**inputs,
|
298 |
-
max_new_tokens=
|
299 |
-
temperature=0.
|
300 |
-
do_sample=
|
301 |
-
|
302 |
-
|
|
|
|
|
303 |
pad_token_id=self.tokenizer.pad_token_id,
|
304 |
eos_token_id=self.tokenizer.eos_token_id
|
305 |
)
|
@@ -312,18 +353,64 @@ Answer based on the context:"""
|
|
312 |
answer = full_response.split("[/INST]")[-1].strip()
|
313 |
else:
|
314 |
# For other models, remove the prompt
|
315 |
-
|
|
|
|
|
|
|
316 |
|
317 |
-
#
|
318 |
-
answer =
|
319 |
|
320 |
return answer if answer else "I couldn't generate a proper response based on the context."
|
321 |
|
322 |
except Exception as e:
|
323 |
return f"β Error generating answer: {str(e)}"
|
324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
def answer_question(self, query: str) -> str:
|
326 |
-
"""Main function to answer questions"""
|
327 |
if not query.strip():
|
328 |
return "β Please ask a question!"
|
329 |
|
@@ -337,17 +424,72 @@ Answer based on the context:"""
|
|
337 |
if not context:
|
338 |
return "π No relevant information found in the uploaded documents for your question."
|
339 |
|
340 |
-
#
|
|
|
|
|
|
|
|
|
|
|
341 |
answer = self.generate_answer(query, context)
|
342 |
|
343 |
-
#
|
344 |
if answer and not answer.startswith("β"):
|
|
|
|
|
|
|
|
|
345 |
return f"π‘ **Answer:** {answer}\n\nπ **Relevant Context:**\n{context[:400]}..."
|
346 |
else:
|
347 |
return answer
|
348 |
|
349 |
except Exception as e:
|
350 |
return f"β Error answering question: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
# Initialize the RAG system
|
353 |
print("Initializing Document RAG System...")
|
|
|
74 |
def setup_fallback_model(self):
|
75 |
"""Fallback to smaller model if Mistral fails"""
|
76 |
try:
|
77 |
+
# Use a model that's better for factual Q&A and less prone to hallucination
|
78 |
+
model_name = "microsoft/DialoGPT-small"
|
79 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
80 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
81 |
|
|
|
86 |
print("β
Fallback model loaded")
|
87 |
except Exception as e:
|
88 |
print(f"β Fallback model failed: {e}")
|
89 |
+
# Try an even simpler approach - return context-based answers without generation
|
90 |
self.model = None
|
91 |
self.tokenizer = None
|
92 |
+
print("β οΈ Using context-only mode (no text generation)")
|
93 |
+
|
94 |
+
def simple_context_answer(self, query: str, context: str) -> str:
|
95 |
+
"""Simple context-based answering when model is not available"""
|
96 |
+
if not context:
|
97 |
+
return "No relevant information found in the documents."
|
98 |
+
|
99 |
+
# Simple keyword matching approach
|
100 |
+
query_words = set(query.lower().split())
|
101 |
+
context_sentences = context.split('.')
|
102 |
+
|
103 |
+
# Find sentences that contain query keywords
|
104 |
+
relevant_sentences = []
|
105 |
+
for sentence in context_sentences:
|
106 |
+
sentence = sentence.strip()
|
107 |
+
if len(sentence) < 10: # Skip very short sentences
|
108 |
+
continue
|
109 |
+
|
110 |
+
sentence_words = set(sentence.lower().split())
|
111 |
+
# Check if sentence contains at least 2 query words or important keywords
|
112 |
+
common_words = query_words.intersection(sentence_words)
|
113 |
+
if len(common_words) >= 2 or any(word in sentence.lower() for word in ['name', 'education', 'experience', 'skill', 'project']):
|
114 |
+
relevant_sentences.append(sentence)
|
115 |
+
|
116 |
+
if relevant_sentences:
|
117 |
+
# Return the most relevant sentences
|
118 |
+
return '. '.join(relevant_sentences[:3]) + '.'
|
119 |
+
else:
|
120 |
+
return "The information needed to answer this question is not available in the provided documents."
|
121 |
|
122 |
def extract_text_from_file(self, file_path: str) -> str:
|
123 |
"""Extract text from various file formats"""
|
|
|
281 |
return ""
|
282 |
|
283 |
def generate_answer(self, query: str, context: str) -> str:
|
284 |
+
"""Generate answer using the LLM with anti-hallucination techniques"""
|
285 |
if self.model is None or self.tokenizer is None:
|
286 |
return "β Model not available. Please try again."
|
287 |
|
|
|
291 |
is_mistral = 'mistral' in model_name
|
292 |
|
293 |
if is_mistral:
|
294 |
+
# Anti-hallucination prompt for Mistral
|
295 |
+
prompt = f"""<s>[INST] You are a document analysis assistant. You must ONLY answer based on the provided context. Do NOT use any external knowledge.
|
296 |
+
|
297 |
+
STRICT RULES:
|
298 |
+
1. Answer ONLY using information from the context below
|
299 |
+
2. If the answer is not in the context, respond: "The information needed to answer this question is not available in the provided documents."
|
300 |
+
3. Do NOT make assumptions or add information not in the context
|
301 |
+
4. Quote relevant parts from the context when possible
|
302 |
|
303 |
+
CONTEXT:
|
304 |
+
{context[:1200]}
|
305 |
|
306 |
+
QUESTION: {query}
|
307 |
|
308 |
+
Remember: Use ONLY the context above. No external knowledge allowed. [/INST]"""
|
309 |
else:
|
310 |
+
# Anti-hallucination prompt for fallback models
|
311 |
+
prompt = f"""INSTRUCTIONS: Answer the question using ONLY the information provided in the context. Do not use external knowledge.
|
312 |
|
313 |
+
CONTEXT:
|
314 |
+
{context[:800]}
|
315 |
|
316 |
+
QUESTION: {query}
|
317 |
+
|
318 |
+
ANSWER (using only the context above):"""
|
319 |
|
320 |
# Tokenize with proper handling
|
321 |
inputs = self.tokenizer(
|
322 |
prompt,
|
323 |
return_tensors="pt",
|
324 |
+
max_length=600, # Further reduced to prevent truncation issues
|
325 |
truncation=True,
|
326 |
padding=True
|
327 |
)
|
|
|
330 |
if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
|
331 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
332 |
|
333 |
+
# Generate with anti-hallucination parameters
|
334 |
with torch.no_grad():
|
335 |
outputs = self.model.generate(
|
336 |
**inputs,
|
337 |
+
max_new_tokens=100, # Shorter responses to reduce hallucination
|
338 |
+
temperature=0.1, # Very low temperature for factual responses
|
339 |
+
do_sample=False, # Use greedy decoding for consistency
|
340 |
+
num_beams=3, # Beam search for better quality
|
341 |
+
early_stopping=True,
|
342 |
+
repetition_penalty=1.2,
|
343 |
+
no_repeat_ngram_size=3,
|
344 |
pad_token_id=self.tokenizer.pad_token_id,
|
345 |
eos_token_id=self.tokenizer.eos_token_id
|
346 |
)
|
|
|
353 |
answer = full_response.split("[/INST]")[-1].strip()
|
354 |
else:
|
355 |
# For other models, remove the prompt
|
356 |
+
if "ANSWER (using only the context above):" in full_response:
|
357 |
+
answer = full_response.split("ANSWER (using only the context above):")[-1].strip()
|
358 |
+
else:
|
359 |
+
answer = full_response[len(prompt):].strip()
|
360 |
|
361 |
+
# Post-process to remove hallucinations
|
362 |
+
answer = self.post_process_answer(answer, context, query)
|
363 |
|
364 |
return answer if answer else "I couldn't generate a proper response based on the context."
|
365 |
|
366 |
except Exception as e:
|
367 |
return f"β Error generating answer: {str(e)}"
|
368 |
|
369 |
+
def post_process_answer(self, answer: str, context: str, query: str) -> str:
|
370 |
+
"""Post-process answer to reduce hallucinations"""
|
371 |
+
if not answer or len(answer) < 5:
|
372 |
+
return "The information needed to answer this question is not available in the provided documents."
|
373 |
+
|
374 |
+
# Remove common hallucination patterns
|
375 |
+
hallucination_patterns = [
|
376 |
+
"what are you doing",
|
377 |
+
"what do you think",
|
378 |
+
"in my opinion",
|
379 |
+
"i believe",
|
380 |
+
"personally",
|
381 |
+
"from my experience",
|
382 |
+
"generally speaking",
|
383 |
+
"it is known that",
|
384 |
+
"everyone knows"
|
385 |
+
]
|
386 |
+
|
387 |
+
answer_lower = answer.lower()
|
388 |
+
for pattern in hallucination_patterns:
|
389 |
+
if pattern in answer_lower:
|
390 |
+
return "The information needed to answer this question is not available in the provided documents."
|
391 |
+
|
392 |
+
# Check if answer contains information that's not in context
|
393 |
+
# Simple check: if answer is much longer than query and doesn't reference context
|
394 |
+
if len(answer) > len(query) * 3 and not any(word in answer.lower() for word in context.lower().split()[:20]):
|
395 |
+
return "The information needed to answer this question is not available in the provided documents."
|
396 |
+
|
397 |
+
# Clean up the answer
|
398 |
+
answer = answer.strip()
|
399 |
+
|
400 |
+
# Remove repetitive parts
|
401 |
+
sentences = answer.split('.')
|
402 |
+
unique_sentences = []
|
403 |
+
for sentence in sentences:
|
404 |
+
sentence = sentence.strip()
|
405 |
+
if sentence and sentence not in unique_sentences:
|
406 |
+
unique_sentences.append(sentence)
|
407 |
+
|
408 |
+
cleaned_answer = '. '.join(unique_sentences)
|
409 |
+
|
410 |
+
return cleaned_answer if cleaned_answer else "The information needed to answer this question is not available in the provided documents."
|
411 |
+
|
412 |
def answer_question(self, query: str) -> str:
|
413 |
+
"""Main function to answer questions with anti-hallucination measures"""
|
414 |
if not query.strip():
|
415 |
return "β Please ask a question!"
|
416 |
|
|
|
424 |
if not context:
|
425 |
return "π No relevant information found in the uploaded documents for your question."
|
426 |
|
427 |
+
# If no model available, use simple context-based answering
|
428 |
+
if self.model is None:
|
429 |
+
answer = self.simple_context_answer(query, context)
|
430 |
+
return f"π‘ **Answer:** {answer}\n\nπ **Source:** {context[:300]}..."
|
431 |
+
|
432 |
+
# Generate answer using the model
|
433 |
answer = self.generate_answer(query, context)
|
434 |
|
435 |
+
# Additional validation to prevent hallucinations
|
436 |
if answer and not answer.startswith("β"):
|
437 |
+
# Check if answer seems to be hallucinated
|
438 |
+
if self.is_likely_hallucination(answer, context):
|
439 |
+
answer = "The information needed to answer this question is not available in the provided documents."
|
440 |
+
|
441 |
return f"π‘ **Answer:** {answer}\n\nπ **Relevant Context:**\n{context[:400]}..."
|
442 |
else:
|
443 |
return answer
|
444 |
|
445 |
except Exception as e:
|
446 |
return f"β Error answering question: {str(e)}"
|
447 |
+
|
448 |
+
def is_likely_hallucination(self, answer: str, context: str) -> bool:
|
449 |
+
"""Check if the answer is likely a hallucination"""
|
450 |
+
# Convert to lowercase for comparison
|
451 |
+
answer_lower = answer.lower()
|
452 |
+
context_lower = context.lower()
|
453 |
+
|
454 |
+
# Check for obvious hallucination patterns
|
455 |
+
hallucination_indicators = [
|
456 |
+
"what are you doing",
|
457 |
+
"what do you think",
|
458 |
+
"how are you",
|
459 |
+
"i think",
|
460 |
+
"in my opinion",
|
461 |
+
"from my experience",
|
462 |
+
"generally speaking",
|
463 |
+
"it is well known",
|
464 |
+
"everyone knows",
|
465 |
+
"obviously",
|
466 |
+
"clearly",
|
467 |
+
"of course"
|
468 |
+
]
|
469 |
+
|
470 |
+
for indicator in hallucination_indicators:
|
471 |
+
if indicator in answer_lower:
|
472 |
+
return True
|
473 |
+
|
474 |
+
# Check if answer contains words that are not in context
|
475 |
+
answer_words = set(answer_lower.split())
|
476 |
+
context_words = set(context_lower.split())
|
477 |
+
|
478 |
+
# Remove common words
|
479 |
+
common_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'will', 'would', 'could', 'should', 'this', 'that', 'these', 'those'}
|
480 |
+
|
481 |
+
answer_content_words = answer_words - common_words
|
482 |
+
context_content_words = context_words - common_words
|
483 |
+
|
484 |
+
# If more than 70% of content words in answer are not in context, likely hallucination
|
485 |
+
if len(answer_content_words) > 0:
|
486 |
+
overlap = len(answer_content_words.intersection(context_content_words))
|
487 |
+
overlap_ratio = overlap / len(answer_content_words)
|
488 |
+
|
489 |
+
if overlap_ratio < 0.3: # Less than 30% overlap
|
490 |
+
return True
|
491 |
+
|
492 |
+
return False
|
493 |
|
494 |
# Initialize the RAG system
|
495 |
print("Initializing Document RAG System...")
|