raksama19 commited on
Commit
12128c1
Β·
verified Β·
1 Parent(s): d6f868f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -30
app.py CHANGED
@@ -11,9 +11,16 @@ import numpy as np
11
  from PIL import Image
12
  from transformers import AutoProcessor, VisionEncoderDecoderModel, Gemma3nForConditionalGeneration, pipeline
13
  import torch
14
- from sentence_transformers import SentenceTransformer
15
- import numpy as np
16
- from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
 
 
 
17
  import os
18
  import tempfile
19
  import uuid
@@ -316,13 +323,20 @@ except Exception as e:
316
  model_status = f"❌ Model failed to load: {str(e)}"
317
 
318
  # Initialize embedding model for RAG
319
- try:
320
- print("Loading embedding model...")
321
- # Force CPU for embedding model to save GPU memory
322
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
323
- print("βœ… Embedding model loaded successfully (CPU)")
324
- except Exception as e:
325
- print(f"❌ Error loading embedding model: {e}")
 
 
 
 
 
 
 
326
  embedding_model = None
327
 
328
  # Initialize chatbot model
@@ -369,7 +383,7 @@ embedding_model = None
369
  # chatbot_model is initialized above
370
 
371
 
372
- def chunk_document(text, chunk_size=300, overlap=30):
373
  """Split document into overlapping chunks for RAG"""
374
  words = text.split()
375
  chunks = []
@@ -387,8 +401,8 @@ def create_embeddings(chunks):
387
  return None
388
 
389
  try:
390
- # Process in smaller batches to avoid memory issues
391
- batch_size = 32
392
  embeddings = []
393
 
394
  for i in range(0, len(chunks), batch_size):
@@ -401,10 +415,10 @@ def create_embeddings(chunks):
401
  print(f"Error creating embeddings: {e}")
402
  return None
403
 
404
- def retrieve_relevant_chunks(question, chunks, embeddings, top_k=2):
405
  """Retrieve most relevant chunks for a question"""
406
  if embedding_model is None or embeddings is None:
407
- return chunks[:2] # Fallback to first 2 chunks (reduced from 3)
408
 
409
  try:
410
  question_embedding = embedding_model.encode([question], show_progress_bar=False)
@@ -417,7 +431,7 @@ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=2):
417
  return relevant_chunks
418
  except Exception as e:
419
  print(f"Error retrieving chunks: {e}")
420
- return chunks[:2] # Fallback
421
 
422
  def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
423
  """Main processing function for uploaded PDF"""
@@ -467,10 +481,6 @@ def clear_all():
467
  document_chunks = []
468
  document_embeddings = None
469
 
470
- # Clear GPU memory
471
- if torch.cuda.is_available():
472
- torch.cuda.empty_cache()
473
-
474
  return None, "", gr.Tabs(visible=False)
475
 
476
 
@@ -676,23 +686,15 @@ with gr.Blocks(
676
  input_len = inputs["input_ids"].shape[-1]
677
 
678
  with torch.inference_mode():
679
- # Clear cache before generation
680
- if torch.cuda.is_available():
681
- torch.cuda.empty_cache()
682
-
683
  generation = chatbot_model.generate(
684
  **inputs,
685
- max_new_tokens=200, # Reduced from 300 to save memory
686
  do_sample=False,
687
  temperature=0.7,
688
  pad_token_id=chatbot_processor.tokenizer.pad_token_id,
689
- use_cache=False # Disable KV cache to save memory
690
  )
691
  generation = generation[0][input_len:]
692
-
693
- # Clear cache after generation
694
- if torch.cuda.is_available():
695
- torch.cuda.empty_cache()
696
 
697
  response = chatbot_processor.decode(generation, skip_special_tokens=True)
698
 
 
11
  from PIL import Image
12
  from transformers import AutoProcessor, VisionEncoderDecoderModel, Gemma3nForConditionalGeneration, pipeline
13
  import torch
14
+ try:
15
+ from sentence_transformers import SentenceTransformer
16
+ import numpy as np
17
+ from sklearn.metrics.pairwise import cosine_similarity
18
+ RAG_DEPENDENCIES_AVAILABLE = True
19
+ except ImportError as e:
20
+ print(f"RAG dependencies not available: {e}")
21
+ print("Please install: pip install sentence-transformers scikit-learn")
22
+ RAG_DEPENDENCIES_AVAILABLE = False
23
+ SentenceTransformer = None
24
  import os
25
  import tempfile
26
  import uuid
 
323
  model_status = f"❌ Model failed to load: {str(e)}"
324
 
325
  # Initialize embedding model for RAG
326
+ if RAG_DEPENDENCIES_AVAILABLE:
327
+ try:
328
+ print("Loading embedding model for RAG...")
329
+ # Use GPU for embedding model with 24GB VRAM
330
+ device = "cuda" if torch.cuda.is_available() else "cpu"
331
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
332
+ print(f"βœ… Embedding model loaded successfully ({device})")
333
+ except Exception as e:
334
+ print(f"❌ Error loading embedding model: {e}")
335
+ import traceback
336
+ traceback.print_exc()
337
+ embedding_model = None
338
+ else:
339
+ print("❌ RAG dependencies not available")
340
  embedding_model = None
341
 
342
  # Initialize chatbot model
 
383
  # chatbot_model is initialized above
384
 
385
 
386
+ def chunk_document(text, chunk_size=500, overlap=50):
387
  """Split document into overlapping chunks for RAG"""
388
  words = text.split()
389
  chunks = []
 
401
  return None
402
 
403
  try:
404
+ # Process in larger batches with 24GB GPU
405
+ batch_size = 64
406
  embeddings = []
407
 
408
  for i in range(0, len(chunks), batch_size):
 
415
  print(f"Error creating embeddings: {e}")
416
  return None
417
 
418
+ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
419
  """Retrieve most relevant chunks for a question"""
420
  if embedding_model is None or embeddings is None:
421
+ return chunks[:3] # Fallback to first 3 chunks
422
 
423
  try:
424
  question_embedding = embedding_model.encode([question], show_progress_bar=False)
 
431
  return relevant_chunks
432
  except Exception as e:
433
  print(f"Error retrieving chunks: {e}")
434
+ return chunks[:3] # Fallback
435
 
436
  def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
437
  """Main processing function for uploaded PDF"""
 
481
  document_chunks = []
482
  document_embeddings = None
483
 
 
 
 
 
484
  return None, "", gr.Tabs(visible=False)
485
 
486
 
 
686
  input_len = inputs["input_ids"].shape[-1]
687
 
688
  with torch.inference_mode():
 
 
 
 
689
  generation = chatbot_model.generate(
690
  **inputs,
691
+ max_new_tokens=400, # Increased for 24GB GPU
692
  do_sample=False,
693
  temperature=0.7,
694
  pad_token_id=chatbot_processor.tokenizer.pad_token_id,
695
+ use_cache=True # Enable KV cache with more VRAM
696
  )
697
  generation = generation[0][input_len:]
 
 
 
 
698
 
699
  response = chatbot_processor.decode(generation, skip_special_tokens=True)
700