Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -318,8 +318,9 @@ except Exception as e:
|
|
318 |
# Initialize embedding model for RAG
|
319 |
try:
|
320 |
print("Loading embedding model...")
|
321 |
-
|
322 |
-
|
|
|
323 |
except Exception as e:
|
324 |
print(f"β Error loading embedding model: {e}")
|
325 |
embedding_model = None
|
@@ -368,7 +369,7 @@ embedding_model = None
|
|
368 |
# chatbot_model is initialized above
|
369 |
|
370 |
|
371 |
-
def chunk_document(text, chunk_size=
|
372 |
"""Split document into overlapping chunks for RAG"""
|
373 |
words = text.split()
|
374 |
chunks = []
|
@@ -386,19 +387,27 @@ def create_embeddings(chunks):
|
|
386 |
return None
|
387 |
|
388 |
try:
|
389 |
-
|
390 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
except Exception as e:
|
392 |
print(f"Error creating embeddings: {e}")
|
393 |
return None
|
394 |
|
395 |
-
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=
|
396 |
"""Retrieve most relevant chunks for a question"""
|
397 |
if embedding_model is None or embeddings is None:
|
398 |
-
return chunks[:
|
399 |
|
400 |
try:
|
401 |
-
question_embedding = embedding_model.encode([question])
|
402 |
similarities = cosine_similarity(question_embedding, embeddings)[0]
|
403 |
|
404 |
# Get top-k most similar chunks
|
@@ -408,7 +417,7 @@ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
|
|
408 |
return relevant_chunks
|
409 |
except Exception as e:
|
410 |
print(f"Error retrieving chunks: {e}")
|
411 |
-
return chunks[:
|
412 |
|
413 |
def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
|
414 |
"""Main processing function for uploaded PDF"""
|
@@ -452,10 +461,17 @@ def get_processed_markdown():
|
|
452 |
|
453 |
def clear_all():
|
454 |
"""Clear all data and hide results tab"""
|
455 |
-
global processed_markdown, show_results_tab
|
456 |
processed_markdown = ""
|
457 |
show_results_tab = False
|
458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
459 |
|
460 |
|
461 |
# Create Gradio interface
|
@@ -660,14 +676,23 @@ with gr.Blocks(
|
|
660 |
input_len = inputs["input_ids"].shape[-1]
|
661 |
|
662 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
663 |
generation = chatbot_model.generate(
|
664 |
**inputs,
|
665 |
-
max_new_tokens=300
|
666 |
do_sample=False,
|
667 |
temperature=0.7,
|
668 |
-
pad_token_id=chatbot_processor.tokenizer.pad_token_id
|
|
|
669 |
)
|
670 |
generation = generation[0][input_len:]
|
|
|
|
|
|
|
|
|
671 |
|
672 |
response = chatbot_processor.decode(generation, skip_special_tokens=True)
|
673 |
|
|
|
318 |
# Initialize embedding model for RAG
|
319 |
try:
|
320 |
print("Loading embedding model...")
|
321 |
+
# Force CPU for embedding model to save GPU memory
|
322 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
323 |
+
print("β
Embedding model loaded successfully (CPU)")
|
324 |
except Exception as e:
|
325 |
print(f"β Error loading embedding model: {e}")
|
326 |
embedding_model = None
|
|
|
369 |
# chatbot_model is initialized above
|
370 |
|
371 |
|
372 |
+
def chunk_document(text, chunk_size=300, overlap=30):
|
373 |
"""Split document into overlapping chunks for RAG"""
|
374 |
words = text.split()
|
375 |
chunks = []
|
|
|
387 |
return None
|
388 |
|
389 |
try:
|
390 |
+
# Process in smaller batches to avoid memory issues
|
391 |
+
batch_size = 32
|
392 |
+
embeddings = []
|
393 |
+
|
394 |
+
for i in range(0, len(chunks), batch_size):
|
395 |
+
batch = chunks[i:i + batch_size]
|
396 |
+
batch_embeddings = embedding_model.encode(batch, show_progress_bar=False)
|
397 |
+
embeddings.extend(batch_embeddings)
|
398 |
+
|
399 |
+
return np.array(embeddings)
|
400 |
except Exception as e:
|
401 |
print(f"Error creating embeddings: {e}")
|
402 |
return None
|
403 |
|
404 |
+
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=2):
|
405 |
"""Retrieve most relevant chunks for a question"""
|
406 |
if embedding_model is None or embeddings is None:
|
407 |
+
return chunks[:2] # Fallback to first 2 chunks (reduced from 3)
|
408 |
|
409 |
try:
|
410 |
+
question_embedding = embedding_model.encode([question], show_progress_bar=False)
|
411 |
similarities = cosine_similarity(question_embedding, embeddings)[0]
|
412 |
|
413 |
# Get top-k most similar chunks
|
|
|
417 |
return relevant_chunks
|
418 |
except Exception as e:
|
419 |
print(f"Error retrieving chunks: {e}")
|
420 |
+
return chunks[:2] # Fallback
|
421 |
|
422 |
def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
|
423 |
"""Main processing function for uploaded PDF"""
|
|
|
461 |
|
462 |
def clear_all():
|
463 |
"""Clear all data and hide results tab"""
|
464 |
+
global processed_markdown, show_results_tab, document_chunks, document_embeddings
|
465 |
processed_markdown = ""
|
466 |
show_results_tab = False
|
467 |
+
document_chunks = []
|
468 |
+
document_embeddings = None
|
469 |
+
|
470 |
+
# Clear GPU memory
|
471 |
+
if torch.cuda.is_available():
|
472 |
+
torch.cuda.empty_cache()
|
473 |
+
|
474 |
+
return None, "", gr.Tabs(visible=False)
|
475 |
|
476 |
|
477 |
# Create Gradio interface
|
|
|
676 |
input_len = inputs["input_ids"].shape[-1]
|
677 |
|
678 |
with torch.inference_mode():
|
679 |
+
# Clear cache before generation
|
680 |
+
if torch.cuda.is_available():
|
681 |
+
torch.cuda.empty_cache()
|
682 |
+
|
683 |
generation = chatbot_model.generate(
|
684 |
**inputs,
|
685 |
+
max_new_tokens=200, # Reduced from 300 to save memory
|
686 |
do_sample=False,
|
687 |
temperature=0.7,
|
688 |
+
pad_token_id=chatbot_processor.tokenizer.pad_token_id,
|
689 |
+
use_cache=False # Disable KV cache to save memory
|
690 |
)
|
691 |
generation = generation[0][input_len:]
|
692 |
+
|
693 |
+
# Clear cache after generation
|
694 |
+
if torch.cuda.is_available():
|
695 |
+
torch.cuda.empty_cache()
|
696 |
|
697 |
response = chatbot_processor.decode(generation, skip_special_tokens=True)
|
698 |
|