Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -92,11 +92,11 @@ class DocumentRAG:
|
|
92 |
print("β οΈ Using context-only mode (no text generation)")
|
93 |
|
94 |
def simple_context_answer(self, query: str, context: str) -> str:
|
95 |
-
"""
|
96 |
if not context:
|
97 |
return "No relevant information found in the documents."
|
98 |
|
99 |
-
#
|
100 |
query_words = set(query.lower().split())
|
101 |
context_sentences = context.split('.')
|
102 |
|
@@ -108,16 +108,20 @@ class DocumentRAG:
|
|
108 |
continue
|
109 |
|
110 |
sentence_words = set(sentence.lower().split())
|
111 |
-
# Check if sentence contains
|
112 |
common_words = query_words.intersection(sentence_words)
|
113 |
-
if len(common_words) >=
|
114 |
relevant_sentences.append(sentence)
|
115 |
|
116 |
if relevant_sentences:
|
117 |
# Return the most relevant sentences
|
118 |
return '. '.join(relevant_sentences[:3]) + '.'
|
119 |
else:
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
|
122 |
def extract_text_from_file(self, file_path: str) -> str:
|
123 |
"""Extract text from various file formats"""
|
@@ -171,7 +175,7 @@ class DocumentRAG:
|
|
171 |
except Exception as e2:
|
172 |
return f"Error reading TXT: {str(e2)}"
|
173 |
|
174 |
-
def chunk_text(self, text: str, chunk_size: int =
|
175 |
"""Split text into overlapping chunks with better sentence preservation"""
|
176 |
if not text.strip():
|
177 |
return []
|
@@ -256,7 +260,7 @@ class DocumentRAG:
|
|
256 |
return f"β Error processing documents: {str(e)}"
|
257 |
|
258 |
def retrieve_context(self, query: str, k: int = 5) -> str:
|
259 |
-
"""Retrieve relevant context for the query"""
|
260 |
if not self.is_indexed:
|
261 |
return ""
|
262 |
|
@@ -268,12 +272,20 @@ class DocumentRAG:
|
|
268 |
# Search for similar chunks
|
269 |
scores, indices = self.index.search(query_embedding.astype('float32'), k)
|
270 |
|
271 |
-
# Get relevant documents with
|
272 |
relevant_docs = []
|
273 |
for i, idx in enumerate(indices[0]):
|
274 |
-
if idx < len(self.documents) and scores[0][i] > 0.
|
275 |
relevant_docs.append(self.documents[idx])
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
return "\n\n".join(relevant_docs)
|
278 |
|
279 |
except Exception as e:
|
@@ -281,9 +293,9 @@ class DocumentRAG:
|
|
281 |
return ""
|
282 |
|
283 |
def generate_answer(self, query: str, context: str) -> str:
|
284 |
-
"""Generate answer using the LLM with
|
285 |
if self.model is None or self.tokenizer is None:
|
286 |
-
return
|
287 |
|
288 |
try:
|
289 |
# Check if using Mistral (has specific prompt format) or fallback model
|
@@ -291,37 +303,31 @@ class DocumentRAG:
|
|
291 |
is_mistral = 'mistral' in model_name
|
292 |
|
293 |
if is_mistral:
|
294 |
-
#
|
295 |
-
prompt = f"""<s>[INST] You are a document
|
296 |
-
|
297 |
-
STRICT RULES:
|
298 |
-
1. Answer ONLY using information from the context below
|
299 |
-
2. If the answer is not in the context, respond: "The information needed to answer this question is not available in the provided documents."
|
300 |
-
3. Do NOT make assumptions or add information not in the context
|
301 |
-
4. Quote relevant parts from the context when possible
|
302 |
|
303 |
-
|
304 |
-
{context[:
|
305 |
|
306 |
-
|
307 |
|
308 |
-
|
309 |
else:
|
310 |
-
#
|
311 |
-
prompt = f"""
|
312 |
|
313 |
-
|
314 |
-
{context[:
|
315 |
|
316 |
-
|
317 |
|
318 |
-
|
319 |
|
320 |
# Tokenize with proper handling
|
321 |
inputs = self.tokenizer(
|
322 |
prompt,
|
323 |
return_tensors="pt",
|
324 |
-
max_length=
|
325 |
truncation=True,
|
326 |
padding=True
|
327 |
)
|
@@ -330,17 +336,17 @@ ANSWER (using only the context above):"""
|
|
330 |
if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
|
331 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
332 |
|
333 |
-
# Generate with
|
334 |
with torch.no_grad():
|
335 |
outputs = self.model.generate(
|
336 |
**inputs,
|
337 |
-
max_new_tokens=
|
338 |
-
temperature=0.
|
339 |
-
do_sample=
|
340 |
-
|
|
|
341 |
early_stopping=True,
|
342 |
-
repetition_penalty=1.
|
343 |
-
no_repeat_ngram_size=3,
|
344 |
pad_token_id=self.tokenizer.pad_token_id,
|
345 |
eos_token_id=self.tokenizer.eos_token_id
|
346 |
)
|
@@ -353,64 +359,48 @@ ANSWER (using only the context above):"""
|
|
353 |
answer = full_response.split("[/INST]")[-1].strip()
|
354 |
else:
|
355 |
# For other models, remove the prompt
|
356 |
-
if "
|
357 |
-
answer = full_response.split("
|
358 |
else:
|
359 |
answer = full_response[len(prompt):].strip()
|
360 |
|
361 |
-
#
|
362 |
-
answer = self.
|
363 |
|
364 |
-
return answer if answer else
|
365 |
|
366 |
except Exception as e:
|
367 |
-
|
|
|
368 |
|
369 |
-
def
|
370 |
-
"""
|
371 |
if not answer or len(answer) < 5:
|
372 |
-
return "
|
373 |
|
374 |
-
# Remove
|
375 |
-
|
376 |
-
|
377 |
-
"what do you think",
|
378 |
-
"in my opinion",
|
379 |
-
"i believe",
|
380 |
-
"personally",
|
381 |
-
"from my experience",
|
382 |
-
"generally speaking",
|
383 |
-
"it is known that",
|
384 |
-
"everyone knows"
|
385 |
-
]
|
386 |
|
387 |
-
|
388 |
-
|
389 |
-
if pattern in
|
390 |
-
|
|
|
|
|
|
|
391 |
|
392 |
-
|
393 |
-
# Simple check: if answer is much longer than query and doesn't reference context
|
394 |
-
if len(answer) > len(query) * 3 and not any(word in answer.lower() for word in context.lower().split()[:20]):
|
395 |
-
return "The information needed to answer this question is not available in the provided documents."
|
396 |
|
397 |
-
#
|
398 |
-
|
|
|
|
|
399 |
|
400 |
-
|
401 |
-
sentences = answer.split('.')
|
402 |
-
unique_sentences = []
|
403 |
-
for sentence in sentences:
|
404 |
-
sentence = sentence.strip()
|
405 |
-
if sentence and sentence not in unique_sentences:
|
406 |
-
unique_sentences.append(sentence)
|
407 |
-
|
408 |
-
cleaned_answer = '. '.join(unique_sentences)
|
409 |
-
|
410 |
-
return cleaned_answer if cleaned_answer else "The information needed to answer this question is not available in the provided documents."
|
411 |
|
412 |
def answer_question(self, query: str) -> str:
|
413 |
-
"""Main function to answer questions with
|
414 |
if not query.strip():
|
415 |
return "β Please ask a question!"
|
416 |
|
@@ -419,77 +409,22 @@ ANSWER (using only the context above):"""
|
|
419 |
|
420 |
try:
|
421 |
# Retrieve relevant context
|
422 |
-
context = self.retrieve_context(query)
|
423 |
|
424 |
if not context:
|
425 |
return "π No relevant information found in the uploaded documents for your question."
|
426 |
|
427 |
-
#
|
428 |
-
if self.model is None:
|
429 |
-
answer = self.simple_context_answer(query, context)
|
430 |
-
return f"π‘ **Answer:** {answer}\n\nπ **Source:** {context[:300]}..."
|
431 |
-
|
432 |
-
# Generate answer using the model
|
433 |
answer = self.generate_answer(query, context)
|
434 |
|
435 |
-
|
436 |
-
|
437 |
-
# Check if answer seems to be hallucinated
|
438 |
-
if self.is_likely_hallucination(answer, context):
|
439 |
-
answer = "The information needed to answer this question is not available in the provided documents."
|
440 |
-
|
441 |
-
return f"π‘ **Answer:** {answer}\n\nπ **Relevant Context:**\n{context[:400]}..."
|
442 |
else:
|
443 |
-
|
|
|
444 |
|
445 |
except Exception as e:
|
446 |
return f"β Error answering question: {str(e)}"
|
447 |
-
|
448 |
-
def is_likely_hallucination(self, answer: str, context: str) -> bool:
|
449 |
-
"""Check if the answer is likely a hallucination"""
|
450 |
-
# Convert to lowercase for comparison
|
451 |
-
answer_lower = answer.lower()
|
452 |
-
context_lower = context.lower()
|
453 |
-
|
454 |
-
# Check for obvious hallucination patterns
|
455 |
-
hallucination_indicators = [
|
456 |
-
"what are you doing",
|
457 |
-
"what do you think",
|
458 |
-
"how are you",
|
459 |
-
"i think",
|
460 |
-
"in my opinion",
|
461 |
-
"from my experience",
|
462 |
-
"generally speaking",
|
463 |
-
"it is well known",
|
464 |
-
"everyone knows",
|
465 |
-
"obviously",
|
466 |
-
"clearly",
|
467 |
-
"of course"
|
468 |
-
]
|
469 |
-
|
470 |
-
for indicator in hallucination_indicators:
|
471 |
-
if indicator in answer_lower:
|
472 |
-
return True
|
473 |
-
|
474 |
-
# Check if answer contains words that are not in context
|
475 |
-
answer_words = set(answer_lower.split())
|
476 |
-
context_words = set(context_lower.split())
|
477 |
-
|
478 |
-
# Remove common words
|
479 |
-
common_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'will', 'would', 'could', 'should', 'this', 'that', 'these', 'those'}
|
480 |
-
|
481 |
-
answer_content_words = answer_words - common_words
|
482 |
-
context_content_words = context_words - common_words
|
483 |
-
|
484 |
-
# If more than 70% of content words in answer are not in context, likely hallucination
|
485 |
-
if len(answer_content_words) > 0:
|
486 |
-
overlap = len(answer_content_words.intersection(context_content_words))
|
487 |
-
overlap_ratio = overlap / len(answer_content_words)
|
488 |
-
|
489 |
-
if overlap_ratio < 0.3: # Less than 30% overlap
|
490 |
-
return True
|
491 |
-
|
492 |
-
return False
|
493 |
|
494 |
# Initialize the RAG system
|
495 |
print("Initializing Document RAG System...")
|
|
|
92 |
print("β οΈ Using context-only mode (no text generation)")
|
93 |
|
94 |
def simple_context_answer(self, query: str, context: str) -> str:
|
95 |
+
"""Improved context-based answering when model is not available"""
|
96 |
if not context:
|
97 |
return "No relevant information found in the documents."
|
98 |
|
99 |
+
# Improved keyword matching approach
|
100 |
query_words = set(query.lower().split())
|
101 |
context_sentences = context.split('.')
|
102 |
|
|
|
108 |
continue
|
109 |
|
110 |
sentence_words = set(sentence.lower().split())
|
111 |
+
# Check if sentence contains query keywords
|
112 |
common_words = query_words.intersection(sentence_words)
|
113 |
+
if len(common_words) >= 1: # Lowered threshold
|
114 |
relevant_sentences.append(sentence)
|
115 |
|
116 |
if relevant_sentences:
|
117 |
# Return the most relevant sentences
|
118 |
return '. '.join(relevant_sentences[:3]) + '.'
|
119 |
else:
|
120 |
+
# If no exact matches, return first few sentences of context
|
121 |
+
first_sentences = context_sentences[:2]
|
122 |
+
if first_sentences:
|
123 |
+
return '. '.join([s.strip() for s in first_sentences if s.strip()]) + '.'
|
124 |
+
return "Based on the document content, I found some information but cannot provide a specific answer to your question."
|
125 |
|
126 |
def extract_text_from_file(self, file_path: str) -> str:
|
127 |
"""Extract text from various file formats"""
|
|
|
175 |
except Exception as e2:
|
176 |
return f"Error reading TXT: {str(e2)}"
|
177 |
|
178 |
+
def chunk_text(self, text: str, chunk_size: int = 200, overlap: int = 30) -> List[str]:
|
179 |
"""Split text into overlapping chunks with better sentence preservation"""
|
180 |
if not text.strip():
|
181 |
return []
|
|
|
260 |
return f"β Error processing documents: {str(e)}"
|
261 |
|
262 |
def retrieve_context(self, query: str, k: int = 5) -> str:
|
263 |
+
"""Retrieve relevant context for the query with improved retrieval"""
|
264 |
if not self.is_indexed:
|
265 |
return ""
|
266 |
|
|
|
272 |
# Search for similar chunks
|
273 |
scores, indices = self.index.search(query_embedding.astype('float32'), k)
|
274 |
|
275 |
+
# Get relevant documents with MUCH LOWER threshold
|
276 |
relevant_docs = []
|
277 |
for i, idx in enumerate(indices[0]):
|
278 |
+
if idx < len(self.documents) and scores[0][i] > 0.05: # Much lower threshold
|
279 |
relevant_docs.append(self.documents[idx])
|
280 |
|
281 |
+
# If no high-similarity matches, take the top results anyway
|
282 |
+
if not relevant_docs:
|
283 |
+
for i, idx in enumerate(indices[0]):
|
284 |
+
if idx < len(self.documents):
|
285 |
+
relevant_docs.append(self.documents[idx])
|
286 |
+
if len(relevant_docs) >= 3: # Take at least 3 chunks
|
287 |
+
break
|
288 |
+
|
289 |
return "\n\n".join(relevant_docs)
|
290 |
|
291 |
except Exception as e:
|
|
|
293 |
return ""
|
294 |
|
295 |
def generate_answer(self, query: str, context: str) -> str:
|
296 |
+
"""Generate answer using the LLM with improved prompting"""
|
297 |
if self.model is None or self.tokenizer is None:
|
298 |
+
return self.simple_context_answer(query, context)
|
299 |
|
300 |
try:
|
301 |
# Check if using Mistral (has specific prompt format) or fallback model
|
|
|
303 |
is_mistral = 'mistral' in model_name
|
304 |
|
305 |
if is_mistral:
|
306 |
+
# Improved prompt for Mistral - more flexible
|
307 |
+
prompt = f"""<s>[INST] You are a helpful document assistant. Answer the question based on the provided context. If the exact answer isn't in the context, provide the most relevant information available.
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
+
Context:
|
310 |
+
{context[:1500]}
|
311 |
|
312 |
+
Question: {query}
|
313 |
|
314 |
+
Please provide a helpful answer based on the available information. [/INST]"""
|
315 |
else:
|
316 |
+
# Improved prompt for fallback models
|
317 |
+
prompt = f"""Based on the following information, please answer the question:
|
318 |
|
319 |
+
Context:
|
320 |
+
{context[:1000]}
|
321 |
|
322 |
+
Question: {query}
|
323 |
|
324 |
+
Answer:"""
|
325 |
|
326 |
# Tokenize with proper handling
|
327 |
inputs = self.tokenizer(
|
328 |
prompt,
|
329 |
return_tensors="pt",
|
330 |
+
max_length=800,
|
331 |
truncation=True,
|
332 |
padding=True
|
333 |
)
|
|
|
336 |
if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
|
337 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
338 |
|
339 |
+
# Generate with more flexible parameters
|
340 |
with torch.no_grad():
|
341 |
outputs = self.model.generate(
|
342 |
**inputs,
|
343 |
+
max_new_tokens=150,
|
344 |
+
temperature=0.3, # Slightly higher for more natural responses
|
345 |
+
do_sample=True,
|
346 |
+
top_p=0.9,
|
347 |
+
num_beams=2,
|
348 |
early_stopping=True,
|
349 |
+
repetition_penalty=1.1,
|
|
|
350 |
pad_token_id=self.tokenizer.pad_token_id,
|
351 |
eos_token_id=self.tokenizer.eos_token_id
|
352 |
)
|
|
|
359 |
answer = full_response.split("[/INST]")[-1].strip()
|
360 |
else:
|
361 |
# For other models, remove the prompt
|
362 |
+
if "Answer:" in full_response:
|
363 |
+
answer = full_response.split("Answer:")[-1].strip()
|
364 |
else:
|
365 |
answer = full_response[len(prompt):].strip()
|
366 |
|
367 |
+
# Clean up the answer
|
368 |
+
answer = self.clean_answer(answer)
|
369 |
|
370 |
+
return answer if answer else self.simple_context_answer(query, context)
|
371 |
|
372 |
except Exception as e:
|
373 |
+
print(f"Error in generation: {e}")
|
374 |
+
return self.simple_context_answer(query, context)
|
375 |
|
376 |
+
def clean_answer(self, answer: str) -> str:
|
377 |
+
"""Clean up the generated answer"""
|
378 |
if not answer or len(answer) < 5:
|
379 |
+
return ""
|
380 |
|
381 |
+
# Remove obvious problematic patterns
|
382 |
+
lines = answer.split('\n')
|
383 |
+
cleaned_lines = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
+
for line in lines:
|
386 |
+
line = line.strip()
|
387 |
+
if line and not any(pattern in line.lower() for pattern in [
|
388 |
+
'what are you doing', 'what do you think', 'how are you',
|
389 |
+
'i am an ai', 'i cannot', 'i don\'t know'
|
390 |
+
]):
|
391 |
+
cleaned_lines.append(line)
|
392 |
|
393 |
+
cleaned_answer = ' '.join(cleaned_lines)
|
|
|
|
|
|
|
394 |
|
395 |
+
# Limit length to prevent rambling
|
396 |
+
if len(cleaned_answer) > 500:
|
397 |
+
sentences = cleaned_answer.split('.')
|
398 |
+
cleaned_answer = '. '.join(sentences[:3]) + '.'
|
399 |
|
400 |
+
return cleaned_answer.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
|
402 |
def answer_question(self, query: str) -> str:
|
403 |
+
"""Main function to answer questions with improved handling"""
|
404 |
if not query.strip():
|
405 |
return "β Please ask a question!"
|
406 |
|
|
|
409 |
|
410 |
try:
|
411 |
# Retrieve relevant context
|
412 |
+
context = self.retrieve_context(query, k=7) # Get more chunks
|
413 |
|
414 |
if not context:
|
415 |
return "π No relevant information found in the uploaded documents for your question."
|
416 |
|
417 |
+
# Generate answer
|
|
|
|
|
|
|
|
|
|
|
418 |
answer = self.generate_answer(query, context)
|
419 |
|
420 |
+
if answer and len(answer) > 10:
|
421 |
+
return f"π‘ **Answer:** {answer}\n\nπ **Source Context:**\n{context[:300]}..."
|
|
|
|
|
|
|
|
|
|
|
422 |
else:
|
423 |
+
# Fallback to simple context display
|
424 |
+
return f"π **Based on the document content:**\n{context[:500]}..."
|
425 |
|
426 |
except Exception as e:
|
427 |
return f"β Error answering question: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
|
429 |
# Initialize the RAG system
|
430 |
print("Initializing Document RAG System...")
|