pradeepsengarr commited on
Commit
611ac83
Β·
verified Β·
1 Parent(s): a8283c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -26
app.py CHANGED
@@ -74,8 +74,8 @@ class DocumentRAG:
74
  def setup_fallback_model(self):
75
  """Fallback to smaller model if Mistral fails"""
76
  try:
77
- # Use a better fallback model for Q&A
78
- model_name = "distilgpt2"
79
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
80
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
81
 
@@ -86,8 +86,38 @@ class DocumentRAG:
86
  print("βœ… Fallback model loaded")
87
  except Exception as e:
88
  print(f"❌ Fallback model failed: {e}")
 
89
  self.model = None
90
  self.tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def extract_text_from_file(self, file_path: str) -> str:
93
  """Extract text from various file formats"""
@@ -251,7 +281,7 @@ class DocumentRAG:
251
  return ""
252
 
253
  def generate_answer(self, query: str, context: str) -> str:
254
- """Generate answer using the LLM with improved prompting"""
255
  if self.model is None or self.tokenizer is None:
256
  return "❌ Model not available. Please try again."
257
 
@@ -261,28 +291,37 @@ class DocumentRAG:
261
  is_mistral = 'mistral' in model_name
262
 
263
  if is_mistral:
264
- # Mistral-specific prompt format
265
- prompt = f"""<s>[INST] You are a helpful assistant that answers questions based on the provided context. Use only the information from the context to answer. If the information is not in the context, say "I don't have enough information to answer this question."
 
 
 
 
 
 
266
 
267
- Context:
268
- {context[:1500]}
269
 
270
- Question: {query}
271
 
272
- Provide a clear and concise answer based only on the context above. [/INST]"""
273
  else:
274
- # Generic prompt for fallback models
275
- prompt = f"""Context: {context[:1000]}
276
 
277
- Question: {query}
 
278
 
279
- Answer based on the context:"""
 
 
280
 
281
  # Tokenize with proper handling
282
  inputs = self.tokenizer(
283
  prompt,
284
  return_tensors="pt",
285
- max_length=800, # Reduced to fit in memory
286
  truncation=True,
287
  padding=True
288
  )
@@ -291,15 +330,17 @@ Answer based on the context:"""
291
  if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
292
  inputs = {k: v.cuda() for k, v in inputs.items()}
293
 
294
- # Generate with better parameters
295
  with torch.no_grad():
296
  outputs = self.model.generate(
297
  **inputs,
298
- max_new_tokens=150, # Reduced for more focused answers
299
- temperature=0.3, # Lower temperature for more consistent answers
300
- do_sample=True,
301
- top_p=0.8,
302
- repetition_penalty=1.1,
 
 
303
  pad_token_id=self.tokenizer.pad_token_id,
304
  eos_token_id=self.tokenizer.eos_token_id
305
  )
@@ -312,18 +353,64 @@ Answer based on the context:"""
312
  answer = full_response.split("[/INST]")[-1].strip()
313
  else:
314
  # For other models, remove the prompt
315
- answer = full_response[len(prompt):].strip()
 
 
 
316
 
317
- # Clean up the answer
318
- answer = answer.replace(prompt, "").strip()
319
 
320
  return answer if answer else "I couldn't generate a proper response based on the context."
321
 
322
  except Exception as e:
323
  return f"❌ Error generating answer: {str(e)}"
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  def answer_question(self, query: str) -> str:
326
- """Main function to answer questions"""
327
  if not query.strip():
328
  return "❓ Please ask a question!"
329
 
@@ -337,17 +424,72 @@ Answer based on the context:"""
337
  if not context:
338
  return "πŸ” No relevant information found in the uploaded documents for your question."
339
 
340
- # Generate answer
 
 
 
 
 
341
  answer = self.generate_answer(query, context)
342
 
343
- # Format the response
344
  if answer and not answer.startswith("❌"):
 
 
 
 
345
  return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Relevant Context:**\n{context[:400]}..."
346
  else:
347
  return answer
348
 
349
  except Exception as e:
350
  return f"❌ Error answering question: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  # Initialize the RAG system
353
  print("Initializing Document RAG System...")
 
74
  def setup_fallback_model(self):
75
  """Fallback to smaller model if Mistral fails"""
76
  try:
77
+ # Use a model that's better for factual Q&A and less prone to hallucination
78
+ model_name = "microsoft/DialoGPT-small"
79
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
80
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
81
 
 
86
  print("βœ… Fallback model loaded")
87
  except Exception as e:
88
  print(f"❌ Fallback model failed: {e}")
89
+ # Try an even simpler approach - return context-based answers without generation
90
  self.model = None
91
  self.tokenizer = None
92
+ print("⚠️ Using context-only mode (no text generation)")
93
+
94
+ def simple_context_answer(self, query: str, context: str) -> str:
95
+ """Simple context-based answering when model is not available"""
96
+ if not context:
97
+ return "No relevant information found in the documents."
98
+
99
+ # Simple keyword matching approach
100
+ query_words = set(query.lower().split())
101
+ context_sentences = context.split('.')
102
+
103
+ # Find sentences that contain query keywords
104
+ relevant_sentences = []
105
+ for sentence in context_sentences:
106
+ sentence = sentence.strip()
107
+ if len(sentence) < 10: # Skip very short sentences
108
+ continue
109
+
110
+ sentence_words = set(sentence.lower().split())
111
+ # Check if sentence contains at least 2 query words or important keywords
112
+ common_words = query_words.intersection(sentence_words)
113
+ if len(common_words) >= 2 or any(word in sentence.lower() for word in ['name', 'education', 'experience', 'skill', 'project']):
114
+ relevant_sentences.append(sentence)
115
+
116
+ if relevant_sentences:
117
+ # Return the most relevant sentences
118
+ return '. '.join(relevant_sentences[:3]) + '.'
119
+ else:
120
+ return "The information needed to answer this question is not available in the provided documents."
121
 
122
  def extract_text_from_file(self, file_path: str) -> str:
123
  """Extract text from various file formats"""
 
281
  return ""
282
 
283
  def generate_answer(self, query: str, context: str) -> str:
284
+ """Generate answer using the LLM with anti-hallucination techniques"""
285
  if self.model is None or self.tokenizer is None:
286
  return "❌ Model not available. Please try again."
287
 
 
291
  is_mistral = 'mistral' in model_name
292
 
293
  if is_mistral:
294
+ # Anti-hallucination prompt for Mistral
295
+ prompt = f"""<s>[INST] You are a document analysis assistant. You must ONLY answer based on the provided context. Do NOT use any external knowledge.
296
+
297
+ STRICT RULES:
298
+ 1. Answer ONLY using information from the context below
299
+ 2. If the answer is not in the context, respond: "The information needed to answer this question is not available in the provided documents."
300
+ 3. Do NOT make assumptions or add information not in the context
301
+ 4. Quote relevant parts from the context when possible
302
 
303
+ CONTEXT:
304
+ {context[:1200]}
305
 
306
+ QUESTION: {query}
307
 
308
+ Remember: Use ONLY the context above. No external knowledge allowed. [/INST]"""
309
  else:
310
+ # Anti-hallucination prompt for fallback models
311
+ prompt = f"""INSTRUCTIONS: Answer the question using ONLY the information provided in the context. Do not use external knowledge.
312
 
313
+ CONTEXT:
314
+ {context[:800]}
315
 
316
+ QUESTION: {query}
317
+
318
+ ANSWER (using only the context above):"""
319
 
320
  # Tokenize with proper handling
321
  inputs = self.tokenizer(
322
  prompt,
323
  return_tensors="pt",
324
+ max_length=600, # Further reduced to prevent truncation issues
325
  truncation=True,
326
  padding=True
327
  )
 
330
  if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
331
  inputs = {k: v.cuda() for k, v in inputs.items()}
332
 
333
+ # Generate with anti-hallucination parameters
334
  with torch.no_grad():
335
  outputs = self.model.generate(
336
  **inputs,
337
+ max_new_tokens=100, # Shorter responses to reduce hallucination
338
+ temperature=0.1, # Very low temperature for factual responses
339
+ do_sample=False, # Use greedy decoding for consistency
340
+ num_beams=3, # Beam search for better quality
341
+ early_stopping=True,
342
+ repetition_penalty=1.2,
343
+ no_repeat_ngram_size=3,
344
  pad_token_id=self.tokenizer.pad_token_id,
345
  eos_token_id=self.tokenizer.eos_token_id
346
  )
 
353
  answer = full_response.split("[/INST]")[-1].strip()
354
  else:
355
  # For other models, remove the prompt
356
+ if "ANSWER (using only the context above):" in full_response:
357
+ answer = full_response.split("ANSWER (using only the context above):")[-1].strip()
358
+ else:
359
+ answer = full_response[len(prompt):].strip()
360
 
361
+ # Post-process to remove hallucinations
362
+ answer = self.post_process_answer(answer, context, query)
363
 
364
  return answer if answer else "I couldn't generate a proper response based on the context."
365
 
366
  except Exception as e:
367
  return f"❌ Error generating answer: {str(e)}"
368
 
369
+ def post_process_answer(self, answer: str, context: str, query: str) -> str:
370
+ """Post-process answer to reduce hallucinations"""
371
+ if not answer or len(answer) < 5:
372
+ return "The information needed to answer this question is not available in the provided documents."
373
+
374
+ # Remove common hallucination patterns
375
+ hallucination_patterns = [
376
+ "what are you doing",
377
+ "what do you think",
378
+ "in my opinion",
379
+ "i believe",
380
+ "personally",
381
+ "from my experience",
382
+ "generally speaking",
383
+ "it is known that",
384
+ "everyone knows"
385
+ ]
386
+
387
+ answer_lower = answer.lower()
388
+ for pattern in hallucination_patterns:
389
+ if pattern in answer_lower:
390
+ return "The information needed to answer this question is not available in the provided documents."
391
+
392
+ # Check if answer contains information that's not in context
393
+ # Simple check: if answer is much longer than query and doesn't reference context
394
+ if len(answer) > len(query) * 3 and not any(word in answer.lower() for word in context.lower().split()[:20]):
395
+ return "The information needed to answer this question is not available in the provided documents."
396
+
397
+ # Clean up the answer
398
+ answer = answer.strip()
399
+
400
+ # Remove repetitive parts
401
+ sentences = answer.split('.')
402
+ unique_sentences = []
403
+ for sentence in sentences:
404
+ sentence = sentence.strip()
405
+ if sentence and sentence not in unique_sentences:
406
+ unique_sentences.append(sentence)
407
+
408
+ cleaned_answer = '. '.join(unique_sentences)
409
+
410
+ return cleaned_answer if cleaned_answer else "The information needed to answer this question is not available in the provided documents."
411
+
412
  def answer_question(self, query: str) -> str:
413
+ """Main function to answer questions with anti-hallucination measures"""
414
  if not query.strip():
415
  return "❓ Please ask a question!"
416
 
 
424
  if not context:
425
  return "πŸ” No relevant information found in the uploaded documents for your question."
426
 
427
+ # If no model available, use simple context-based answering
428
+ if self.model is None:
429
+ answer = self.simple_context_answer(query, context)
430
+ return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Source:** {context[:300]}..."
431
+
432
+ # Generate answer using the model
433
  answer = self.generate_answer(query, context)
434
 
435
+ # Additional validation to prevent hallucinations
436
  if answer and not answer.startswith("❌"):
437
+ # Check if answer seems to be hallucinated
438
+ if self.is_likely_hallucination(answer, context):
439
+ answer = "The information needed to answer this question is not available in the provided documents."
440
+
441
  return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Relevant Context:**\n{context[:400]}..."
442
  else:
443
  return answer
444
 
445
  except Exception as e:
446
  return f"❌ Error answering question: {str(e)}"
447
+
448
+ def is_likely_hallucination(self, answer: str, context: str) -> bool:
449
+ """Check if the answer is likely a hallucination"""
450
+ # Convert to lowercase for comparison
451
+ answer_lower = answer.lower()
452
+ context_lower = context.lower()
453
+
454
+ # Check for obvious hallucination patterns
455
+ hallucination_indicators = [
456
+ "what are you doing",
457
+ "what do you think",
458
+ "how are you",
459
+ "i think",
460
+ "in my opinion",
461
+ "from my experience",
462
+ "generally speaking",
463
+ "it is well known",
464
+ "everyone knows",
465
+ "obviously",
466
+ "clearly",
467
+ "of course"
468
+ ]
469
+
470
+ for indicator in hallucination_indicators:
471
+ if indicator in answer_lower:
472
+ return True
473
+
474
+ # Check if answer contains words that are not in context
475
+ answer_words = set(answer_lower.split())
476
+ context_words = set(context_lower.split())
477
+
478
+ # Remove common words
479
+ common_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'will', 'would', 'could', 'should', 'this', 'that', 'these', 'those'}
480
+
481
+ answer_content_words = answer_words - common_words
482
+ context_content_words = context_words - common_words
483
+
484
+ # If more than 70% of content words in answer are not in context, likely hallucination
485
+ if len(answer_content_words) > 0:
486
+ overlap = len(answer_content_words.intersection(context_content_words))
487
+ overlap_ratio = overlap / len(answer_content_words)
488
+
489
+ if overlap_ratio < 0.3: # Less than 30% overlap
490
+ return True
491
+
492
+ return False
493
 
494
  # Initialize the RAG system
495
  print("Initializing Document RAG System...")