pradeepsengarr commited on
Commit
8b78b3b
Β·
verified Β·
1 Parent(s): 611ac83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -141
app.py CHANGED
@@ -92,11 +92,11 @@ class DocumentRAG:
92
  print("⚠️ Using context-only mode (no text generation)")
93
 
94
  def simple_context_answer(self, query: str, context: str) -> str:
95
- """Simple context-based answering when model is not available"""
96
  if not context:
97
  return "No relevant information found in the documents."
98
 
99
- # Simple keyword matching approach
100
  query_words = set(query.lower().split())
101
  context_sentences = context.split('.')
102
 
@@ -108,16 +108,20 @@ class DocumentRAG:
108
  continue
109
 
110
  sentence_words = set(sentence.lower().split())
111
- # Check if sentence contains at least 2 query words or important keywords
112
  common_words = query_words.intersection(sentence_words)
113
- if len(common_words) >= 2 or any(word in sentence.lower() for word in ['name', 'education', 'experience', 'skill', 'project']):
114
  relevant_sentences.append(sentence)
115
 
116
  if relevant_sentences:
117
  # Return the most relevant sentences
118
  return '. '.join(relevant_sentences[:3]) + '.'
119
  else:
120
- return "The information needed to answer this question is not available in the provided documents."
 
 
 
 
121
 
122
  def extract_text_from_file(self, file_path: str) -> str:
123
  """Extract text from various file formats"""
@@ -171,7 +175,7 @@ class DocumentRAG:
171
  except Exception as e2:
172
  return f"Error reading TXT: {str(e2)}"
173
 
174
- def chunk_text(self, text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
175
  """Split text into overlapping chunks with better sentence preservation"""
176
  if not text.strip():
177
  return []
@@ -256,7 +260,7 @@ class DocumentRAG:
256
  return f"❌ Error processing documents: {str(e)}"
257
 
258
  def retrieve_context(self, query: str, k: int = 5) -> str:
259
- """Retrieve relevant context for the query"""
260
  if not self.is_indexed:
261
  return ""
262
 
@@ -268,12 +272,20 @@ class DocumentRAG:
268
  # Search for similar chunks
269
  scores, indices = self.index.search(query_embedding.astype('float32'), k)
270
 
271
- # Get relevant documents with higher threshold
272
  relevant_docs = []
273
  for i, idx in enumerate(indices[0]):
274
- if idx < len(self.documents) and scores[0][i] > 0.2: # Higher similarity threshold
275
  relevant_docs.append(self.documents[idx])
276
 
 
 
 
 
 
 
 
 
277
  return "\n\n".join(relevant_docs)
278
 
279
  except Exception as e:
@@ -281,9 +293,9 @@ class DocumentRAG:
281
  return ""
282
 
283
  def generate_answer(self, query: str, context: str) -> str:
284
- """Generate answer using the LLM with anti-hallucination techniques"""
285
  if self.model is None or self.tokenizer is None:
286
- return "❌ Model not available. Please try again."
287
 
288
  try:
289
  # Check if using Mistral (has specific prompt format) or fallback model
@@ -291,37 +303,31 @@ class DocumentRAG:
291
  is_mistral = 'mistral' in model_name
292
 
293
  if is_mistral:
294
- # Anti-hallucination prompt for Mistral
295
- prompt = f"""<s>[INST] You are a document analysis assistant. You must ONLY answer based on the provided context. Do NOT use any external knowledge.
296
-
297
- STRICT RULES:
298
- 1. Answer ONLY using information from the context below
299
- 2. If the answer is not in the context, respond: "The information needed to answer this question is not available in the provided documents."
300
- 3. Do NOT make assumptions or add information not in the context
301
- 4. Quote relevant parts from the context when possible
302
 
303
- CONTEXT:
304
- {context[:1200]}
305
 
306
- QUESTION: {query}
307
 
308
- Remember: Use ONLY the context above. No external knowledge allowed. [/INST]"""
309
  else:
310
- # Anti-hallucination prompt for fallback models
311
- prompt = f"""INSTRUCTIONS: Answer the question using ONLY the information provided in the context. Do not use external knowledge.
312
 
313
- CONTEXT:
314
- {context[:800]}
315
 
316
- QUESTION: {query}
317
 
318
- ANSWER (using only the context above):"""
319
 
320
  # Tokenize with proper handling
321
  inputs = self.tokenizer(
322
  prompt,
323
  return_tensors="pt",
324
- max_length=600, # Further reduced to prevent truncation issues
325
  truncation=True,
326
  padding=True
327
  )
@@ -330,17 +336,17 @@ ANSWER (using only the context above):"""
330
  if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
331
  inputs = {k: v.cuda() for k, v in inputs.items()}
332
 
333
- # Generate with anti-hallucination parameters
334
  with torch.no_grad():
335
  outputs = self.model.generate(
336
  **inputs,
337
- max_new_tokens=100, # Shorter responses to reduce hallucination
338
- temperature=0.1, # Very low temperature for factual responses
339
- do_sample=False, # Use greedy decoding for consistency
340
- num_beams=3, # Beam search for better quality
 
341
  early_stopping=True,
342
- repetition_penalty=1.2,
343
- no_repeat_ngram_size=3,
344
  pad_token_id=self.tokenizer.pad_token_id,
345
  eos_token_id=self.tokenizer.eos_token_id
346
  )
@@ -353,64 +359,48 @@ ANSWER (using only the context above):"""
353
  answer = full_response.split("[/INST]")[-1].strip()
354
  else:
355
  # For other models, remove the prompt
356
- if "ANSWER (using only the context above):" in full_response:
357
- answer = full_response.split("ANSWER (using only the context above):")[-1].strip()
358
  else:
359
  answer = full_response[len(prompt):].strip()
360
 
361
- # Post-process to remove hallucinations
362
- answer = self.post_process_answer(answer, context, query)
363
 
364
- return answer if answer else "I couldn't generate a proper response based on the context."
365
 
366
  except Exception as e:
367
- return f"❌ Error generating answer: {str(e)}"
 
368
 
369
- def post_process_answer(self, answer: str, context: str, query: str) -> str:
370
- """Post-process answer to reduce hallucinations"""
371
  if not answer or len(answer) < 5:
372
- return "The information needed to answer this question is not available in the provided documents."
373
 
374
- # Remove common hallucination patterns
375
- hallucination_patterns = [
376
- "what are you doing",
377
- "what do you think",
378
- "in my opinion",
379
- "i believe",
380
- "personally",
381
- "from my experience",
382
- "generally speaking",
383
- "it is known that",
384
- "everyone knows"
385
- ]
386
 
387
- answer_lower = answer.lower()
388
- for pattern in hallucination_patterns:
389
- if pattern in answer_lower:
390
- return "The information needed to answer this question is not available in the provided documents."
 
 
 
391
 
392
- # Check if answer contains information that's not in context
393
- # Simple check: if answer is much longer than query and doesn't reference context
394
- if len(answer) > len(query) * 3 and not any(word in answer.lower() for word in context.lower().split()[:20]):
395
- return "The information needed to answer this question is not available in the provided documents."
396
 
397
- # Clean up the answer
398
- answer = answer.strip()
 
 
399
 
400
- # Remove repetitive parts
401
- sentences = answer.split('.')
402
- unique_sentences = []
403
- for sentence in sentences:
404
- sentence = sentence.strip()
405
- if sentence and sentence not in unique_sentences:
406
- unique_sentences.append(sentence)
407
-
408
- cleaned_answer = '. '.join(unique_sentences)
409
-
410
- return cleaned_answer if cleaned_answer else "The information needed to answer this question is not available in the provided documents."
411
 
412
  def answer_question(self, query: str) -> str:
413
- """Main function to answer questions with anti-hallucination measures"""
414
  if not query.strip():
415
  return "❓ Please ask a question!"
416
 
@@ -419,77 +409,22 @@ ANSWER (using only the context above):"""
419
 
420
  try:
421
  # Retrieve relevant context
422
- context = self.retrieve_context(query)
423
 
424
  if not context:
425
  return "πŸ” No relevant information found in the uploaded documents for your question."
426
 
427
- # If no model available, use simple context-based answering
428
- if self.model is None:
429
- answer = self.simple_context_answer(query, context)
430
- return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Source:** {context[:300]}..."
431
-
432
- # Generate answer using the model
433
  answer = self.generate_answer(query, context)
434
 
435
- # Additional validation to prevent hallucinations
436
- if answer and not answer.startswith("❌"):
437
- # Check if answer seems to be hallucinated
438
- if self.is_likely_hallucination(answer, context):
439
- answer = "The information needed to answer this question is not available in the provided documents."
440
-
441
- return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Relevant Context:**\n{context[:400]}..."
442
  else:
443
- return answer
 
444
 
445
  except Exception as e:
446
  return f"❌ Error answering question: {str(e)}"
447
-
448
- def is_likely_hallucination(self, answer: str, context: str) -> bool:
449
- """Check if the answer is likely a hallucination"""
450
- # Convert to lowercase for comparison
451
- answer_lower = answer.lower()
452
- context_lower = context.lower()
453
-
454
- # Check for obvious hallucination patterns
455
- hallucination_indicators = [
456
- "what are you doing",
457
- "what do you think",
458
- "how are you",
459
- "i think",
460
- "in my opinion",
461
- "from my experience",
462
- "generally speaking",
463
- "it is well known",
464
- "everyone knows",
465
- "obviously",
466
- "clearly",
467
- "of course"
468
- ]
469
-
470
- for indicator in hallucination_indicators:
471
- if indicator in answer_lower:
472
- return True
473
-
474
- # Check if answer contains words that are not in context
475
- answer_words = set(answer_lower.split())
476
- context_words = set(context_lower.split())
477
-
478
- # Remove common words
479
- common_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'will', 'would', 'could', 'should', 'this', 'that', 'these', 'those'}
480
-
481
- answer_content_words = answer_words - common_words
482
- context_content_words = context_words - common_words
483
-
484
- # If more than 70% of content words in answer are not in context, likely hallucination
485
- if len(answer_content_words) > 0:
486
- overlap = len(answer_content_words.intersection(context_content_words))
487
- overlap_ratio = overlap / len(answer_content_words)
488
-
489
- if overlap_ratio < 0.3: # Less than 30% overlap
490
- return True
491
-
492
- return False
493
 
494
  # Initialize the RAG system
495
  print("Initializing Document RAG System...")
 
92
  print("⚠️ Using context-only mode (no text generation)")
93
 
94
  def simple_context_answer(self, query: str, context: str) -> str:
95
+ """Improved context-based answering when model is not available"""
96
  if not context:
97
  return "No relevant information found in the documents."
98
 
99
+ # Improved keyword matching approach
100
  query_words = set(query.lower().split())
101
  context_sentences = context.split('.')
102
 
 
108
  continue
109
 
110
  sentence_words = set(sentence.lower().split())
111
+ # Check if sentence contains query keywords
112
  common_words = query_words.intersection(sentence_words)
113
+ if len(common_words) >= 1: # Lowered threshold
114
  relevant_sentences.append(sentence)
115
 
116
  if relevant_sentences:
117
  # Return the most relevant sentences
118
  return '. '.join(relevant_sentences[:3]) + '.'
119
  else:
120
+ # If no exact matches, return first few sentences of context
121
+ first_sentences = context_sentences[:2]
122
+ if first_sentences:
123
+ return '. '.join([s.strip() for s in first_sentences if s.strip()]) + '.'
124
+ return "Based on the document content, I found some information but cannot provide a specific answer to your question."
125
 
126
  def extract_text_from_file(self, file_path: str) -> str:
127
  """Extract text from various file formats"""
 
175
  except Exception as e2:
176
  return f"Error reading TXT: {str(e2)}"
177
 
178
+ def chunk_text(self, text: str, chunk_size: int = 200, overlap: int = 30) -> List[str]:
179
  """Split text into overlapping chunks with better sentence preservation"""
180
  if not text.strip():
181
  return []
 
260
  return f"❌ Error processing documents: {str(e)}"
261
 
262
  def retrieve_context(self, query: str, k: int = 5) -> str:
263
+ """Retrieve relevant context for the query with improved retrieval"""
264
  if not self.is_indexed:
265
  return ""
266
 
 
272
  # Search for similar chunks
273
  scores, indices = self.index.search(query_embedding.astype('float32'), k)
274
 
275
+ # Get relevant documents with MUCH LOWER threshold
276
  relevant_docs = []
277
  for i, idx in enumerate(indices[0]):
278
+ if idx < len(self.documents) and scores[0][i] > 0.05: # Much lower threshold
279
  relevant_docs.append(self.documents[idx])
280
 
281
+ # If no high-similarity matches, take the top results anyway
282
+ if not relevant_docs:
283
+ for i, idx in enumerate(indices[0]):
284
+ if idx < len(self.documents):
285
+ relevant_docs.append(self.documents[idx])
286
+ if len(relevant_docs) >= 3: # Take at least 3 chunks
287
+ break
288
+
289
  return "\n\n".join(relevant_docs)
290
 
291
  except Exception as e:
 
293
  return ""
294
 
295
  def generate_answer(self, query: str, context: str) -> str:
296
+ """Generate answer using the LLM with improved prompting"""
297
  if self.model is None or self.tokenizer is None:
298
+ return self.simple_context_answer(query, context)
299
 
300
  try:
301
  # Check if using Mistral (has specific prompt format) or fallback model
 
303
  is_mistral = 'mistral' in model_name
304
 
305
  if is_mistral:
306
+ # Improved prompt for Mistral - more flexible
307
+ prompt = f"""<s>[INST] You are a helpful document assistant. Answer the question based on the provided context. If the exact answer isn't in the context, provide the most relevant information available.
 
 
 
 
 
 
308
 
309
+ Context:
310
+ {context[:1500]}
311
 
312
+ Question: {query}
313
 
314
+ Please provide a helpful answer based on the available information. [/INST]"""
315
  else:
316
+ # Improved prompt for fallback models
317
+ prompt = f"""Based on the following information, please answer the question:
318
 
319
+ Context:
320
+ {context[:1000]}
321
 
322
+ Question: {query}
323
 
324
+ Answer:"""
325
 
326
  # Tokenize with proper handling
327
  inputs = self.tokenizer(
328
  prompt,
329
  return_tensors="pt",
330
+ max_length=800,
331
  truncation=True,
332
  padding=True
333
  )
 
336
  if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
337
  inputs = {k: v.cuda() for k, v in inputs.items()}
338
 
339
+ # Generate with more flexible parameters
340
  with torch.no_grad():
341
  outputs = self.model.generate(
342
  **inputs,
343
+ max_new_tokens=150,
344
+ temperature=0.3, # Slightly higher for more natural responses
345
+ do_sample=True,
346
+ top_p=0.9,
347
+ num_beams=2,
348
  early_stopping=True,
349
+ repetition_penalty=1.1,
 
350
  pad_token_id=self.tokenizer.pad_token_id,
351
  eos_token_id=self.tokenizer.eos_token_id
352
  )
 
359
  answer = full_response.split("[/INST]")[-1].strip()
360
  else:
361
  # For other models, remove the prompt
362
+ if "Answer:" in full_response:
363
+ answer = full_response.split("Answer:")[-1].strip()
364
  else:
365
  answer = full_response[len(prompt):].strip()
366
 
367
+ # Clean up the answer
368
+ answer = self.clean_answer(answer)
369
 
370
+ return answer if answer else self.simple_context_answer(query, context)
371
 
372
  except Exception as e:
373
+ print(f"Error in generation: {e}")
374
+ return self.simple_context_answer(query, context)
375
 
376
+ def clean_answer(self, answer: str) -> str:
377
+ """Clean up the generated answer"""
378
  if not answer or len(answer) < 5:
379
+ return ""
380
 
381
+ # Remove obvious problematic patterns
382
+ lines = answer.split('\n')
383
+ cleaned_lines = []
 
 
 
 
 
 
 
 
 
384
 
385
+ for line in lines:
386
+ line = line.strip()
387
+ if line and not any(pattern in line.lower() for pattern in [
388
+ 'what are you doing', 'what do you think', 'how are you',
389
+ 'i am an ai', 'i cannot', 'i don\'t know'
390
+ ]):
391
+ cleaned_lines.append(line)
392
 
393
+ cleaned_answer = ' '.join(cleaned_lines)
 
 
 
394
 
395
+ # Limit length to prevent rambling
396
+ if len(cleaned_answer) > 500:
397
+ sentences = cleaned_answer.split('.')
398
+ cleaned_answer = '. '.join(sentences[:3]) + '.'
399
 
400
+ return cleaned_answer.strip()
 
 
 
 
 
 
 
 
 
 
401
 
402
  def answer_question(self, query: str) -> str:
403
+ """Main function to answer questions with improved handling"""
404
  if not query.strip():
405
  return "❓ Please ask a question!"
406
 
 
409
 
410
  try:
411
  # Retrieve relevant context
412
+ context = self.retrieve_context(query, k=7) # Get more chunks
413
 
414
  if not context:
415
  return "πŸ” No relevant information found in the uploaded documents for your question."
416
 
417
+ # Generate answer
 
 
 
 
 
418
  answer = self.generate_answer(query, context)
419
 
420
+ if answer and len(answer) > 10:
421
+ return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Source Context:**\n{context[:300]}..."
 
 
 
 
 
422
  else:
423
+ # Fallback to simple context display
424
+ return f"πŸ“„ **Based on the document content:**\n{context[:500]}..."
425
 
426
  except Exception as e:
427
  return f"❌ Error answering question: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  # Initialize the RAG system
430
  print("Initializing Document RAG System...")