Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -27,56 +27,62 @@ class DocumentRAG:
|
|
27 |
self.is_indexed = False
|
28 |
|
29 |
def setup_llm(self):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
self.setup_fallback_model()
|
36 |
-
|
37 |
-
|
38 |
-
quantization_config = BitsAndBytesConfig(
|
39 |
-
load_in_4bit=True,
|
40 |
-
bnb_4bit_compute_dtype=torch.float16,
|
41 |
-
bnb_4bit_use_double_quant=True,
|
42 |
-
bnb_4bit_quant_type="nf4"
|
43 |
-
)
|
44 |
-
|
45 |
-
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
|
46 |
-
|
47 |
-
# Load tokenizer first
|
48 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
49 |
-
model_name,
|
50 |
-
trust_remote_code=True
|
51 |
-
)
|
52 |
-
|
53 |
-
# Fix padding token issue
|
54 |
-
if self.tokenizer.pad_token is None:
|
55 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
56 |
-
|
57 |
-
# Load model with quantization
|
58 |
-
self.model = AutoModelForCausalLM.from_pretrained(
|
59 |
-
model_name,
|
60 |
-
quantization_config=quantization_config,
|
61 |
-
device_map="auto",
|
62 |
-
torch_dtype=torch.float16,
|
63 |
-
trust_remote_code=True,
|
64 |
-
low_cpu_mem_usage=True # Added for better memory management
|
65 |
-
)
|
66 |
-
|
67 |
-
print("β
Quantized Mistral model loaded successfully")
|
68 |
-
|
69 |
-
except Exception as e:
|
70 |
-
print(f"β Error loading model: {e}")
|
71 |
-
print("π Falling back to alternative model...")
|
72 |
-
self.setup_fallback_model()
|
73 |
-
|
74 |
def setup_fallback_model(self):
|
75 |
"""Fallback to smaller model if Mistral fails"""
|
76 |
try:
|
77 |
-
|
|
|
78 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
79 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
80 |
print("β
Fallback model loaded")
|
81 |
except Exception as e:
|
82 |
print(f"β Fallback model failed: {e}")
|
@@ -135,21 +141,35 @@ class DocumentRAG:
|
|
135 |
except Exception as e2:
|
136 |
return f"Error reading TXT: {str(e2)}"
|
137 |
|
138 |
-
def chunk_text(self, text: str, chunk_size: int =
|
139 |
-
"""Split text into overlapping chunks"""
|
140 |
if not text.strip():
|
141 |
return []
|
142 |
|
143 |
-
|
|
|
144 |
chunks = []
|
|
|
145 |
|
146 |
-
for
|
147 |
-
|
148 |
-
if
|
149 |
-
|
|
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
return chunks
|
155 |
|
@@ -205,7 +225,7 @@ class DocumentRAG:
|
|
205 |
except Exception as e:
|
206 |
return f"β Error processing documents: {str(e)}"
|
207 |
|
208 |
-
def retrieve_context(self, query: str, k: int =
|
209 |
"""Retrieve relevant context for the query"""
|
210 |
if not self.is_indexed:
|
211 |
return ""
|
@@ -218,10 +238,10 @@ class DocumentRAG:
|
|
218 |
# Search for similar chunks
|
219 |
scores, indices = self.index.search(query_embedding.astype('float32'), k)
|
220 |
|
221 |
-
# Get relevant documents
|
222 |
relevant_docs = []
|
223 |
for i, idx in enumerate(indices[0]):
|
224 |
-
if idx < len(self.documents) and scores[0][i] > 0.
|
225 |
relevant_docs.append(self.documents[idx])
|
226 |
|
227 |
return "\n\n".join(relevant_docs)
|
@@ -231,52 +251,73 @@ class DocumentRAG:
|
|
231 |
return ""
|
232 |
|
233 |
def generate_answer(self, query: str, context: str) -> str:
|
234 |
-
"""Generate answer using the LLM"""
|
235 |
if self.model is None or self.tokenizer is None:
|
236 |
return "β Model not available. Please try again."
|
237 |
|
238 |
try:
|
239 |
-
#
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
Context:
|
243 |
-
{context[:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
Question: {query}
|
246 |
|
247 |
-
Answer:
|
248 |
|
249 |
-
# Tokenize
|
250 |
inputs = self.tokenizer(
|
251 |
prompt,
|
252 |
return_tensors="pt",
|
253 |
-
max_length=
|
254 |
truncation=True,
|
255 |
padding=True
|
256 |
)
|
257 |
|
258 |
-
#
|
|
|
|
|
|
|
|
|
259 |
with torch.no_grad():
|
260 |
outputs = self.model.generate(
|
261 |
**inputs,
|
262 |
-
max_new_tokens=
|
263 |
-
temperature=0.
|
264 |
do_sample=True,
|
265 |
-
top_p=0.
|
266 |
-
|
|
|
267 |
eos_token_id=self.tokenizer.eos_token_id
|
268 |
)
|
269 |
|
270 |
# Decode response
|
271 |
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
272 |
|
273 |
-
# Extract answer
|
274 |
-
if "[/INST]" in full_response:
|
275 |
answer = full_response.split("[/INST]")[-1].strip()
|
276 |
else:
|
|
|
277 |
answer = full_response[len(prompt):].strip()
|
278 |
|
279 |
-
|
|
|
|
|
|
|
280 |
|
281 |
except Exception as e:
|
282 |
return f"β Error generating answer: {str(e)}"
|
@@ -294,12 +335,16 @@ Answer: [/INST]"""
|
|
294 |
context = self.retrieve_context(query)
|
295 |
|
296 |
if not context:
|
297 |
-
return "π No relevant information found in the uploaded documents."
|
298 |
|
299 |
# Generate answer
|
300 |
answer = self.generate_answer(query, context)
|
301 |
|
302 |
-
|
|
|
|
|
|
|
|
|
303 |
|
304 |
except Exception as e:
|
305 |
return f"β Error answering question: {str(e)}"
|
@@ -355,7 +400,7 @@ def create_interface():
|
|
355 |
with gr.Column():
|
356 |
answer_output = gr.Textbox(
|
357 |
label="Answer",
|
358 |
-
lines=
|
359 |
interactive=False
|
360 |
)
|
361 |
|
@@ -372,6 +417,7 @@ def create_interface():
|
|
372 |
- Can you summarize the key points?
|
373 |
- What are the conclusions mentioned?
|
374 |
- Are there any specific numbers or statistics?
|
|
|
375 |
""")
|
376 |
|
377 |
return demo
|
|
|
27 |
self.is_indexed = False
|
28 |
|
29 |
def setup_llm(self):
|
30 |
+
"""Setup quantized Mistral model"""
|
31 |
+
try:
|
32 |
+
# Check if CUDA is available
|
33 |
+
if not torch.cuda.is_available():
|
34 |
+
print("β οΈ CUDA not available, falling back to CPU or alternative model")
|
35 |
+
self.setup_fallback_model()
|
36 |
+
return
|
37 |
+
|
38 |
+
quantization_config = BitsAndBytesConfig(
|
39 |
+
load_in_4bit=True,
|
40 |
+
bnb_4bit_compute_dtype=torch.float16,
|
41 |
+
bnb_4bit_use_double_quant=True,
|
42 |
+
bnb_4bit_quant_type="nf4"
|
43 |
+
)
|
44 |
+
|
45 |
+
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
|
46 |
+
|
47 |
+
# Load tokenizer first
|
48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
49 |
+
model_name,
|
50 |
+
trust_remote_code=True
|
51 |
+
)
|
52 |
+
|
53 |
+
# Fix padding token issue
|
54 |
+
if self.tokenizer.pad_token is None:
|
55 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
56 |
+
|
57 |
+
# Load model with quantization
|
58 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
59 |
+
model_name,
|
60 |
+
quantization_config=quantization_config,
|
61 |
+
device_map="auto",
|
62 |
+
torch_dtype=torch.float16,
|
63 |
+
trust_remote_code=True,
|
64 |
+
low_cpu_mem_usage=True
|
65 |
+
)
|
66 |
+
|
67 |
+
print("β
Quantized Mistral model loaded successfully")
|
68 |
+
|
69 |
+
except Exception as e:
|
70 |
+
print(f"β Error loading model: {e}")
|
71 |
+
print("π Falling back to alternative model...")
|
72 |
self.setup_fallback_model()
|
73 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
def setup_fallback_model(self):
|
75 |
"""Fallback to smaller model if Mistral fails"""
|
76 |
try:
|
77 |
+
# Use a better fallback model for Q&A
|
78 |
+
model_name = "distilgpt2"
|
79 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
80 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
81 |
+
|
82 |
+
# Fix padding token for fallback model too
|
83 |
+
if self.tokenizer.pad_token is None:
|
84 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
85 |
+
|
86 |
print("β
Fallback model loaded")
|
87 |
except Exception as e:
|
88 |
print(f"β Fallback model failed: {e}")
|
|
|
141 |
except Exception as e2:
|
142 |
return f"Error reading TXT: {str(e2)}"
|
143 |
|
144 |
+
def chunk_text(self, text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
|
145 |
+
"""Split text into overlapping chunks with better sentence preservation"""
|
146 |
if not text.strip():
|
147 |
return []
|
148 |
|
149 |
+
# Split by sentences first, then group into chunks
|
150 |
+
sentences = text.replace('\n', ' ').split('. ')
|
151 |
chunks = []
|
152 |
+
current_chunk = ""
|
153 |
|
154 |
+
for sentence in sentences:
|
155 |
+
sentence = sentence.strip()
|
156 |
+
if not sentence:
|
157 |
+
continue
|
158 |
+
|
159 |
+
# Add sentence to current chunk
|
160 |
+
test_chunk = current_chunk + ". " + sentence if current_chunk else sentence
|
161 |
|
162 |
+
# If chunk gets too long, save it and start new one
|
163 |
+
if len(test_chunk.split()) > chunk_size:
|
164 |
+
if current_chunk:
|
165 |
+
chunks.append(current_chunk.strip())
|
166 |
+
current_chunk = sentence
|
167 |
+
else:
|
168 |
+
current_chunk = test_chunk
|
169 |
+
|
170 |
+
# Add the last chunk
|
171 |
+
if current_chunk:
|
172 |
+
chunks.append(current_chunk.strip())
|
173 |
|
174 |
return chunks
|
175 |
|
|
|
225 |
except Exception as e:
|
226 |
return f"β Error processing documents: {str(e)}"
|
227 |
|
228 |
+
def retrieve_context(self, query: str, k: int = 5) -> str:
|
229 |
"""Retrieve relevant context for the query"""
|
230 |
if not self.is_indexed:
|
231 |
return ""
|
|
|
238 |
# Search for similar chunks
|
239 |
scores, indices = self.index.search(query_embedding.astype('float32'), k)
|
240 |
|
241 |
+
# Get relevant documents with higher threshold
|
242 |
relevant_docs = []
|
243 |
for i, idx in enumerate(indices[0]):
|
244 |
+
if idx < len(self.documents) and scores[0][i] > 0.2: # Higher similarity threshold
|
245 |
relevant_docs.append(self.documents[idx])
|
246 |
|
247 |
return "\n\n".join(relevant_docs)
|
|
|
251 |
return ""
|
252 |
|
253 |
def generate_answer(self, query: str, context: str) -> str:
|
254 |
+
"""Generate answer using the LLM with improved prompting"""
|
255 |
if self.model is None or self.tokenizer is None:
|
256 |
return "β Model not available. Please try again."
|
257 |
|
258 |
try:
|
259 |
+
# Check if using Mistral (has specific prompt format) or fallback model
|
260 |
+
model_name = getattr(self.model.config, '_name_or_path', '').lower()
|
261 |
+
is_mistral = 'mistral' in model_name
|
262 |
+
|
263 |
+
if is_mistral:
|
264 |
+
# Mistral-specific prompt format
|
265 |
+
prompt = f"""<s>[INST] You are a helpful assistant that answers questions based on the provided context. Use only the information from the context to answer. If the information is not in the context, say "I don't have enough information to answer this question."
|
266 |
|
267 |
Context:
|
268 |
+
{context[:1500]}
|
269 |
+
|
270 |
+
Question: {query}
|
271 |
+
|
272 |
+
Provide a clear and concise answer based only on the context above. [/INST]"""
|
273 |
+
else:
|
274 |
+
# Generic prompt for fallback models
|
275 |
+
prompt = f"""Context: {context[:1000]}
|
276 |
|
277 |
Question: {query}
|
278 |
|
279 |
+
Answer based on the context:"""
|
280 |
|
281 |
+
# Tokenize with proper handling
|
282 |
inputs = self.tokenizer(
|
283 |
prompt,
|
284 |
return_tensors="pt",
|
285 |
+
max_length=800, # Reduced to fit in memory
|
286 |
truncation=True,
|
287 |
padding=True
|
288 |
)
|
289 |
|
290 |
+
# Move to same device as model
|
291 |
+
if torch.cuda.is_available() and next(self.model.parameters()).is_cuda:
|
292 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
293 |
+
|
294 |
+
# Generate with better parameters
|
295 |
with torch.no_grad():
|
296 |
outputs = self.model.generate(
|
297 |
**inputs,
|
298 |
+
max_new_tokens=150, # Reduced for more focused answers
|
299 |
+
temperature=0.3, # Lower temperature for more consistent answers
|
300 |
do_sample=True,
|
301 |
+
top_p=0.8,
|
302 |
+
repetition_penalty=1.1,
|
303 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
304 |
eos_token_id=self.tokenizer.eos_token_id
|
305 |
)
|
306 |
|
307 |
# Decode response
|
308 |
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
309 |
|
310 |
+
# Extract answer based on model type
|
311 |
+
if is_mistral and "[/INST]" in full_response:
|
312 |
answer = full_response.split("[/INST]")[-1].strip()
|
313 |
else:
|
314 |
+
# For other models, remove the prompt
|
315 |
answer = full_response[len(prompt):].strip()
|
316 |
|
317 |
+
# Clean up the answer
|
318 |
+
answer = answer.replace(prompt, "").strip()
|
319 |
+
|
320 |
+
return answer if answer else "I couldn't generate a proper response based on the context."
|
321 |
|
322 |
except Exception as e:
|
323 |
return f"β Error generating answer: {str(e)}"
|
|
|
335 |
context = self.retrieve_context(query)
|
336 |
|
337 |
if not context:
|
338 |
+
return "π No relevant information found in the uploaded documents for your question."
|
339 |
|
340 |
# Generate answer
|
341 |
answer = self.generate_answer(query, context)
|
342 |
|
343 |
+
# Format the response
|
344 |
+
if answer and not answer.startswith("β"):
|
345 |
+
return f"π‘ **Answer:** {answer}\n\nπ **Relevant Context:**\n{context[:400]}..."
|
346 |
+
else:
|
347 |
+
return answer
|
348 |
|
349 |
except Exception as e:
|
350 |
return f"β Error answering question: {str(e)}"
|
|
|
400 |
with gr.Column():
|
401 |
answer_output = gr.Textbox(
|
402 |
label="Answer",
|
403 |
+
lines=12,
|
404 |
interactive=False
|
405 |
)
|
406 |
|
|
|
417 |
- Can you summarize the key points?
|
418 |
- What are the conclusions mentioned?
|
419 |
- Are there any specific numbers or statistics?
|
420 |
+
- Who are the main people or organizations mentioned?
|
421 |
""")
|
422 |
|
423 |
return demo
|