raksama19 commited on
Commit
d6f868f
Β·
verified Β·
1 Parent(s): d3e09b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -13
app.py CHANGED
@@ -318,8 +318,9 @@ except Exception as e:
318
  # Initialize embedding model for RAG
319
  try:
320
  print("Loading embedding model...")
321
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
322
- print("βœ… Embedding model loaded successfully")
 
323
  except Exception as e:
324
  print(f"❌ Error loading embedding model: {e}")
325
  embedding_model = None
@@ -368,7 +369,7 @@ embedding_model = None
368
  # chatbot_model is initialized above
369
 
370
 
371
- def chunk_document(text, chunk_size=500, overlap=50):
372
  """Split document into overlapping chunks for RAG"""
373
  words = text.split()
374
  chunks = []
@@ -386,19 +387,27 @@ def create_embeddings(chunks):
386
  return None
387
 
388
  try:
389
- embeddings = embedding_model.encode(chunks)
390
- return embeddings
 
 
 
 
 
 
 
 
391
  except Exception as e:
392
  print(f"Error creating embeddings: {e}")
393
  return None
394
 
395
- def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
396
  """Retrieve most relevant chunks for a question"""
397
  if embedding_model is None or embeddings is None:
398
- return chunks[:3] # Fallback to first 3 chunks
399
 
400
  try:
401
- question_embedding = embedding_model.encode([question])
402
  similarities = cosine_similarity(question_embedding, embeddings)[0]
403
 
404
  # Get top-k most similar chunks
@@ -408,7 +417,7 @@ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
408
  return relevant_chunks
409
  except Exception as e:
410
  print(f"Error retrieving chunks: {e}")
411
- return chunks[:3] # Fallback
412
 
413
  def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
414
  """Main processing function for uploaded PDF"""
@@ -452,10 +461,17 @@ def get_processed_markdown():
452
 
453
  def clear_all():
454
  """Clear all data and hide results tab"""
455
- global processed_markdown, show_results_tab
456
  processed_markdown = ""
457
  show_results_tab = False
458
- return None, "βœ… Ready to process your PDF", gr.Tabs(visible=False)
 
 
 
 
 
 
 
459
 
460
 
461
  # Create Gradio interface
@@ -660,14 +676,23 @@ with gr.Blocks(
660
  input_len = inputs["input_ids"].shape[-1]
661
 
662
  with torch.inference_mode():
 
 
 
 
663
  generation = chatbot_model.generate(
664
  **inputs,
665
- max_new_tokens=300,
666
  do_sample=False,
667
  temperature=0.7,
668
- pad_token_id=chatbot_processor.tokenizer.pad_token_id
 
669
  )
670
  generation = generation[0][input_len:]
 
 
 
 
671
 
672
  response = chatbot_processor.decode(generation, skip_special_tokens=True)
673
 
 
318
  # Initialize embedding model for RAG
319
  try:
320
  print("Loading embedding model...")
321
+ # Force CPU for embedding model to save GPU memory
322
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
323
+ print("βœ… Embedding model loaded successfully (CPU)")
324
  except Exception as e:
325
  print(f"❌ Error loading embedding model: {e}")
326
  embedding_model = None
 
369
  # chatbot_model is initialized above
370
 
371
 
372
+ def chunk_document(text, chunk_size=300, overlap=30):
373
  """Split document into overlapping chunks for RAG"""
374
  words = text.split()
375
  chunks = []
 
387
  return None
388
 
389
  try:
390
+ # Process in smaller batches to avoid memory issues
391
+ batch_size = 32
392
+ embeddings = []
393
+
394
+ for i in range(0, len(chunks), batch_size):
395
+ batch = chunks[i:i + batch_size]
396
+ batch_embeddings = embedding_model.encode(batch, show_progress_bar=False)
397
+ embeddings.extend(batch_embeddings)
398
+
399
+ return np.array(embeddings)
400
  except Exception as e:
401
  print(f"Error creating embeddings: {e}")
402
  return None
403
 
404
+ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=2):
405
  """Retrieve most relevant chunks for a question"""
406
  if embedding_model is None or embeddings is None:
407
+ return chunks[:2] # Fallback to first 2 chunks (reduced from 3)
408
 
409
  try:
410
+ question_embedding = embedding_model.encode([question], show_progress_bar=False)
411
  similarities = cosine_similarity(question_embedding, embeddings)[0]
412
 
413
  # Get top-k most similar chunks
 
417
  return relevant_chunks
418
  except Exception as e:
419
  print(f"Error retrieving chunks: {e}")
420
+ return chunks[:2] # Fallback
421
 
422
  def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
423
  """Main processing function for uploaded PDF"""
 
461
 
462
  def clear_all():
463
  """Clear all data and hide results tab"""
464
+ global processed_markdown, show_results_tab, document_chunks, document_embeddings
465
  processed_markdown = ""
466
  show_results_tab = False
467
+ document_chunks = []
468
+ document_embeddings = None
469
+
470
+ # Clear GPU memory
471
+ if torch.cuda.is_available():
472
+ torch.cuda.empty_cache()
473
+
474
+ return None, "", gr.Tabs(visible=False)
475
 
476
 
477
  # Create Gradio interface
 
676
  input_len = inputs["input_ids"].shape[-1]
677
 
678
  with torch.inference_mode():
679
+ # Clear cache before generation
680
+ if torch.cuda.is_available():
681
+ torch.cuda.empty_cache()
682
+
683
  generation = chatbot_model.generate(
684
  **inputs,
685
+ max_new_tokens=200, # Reduced from 300 to save memory
686
  do_sample=False,
687
  temperature=0.7,
688
+ pad_token_id=chatbot_processor.tokenizer.pad_token_id,
689
+ use_cache=False # Disable KV cache to save memory
690
  )
691
  generation = generation[0][input_len:]
692
+
693
+ # Clear cache after generation
694
+ if torch.cuda.is_available():
695
+ torch.cuda.empty_cache()
696
 
697
  response = chatbot_processor.decode(generation, skip_special_tokens=True)
698