pradeepsengarr commited on
Commit
a8283c8
Β·
verified Β·
1 Parent(s): 3529e03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -74
app.py CHANGED
@@ -27,56 +27,62 @@ class DocumentRAG:
27
  self.is_indexed = False
28
 
29
  def setup_llm(self):
30
- """Setup quantized Mistral model"""
31
- try:
32
- # Check if CUDA is available
33
- if not torch.cuda.is_available():
34
- print("⚠️ CUDA not available, falling back to CPU or alternative model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  self.setup_fallback_model()
36
- return
37
-
38
- quantization_config = BitsAndBytesConfig(
39
- load_in_4bit=True,
40
- bnb_4bit_compute_dtype=torch.float16,
41
- bnb_4bit_use_double_quant=True,
42
- bnb_4bit_quant_type="nf4"
43
- )
44
-
45
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
46
-
47
- # Load tokenizer first
48
- self.tokenizer = AutoTokenizer.from_pretrained(
49
- model_name,
50
- trust_remote_code=True
51
- )
52
-
53
- # Fix padding token issue
54
- if self.tokenizer.pad_token is None:
55
- self.tokenizer.pad_token = self.tokenizer.eos_token
56
-
57
- # Load model with quantization
58
- self.model = AutoModelForCausalLM.from_pretrained(
59
- model_name,
60
- quantization_config=quantization_config,
61
- device_map="auto",
62
- torch_dtype=torch.float16,
63
- trust_remote_code=True,
64
- low_cpu_mem_usage=True # Added for better memory management
65
- )
66
-
67
- print("βœ… Quantized Mistral model loaded successfully")
68
-
69
- except Exception as e:
70
- print(f"❌ Error loading model: {e}")
71
- print("πŸ”„ Falling back to alternative model...")
72
- self.setup_fallback_model()
73
-
74
  def setup_fallback_model(self):
75
  """Fallback to smaller model if Mistral fails"""
76
  try:
77
- model_name = "microsoft/DialoGPT-small"
 
78
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
79
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
80
  print("βœ… Fallback model loaded")
81
  except Exception as e:
82
  print(f"❌ Fallback model failed: {e}")
@@ -135,21 +141,35 @@ class DocumentRAG:
135
  except Exception as e2:
136
  return f"Error reading TXT: {str(e2)}"
137
 
138
- def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
139
- """Split text into overlapping chunks"""
140
  if not text.strip():
141
  return []
142
 
143
- words = text.split()
 
144
  chunks = []
 
145
 
146
- for i in range(0, len(words), chunk_size - overlap):
147
- chunk = ' '.join(words[i:i + chunk_size])
148
- if chunk.strip():
149
- chunks.append(chunk.strip())
 
 
 
150
 
151
- if i + chunk_size >= len(words):
152
- break
 
 
 
 
 
 
 
 
 
153
 
154
  return chunks
155
 
@@ -205,7 +225,7 @@ class DocumentRAG:
205
  except Exception as e:
206
  return f"❌ Error processing documents: {str(e)}"
207
 
208
- def retrieve_context(self, query: str, k: int = 3) -> str:
209
  """Retrieve relevant context for the query"""
210
  if not self.is_indexed:
211
  return ""
@@ -218,10 +238,10 @@ class DocumentRAG:
218
  # Search for similar chunks
219
  scores, indices = self.index.search(query_embedding.astype('float32'), k)
220
 
221
- # Get relevant documents
222
  relevant_docs = []
223
  for i, idx in enumerate(indices[0]):
224
- if idx < len(self.documents) and scores[0][i] > 0.1: # Similarity threshold
225
  relevant_docs.append(self.documents[idx])
226
 
227
  return "\n\n".join(relevant_docs)
@@ -231,52 +251,73 @@ class DocumentRAG:
231
  return ""
232
 
233
  def generate_answer(self, query: str, context: str) -> str:
234
- """Generate answer using the LLM"""
235
  if self.model is None or self.tokenizer is None:
236
  return "❌ Model not available. Please try again."
237
 
238
  try:
239
- # Create prompt
240
- prompt = f"""<s>[INST] Based on the following context, answer the question. If the answer is not in the context, say "I don't have enough information to answer this question."
 
 
 
 
 
241
 
242
  Context:
243
- {context[:2000]} # Limit context length
 
 
 
 
 
 
 
244
 
245
  Question: {query}
246
 
247
- Answer: [/INST]"""
248
 
249
- # Tokenize
250
  inputs = self.tokenizer(
251
  prompt,
252
  return_tensors="pt",
253
- max_length=1024,
254
  truncation=True,
255
  padding=True
256
  )
257
 
258
- # Generate
 
 
 
 
259
  with torch.no_grad():
260
  outputs = self.model.generate(
261
  **inputs,
262
- max_new_tokens=256,
263
- temperature=0.7,
264
  do_sample=True,
265
- top_p=0.9,
266
- pad_token_id=self.tokenizer.eos_token_id,
 
267
  eos_token_id=self.tokenizer.eos_token_id
268
  )
269
 
270
  # Decode response
271
  full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
272
 
273
- # Extract answer (remove the prompt part)
274
- if "[/INST]" in full_response:
275
  answer = full_response.split("[/INST]")[-1].strip()
276
  else:
 
277
  answer = full_response[len(prompt):].strip()
278
 
279
- return answer if answer else "I couldn't generate a proper response."
 
 
 
280
 
281
  except Exception as e:
282
  return f"❌ Error generating answer: {str(e)}"
@@ -294,12 +335,16 @@ Answer: [/INST]"""
294
  context = self.retrieve_context(query)
295
 
296
  if not context:
297
- return "πŸ” No relevant information found in the uploaded documents."
298
 
299
  # Generate answer
300
  answer = self.generate_answer(query, context)
301
 
302
- return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Source Context:** {context[:500]}..."
 
 
 
 
303
 
304
  except Exception as e:
305
  return f"❌ Error answering question: {str(e)}"
@@ -355,7 +400,7 @@ def create_interface():
355
  with gr.Column():
356
  answer_output = gr.Textbox(
357
  label="Answer",
358
- lines=10,
359
  interactive=False
360
  )
361
 
@@ -372,6 +417,7 @@ def create_interface():
372
  - Can you summarize the key points?
373
  - What are the conclusions mentioned?
374
  - Are there any specific numbers or statistics?
 
375
  """)
376
 
377
  return demo
 
27
  self.is_indexed = False
28
 
29
  def setup_llm(self):
30
+ """Setup quantized Mistral model"""
31
+ try:
32
+ # Check if CUDA is available
33
+ if not torch.cuda.is_available():
34
+ print("⚠️ CUDA not available, falling back to CPU or alternative model")
35
+ self.setup_fallback_model()
36
+ return
37
+
38
+ quantization_config = BitsAndBytesConfig(
39
+ load_in_4bit=True,
40
+ bnb_4bit_compute_dtype=torch.float16,
41
+ bnb_4bit_use_double_quant=True,
42
+ bnb_4bit_quant_type="nf4"
43
+ )
44
+
45
+ model_name = "mistralai/Mistral-7B-Instruct-v0.1"
46
+
47
+ # Load tokenizer first
48
+ self.tokenizer = AutoTokenizer.from_pretrained(
49
+ model_name,
50
+ trust_remote_code=True
51
+ )
52
+
53
+ # Fix padding token issue
54
+ if self.tokenizer.pad_token is None:
55
+ self.tokenizer.pad_token = self.tokenizer.eos_token
56
+
57
+ # Load model with quantization
58
+ self.model = AutoModelForCausalLM.from_pretrained(
59
+ model_name,
60
+ quantization_config=quantization_config,
61
+ device_map="auto",
62
+ torch_dtype=torch.float16,
63
+ trust_remote_code=True,
64
+ low_cpu_mem_usage=True
65
+ )
66
+
67
+ print("βœ… Quantized Mistral model loaded successfully")
68
+
69
+ except Exception as e:
70
+ print(f"❌ Error loading model: {e}")
71
+ print("πŸ”„ Falling back to alternative model...")
72
  self.setup_fallback_model()
73
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def setup_fallback_model(self):
75
  """Fallback to smaller model if Mistral fails"""
76
  try:
77
+ # Use a better fallback model for Q&A
78
+ model_name = "distilgpt2"
79
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
80
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
81
+
82
+ # Fix padding token for fallback model too
83
+ if self.tokenizer.pad_token is None:
84
+ self.tokenizer.pad_token = self.tokenizer.eos_token
85
+
86
  print("βœ… Fallback model loaded")
87
  except Exception as e:
88
  print(f"❌ Fallback model failed: {e}")
 
141
  except Exception as e2:
142
  return f"Error reading TXT: {str(e2)}"
143
 
144
+ def chunk_text(self, text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
145
+ """Split text into overlapping chunks with better sentence preservation"""
146
  if not text.strip():
147
  return []
148
 
149
+ # Split by sentences first, then group into chunks
150
+ sentences = text.replace('\n', ' ').split('. ')
151
  chunks = []
152
+ current_chunk = ""
153
 
154
+ for sentence in sentences:
155
+ sentence = sentence.strip()
156
+ if not sentence:
157
+ continue
158
+
159
+ # Add sentence to current chunk
160
+ test_chunk = current_chunk + ". " + sentence if current_chunk else sentence
161
 
162
+ # If chunk gets too long, save it and start new one
163
+ if len(test_chunk.split()) > chunk_size:
164
+ if current_chunk:
165
+ chunks.append(current_chunk.strip())
166
+ current_chunk = sentence
167
+ else:
168
+ current_chunk = test_chunk
169
+
170
+ # Add the last chunk
171
+ if current_chunk:
172
+ chunks.append(current_chunk.strip())
173
 
174
  return chunks
175
 
 
225
  except Exception as e:
226
  return f"❌ Error processing documents: {str(e)}"
227
 
228
+ def retrieve_context(self, query: str, k: int = 5) -> str:
229
  """Retrieve relevant context for the query"""
230
  if not self.is_indexed:
231
  return ""
 
238
  # Search for similar chunks
239
  scores, indices = self.index.search(query_embedding.astype('float32'), k)
240
 
241
+ # Get relevant documents with higher threshold
242
  relevant_docs = []
243
  for i, idx in enumerate(indices[0]):
244
+ if idx < len(self.documents) and scores[0][i] > 0.2: # Higher similarity threshold
245
  relevant_docs.append(self.documents[idx])
246
 
247
  return "\n\n".join(relevant_docs)
 
251
  return ""
252
 
253
  def generate_answer(self, query: str, context: str) -> str:
254
+ """Generate answer using the LLM with improved prompting"""
255
  if self.model is None or self.tokenizer is None:
256
  return "❌ Model not available. Please try again."
257
 
258
  try:
259
+ # Check if using Mistral (has specific prompt format) or fallback model
260
+ model_name = getattr(self.model.config, '_name_or_path', '').lower()
261
+ is_mistral = 'mistral' in model_name
262
+
263
+ if is_mistral:
264
+ # Mistral-specific prompt format
265
+ prompt = f"""<s>[INST] You are a helpful assistant that answers questions based on the provided context. Use only the information from the context to answer. If the information is not in the context, say "I don't have enough information to answer this question."
266
 
267
  Context:
268
+ {context[:1500]}
269
+
270
+ Question: {query}
271
+
272
+ Provide a clear and concise answer based only on the context above. [/INST]"""
273
+ else:
274
+ # Generic prompt for fallback models
275
+ prompt = f"""Context: {context[:1000]}
276
 
277
  Question: {query}
278
 
279
+ Answer based on the context:"""
280
 
281
+ # Tokenize with proper handling
282
  inputs = self.tokenizer(
283
  prompt,
284
  return_tensors="pt",
285
+ max_length=800, # Reduced to fit in memory
286
  truncation=True,
287
  padding=True
288
  )
289
 
290
+ # Move to same device as model
291
+ if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
292
+ inputs = {k: v.cuda() for k, v in inputs.items()}
293
+
294
+ # Generate with better parameters
295
  with torch.no_grad():
296
  outputs = self.model.generate(
297
  **inputs,
298
+ max_new_tokens=150, # Reduced for more focused answers
299
+ temperature=0.3, # Lower temperature for more consistent answers
300
  do_sample=True,
301
+ top_p=0.8,
302
+ repetition_penalty=1.1,
303
+ pad_token_id=self.tokenizer.pad_token_id,
304
  eos_token_id=self.tokenizer.eos_token_id
305
  )
306
 
307
  # Decode response
308
  full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
309
 
310
+ # Extract answer based on model type
311
+ if is_mistral and "[/INST]" in full_response:
312
  answer = full_response.split("[/INST]")[-1].strip()
313
  else:
314
+ # For other models, remove the prompt
315
  answer = full_response[len(prompt):].strip()
316
 
317
+ # Clean up the answer
318
+ answer = answer.replace(prompt, "").strip()
319
+
320
+ return answer if answer else "I couldn't generate a proper response based on the context."
321
 
322
  except Exception as e:
323
  return f"❌ Error generating answer: {str(e)}"
 
335
  context = self.retrieve_context(query)
336
 
337
  if not context:
338
+ return "πŸ” No relevant information found in the uploaded documents for your question."
339
 
340
  # Generate answer
341
  answer = self.generate_answer(query, context)
342
 
343
+ # Format the response
344
+ if answer and not answer.startswith("❌"):
345
+ return f"πŸ’‘ **Answer:** {answer}\n\nπŸ“„ **Relevant Context:**\n{context[:400]}..."
346
+ else:
347
+ return answer
348
 
349
  except Exception as e:
350
  return f"❌ Error answering question: {str(e)}"
 
400
  with gr.Column():
401
  answer_output = gr.Textbox(
402
  label="Answer",
403
+ lines=12,
404
  interactive=False
405
  )
406
 
 
417
  - Can you summarize the key points?
418
  - What are the conclusions mentioned?
419
  - Are there any specific numbers or statistics?
420
+ - Who are the main people or organizations mentioned?
421
  """)
422
 
423
  return demo