raksama19 commited on
Commit
42f533f
Β·
verified Β·
1 Parent(s): 12128c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -78
app.py CHANGED
@@ -313,23 +313,20 @@ model_path = "./hf_model"
313
  if not os.path.exists(model_path):
314
  model_path = "ByteDance/DOLPHIN"
315
 
316
- try:
317
- dolphin_model = DOLPHIN(model_path)
318
- print(f"Model loaded successfully from {model_path}")
319
- model_status = f"βœ… Model ready (Device: {dolphin_model.device})"
320
- except Exception as e:
321
- print(f"Error loading model: {e}")
322
- dolphin_model = None
323
- model_status = f"❌ Model failed to load: {str(e)}"
324
-
325
- # Initialize embedding model for RAG
326
  if RAG_DEPENDENCIES_AVAILABLE:
327
  try:
328
  print("Loading embedding model for RAG...")
329
- # Use GPU for embedding model with 24GB VRAM
330
- device = "cuda" if torch.cuda.is_available() else "cpu"
331
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
332
- print(f"βœ… Embedding model loaded successfully ({device})")
333
  except Exception as e:
334
  print(f"❌ Error loading embedding model: {e}")
335
  import traceback
@@ -339,39 +336,94 @@ else:
339
  print("❌ RAG dependencies not available")
340
  embedding_model = None
341
 
342
- # Initialize chatbot model
343
- try:
344
- import os
345
- # Get HuggingFace token from environment/secrets
346
- hf_token = os.getenv('HF_TOKEN')
347
- print(f"HF_TOKEN found: {'Yes' if hf_token else 'No'}")
348
-
349
- if hf_token:
350
- print("Loading chatbot model with token...")
351
- chatbot_model = Gemma3nForConditionalGeneration.from_pretrained(
352
- "google/gemma-3n-e4b-it",
353
- device_map="auto",
354
- torch_dtype=torch.bfloat16,
355
- token=hf_token # Use 'token' instead of 'use_auth_token'
356
- ).eval()
357
-
358
- chatbot_processor = AutoProcessor.from_pretrained(
359
- "google/gemma-3n-e4b-it",
360
- token=hf_token # Use 'token' instead of 'use_auth_token'
361
- )
362
-
363
- print("βœ… Chatbot model loaded successfully")
364
- else:
365
- print("❌ No HF_TOKEN found in environment")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  chatbot_model = None
367
  chatbot_processor = None
368
-
369
- except Exception as e:
370
- print(f"❌ Error loading chatbot model: {e}")
371
- import traceback
372
- traceback.print_exc()
373
- chatbot_model = None
374
- chatbot_processor = None
375
 
376
 
377
  # Global state for managing tabs
@@ -380,10 +432,15 @@ show_results_tab = False
380
  document_chunks = []
381
  document_embeddings = None
382
  embedding_model = None
383
- # chatbot_model is initialized above
 
 
 
 
 
384
 
385
 
386
- def chunk_document(text, chunk_size=500, overlap=50):
387
  """Split document into overlapping chunks for RAG"""
388
  words = text.split()
389
  chunks = []
@@ -401,8 +458,8 @@ def create_embeddings(chunks):
401
  return None
402
 
403
  try:
404
- # Process in larger batches with 24GB GPU
405
- batch_size = 64
406
  embeddings = []
407
 
408
  for i in range(0, len(chunks), batch_size):
@@ -415,10 +472,10 @@ def create_embeddings(chunks):
415
  print(f"Error creating embeddings: {e}")
416
  return None
417
 
418
- def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
419
  """Retrieve most relevant chunks for a question"""
420
  if embedding_model is None or embeddings is None:
421
- return chunks[:3] # Fallback to first 3 chunks
422
 
423
  try:
424
  question_embedding = embedding_model.encode([question], show_progress_bar=False)
@@ -431,32 +488,43 @@ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
431
  return relevant_chunks
432
  except Exception as e:
433
  print(f"Error retrieving chunks: {e}")
434
- return chunks[:3] # Fallback
435
 
436
  def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
437
  """Main processing function for uploaded PDF"""
438
  global processed_markdown, show_results_tab, document_chunks, document_embeddings
439
 
440
- if dolphin_model is None:
441
- return "❌ Model not loaded", gr.Tabs(visible=False)
442
-
443
  if pdf_file is None:
444
  return "❌ No PDF uploaded", gr.Tabs(visible=False)
445
 
446
  try:
447
- combined_markdown, status = process_pdf_document(pdf_file, dolphin_model, progress)
 
 
 
 
 
 
 
 
 
448
 
449
  if status == "processing_complete":
450
  processed_markdown = combined_markdown
451
 
452
  # Create chunks and embeddings for RAG
453
- print("Creating document chunks for RAG...")
454
  document_chunks = chunk_document(processed_markdown)
455
  document_embeddings = create_embeddings(document_chunks)
456
  print(f"Created {len(document_chunks)} chunks")
457
 
 
 
 
 
458
  show_results_tab = True
459
- return "βœ… PDF processed successfully! Check the 'Document' tab above.", gr.Tabs(visible=True)
 
460
  else:
461
  show_results_tab = False
462
  return combined_markdown, gr.Tabs(visible=False)
@@ -481,7 +549,11 @@ def clear_all():
481
  document_chunks = []
482
  document_embeddings = None
483
 
484
- return None, "", gr.Tabs(visible=False)
 
 
 
 
485
 
486
 
487
  # Create Gradio interface
@@ -535,14 +607,14 @@ with gr.Blocks(
535
  with gr.Tabs() as main_tabs:
536
  # Home Tab
537
  with gr.TabItem("🏠 Home", id="home"):
538
- chatbot_status = "βœ… Chatbot ready" if chatbot_model else "❌ Chatbot not loaded"
539
  embedding_status = "βœ… RAG ready" if embedding_model else "❌ RAG not loaded"
 
540
  gr.Markdown(
541
  "# Scholar Express\n"
542
- "### Upload a research paper to get a web-friendly version, an AI chatbot, and a podcast summary. Because of our reliance on Generative AI, some errors are inevitable.\n"
543
- f"**PDF Processing:** {model_status}\n"
544
- f"**Chatbot:** {chatbot_status}\n"
545
- f"**RAG System:** {embedding_status}"
546
  )
547
 
548
  with gr.Column(elem_classes="upload-container"):
@@ -647,56 +719,60 @@ with gr.Blocks(
647
  if not message.strip():
648
  return history
649
 
650
- if chatbot_model is None:
651
- return history + [[message, "❌ Chatbot model not loaded. Please check your HuggingFace token."]]
652
-
653
  if not processed_markdown:
654
  return history + [[message, "❌ Please process a PDF document first before asking questions."]]
655
 
656
  try:
 
 
 
 
 
 
657
  # Use RAG to get relevant chunks instead of full document
658
  if document_chunks and len(document_chunks) > 0:
659
  relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
660
  context = "\n\n".join(relevant_chunks)
661
  else:
662
  # Fallback to truncated document if RAG fails
663
- context = processed_markdown[:1500] + "..." if len(processed_markdown) > 1500 else processed_markdown
664
 
665
- # Create chat messages
666
  messages = [
667
  {
668
  "role": "system",
669
- "content": [{"type": "text", "text": "You are a helpful assistant that answers questions about documents. Use the provided document content to answer questions accurately and concisely."}]
670
  },
671
  {
672
  "role": "user",
673
- "content": [{"type": "text", "text": f"Document content:\n{context}\n\nQuestion: {message}"}]
674
  }
675
  ]
676
 
677
  # Process with the model
678
- inputs = chatbot_processor.apply_chat_template(
679
  messages,
680
  add_generation_prompt=True,
681
  tokenize=True,
682
  return_dict=True,
683
  return_tensors="pt",
684
- ).to(chatbot_model.device)
685
 
686
  input_len = inputs["input_ids"].shape[-1]
687
 
688
  with torch.inference_mode():
689
- generation = chatbot_model.generate(
690
  **inputs,
691
- max_new_tokens=400, # Increased for 24GB GPU
692
  do_sample=False,
693
  temperature=0.7,
694
- pad_token_id=chatbot_processor.tokenizer.pad_token_id,
695
- use_cache=True # Enable KV cache with more VRAM
 
696
  )
697
  generation = generation[0][input_len:]
698
 
699
- response = chatbot_processor.decode(generation, skip_special_tokens=True)
700
 
701
  return history + [[message, response]]
702
 
 
313
  if not os.path.exists(model_path):
314
  model_path = "ByteDance/DOLPHIN"
315
 
316
+ # Model paths and configuration
317
+ model_path = "./hf_model" if os.path.exists("./hf_model") else "ByteDance/DOLPHIN"
318
+ hf_token = os.getenv('HF_TOKEN')
319
+
320
+ # Don't load models initially - load them on demand
321
+ model_status = "βœ… Models ready (Dynamic loading)"
322
+
323
+ # Initialize embedding model for RAG (CPU to save GPU memory)
 
 
324
  if RAG_DEPENDENCIES_AVAILABLE:
325
  try:
326
  print("Loading embedding model for RAG...")
327
+ # Use CPU for embedding model to save GPU memory for main models
328
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
329
+ print("βœ… Embedding model loaded successfully (CPU)")
 
330
  except Exception as e:
331
  print(f"❌ Error loading embedding model: {e}")
332
  import traceback
 
336
  print("❌ RAG dependencies not available")
337
  embedding_model = None
338
 
339
+ # Model management functions
340
+ def load_dolphin_model():
341
+ """Load DOLPHIN model for PDF processing"""
342
+ global dolphin_model, current_model
343
+
344
+ if current_model == "dolphin":
345
+ return dolphin_model
346
+
347
+ # Unload chatbot model if loaded
348
+ unload_chatbot_model()
349
+
350
+ try:
351
+ print("Loading DOLPHIN model...")
352
+ dolphin_model = DOLPHIN(model_path)
353
+ current_model = "dolphin"
354
+ print(f"βœ… DOLPHIN model loaded (Device: {dolphin_model.device})")
355
+ return dolphin_model
356
+ except Exception as e:
357
+ print(f"❌ Error loading DOLPHIN model: {e}")
358
+ return None
359
+
360
+ def unload_dolphin_model():
361
+ """Unload DOLPHIN model to free memory"""
362
+ global dolphin_model, current_model
363
+
364
+ if dolphin_model is not None:
365
+ print("Unloading DOLPHIN model...")
366
+ del dolphin_model
367
+ dolphin_model = None
368
+ if current_model == "dolphin":
369
+ current_model = None
370
+ if torch.cuda.is_available():
371
+ torch.cuda.empty_cache()
372
+ print("βœ… DOLPHIN model unloaded")
373
+
374
+ def load_chatbot_model():
375
+ """Load Gemma chatbot model"""
376
+ global chatbot_model, chatbot_processor, current_model
377
+
378
+ if current_model == "chatbot":
379
+ return chatbot_model, chatbot_processor
380
+
381
+ # Unload DOLPHIN model if loaded
382
+ unload_dolphin_model()
383
+
384
+ try:
385
+ print("Loading Gemma chatbot model...")
386
+ print(f"HF_TOKEN found: {'Yes' if hf_token else 'No'}")
387
+
388
+ if hf_token:
389
+ chatbot_model = Gemma3nForConditionalGeneration.from_pretrained(
390
+ "google/gemma-3n-e4b-it",
391
+ device_map="auto",
392
+ torch_dtype=torch.bfloat16,
393
+ token=hf_token
394
+ ).eval()
395
+
396
+ chatbot_processor = AutoProcessor.from_pretrained(
397
+ "google/gemma-3n-e4b-it",
398
+ token=hf_token
399
+ )
400
+
401
+ current_model = "chatbot"
402
+ print("βœ… Gemma chatbot model loaded")
403
+ return chatbot_model, chatbot_processor
404
+ else:
405
+ print("❌ No HF_TOKEN found")
406
+ return None, None
407
+ except Exception as e:
408
+ print(f"❌ Error loading chatbot model: {e}")
409
+ import traceback
410
+ traceback.print_exc()
411
+ return None, None
412
+
413
+ def unload_chatbot_model():
414
+ """Unload chatbot model to free memory"""
415
+ global chatbot_model, chatbot_processor, current_model
416
+
417
+ if chatbot_model is not None:
418
+ print("Unloading Gemma chatbot model...")
419
+ del chatbot_model, chatbot_processor
420
  chatbot_model = None
421
  chatbot_processor = None
422
+ if current_model == "chatbot":
423
+ current_model = None
424
+ if torch.cuda.is_available():
425
+ torch.cuda.empty_cache()
426
+ print("βœ… Gemma chatbot model unloaded")
 
 
427
 
428
 
429
  # Global state for managing tabs
 
432
  document_chunks = []
433
  document_embeddings = None
434
  embedding_model = None
435
+
436
+ # Global model state - only one model loaded at a time
437
+ dolphin_model = None
438
+ chatbot_model = None
439
+ chatbot_processor = None
440
+ current_model = None # Track which model is currently loaded
441
 
442
 
443
+ def chunk_document(text, chunk_size=400, overlap=40):
444
  """Split document into overlapping chunks for RAG"""
445
  words = text.split()
446
  chunks = []
 
458
  return None
459
 
460
  try:
461
+ # Process in smaller batches on CPU
462
+ batch_size = 32
463
  embeddings = []
464
 
465
  for i in range(0, len(chunks), batch_size):
 
472
  print(f"Error creating embeddings: {e}")
473
  return None
474
 
475
+ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=2):
476
  """Retrieve most relevant chunks for a question"""
477
  if embedding_model is None or embeddings is None:
478
+ return chunks[:2] # Fallback to first 2 chunks
479
 
480
  try:
481
  question_embedding = embedding_model.encode([question], show_progress_bar=False)
 
488
  return relevant_chunks
489
  except Exception as e:
490
  print(f"Error retrieving chunks: {e}")
491
+ return chunks[:2] # Fallback
492
 
493
  def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
494
  """Main processing function for uploaded PDF"""
495
  global processed_markdown, show_results_tab, document_chunks, document_embeddings
496
 
 
 
 
497
  if pdf_file is None:
498
  return "❌ No PDF uploaded", gr.Tabs(visible=False)
499
 
500
  try:
501
+ # Load DOLPHIN model for PDF processing
502
+ progress(0.1, desc="Loading DOLPHIN model...")
503
+ dolphin = load_dolphin_model()
504
+
505
+ if dolphin is None:
506
+ return "❌ Failed to load DOLPHIN model", gr.Tabs(visible=False)
507
+
508
+ # Process PDF
509
+ progress(0.2, desc="Processing PDF...")
510
+ combined_markdown, status = process_pdf_document(pdf_file, dolphin, progress)
511
 
512
  if status == "processing_complete":
513
  processed_markdown = combined_markdown
514
 
515
  # Create chunks and embeddings for RAG
516
+ progress(0.9, desc="Creating document chunks for RAG...")
517
  document_chunks = chunk_document(processed_markdown)
518
  document_embeddings = create_embeddings(document_chunks)
519
  print(f"Created {len(document_chunks)} chunks")
520
 
521
+ # Unload DOLPHIN model to free memory for chatbot
522
+ progress(0.95, desc="Preparing chatbot...")
523
+ unload_dolphin_model()
524
+
525
  show_results_tab = True
526
+ progress(1.0, desc="PDF processed successfully!")
527
+ return "βœ… PDF processed successfully! Chatbot is ready in the Chat tab.", gr.Tabs(visible=True)
528
  else:
529
  show_results_tab = False
530
  return combined_markdown, gr.Tabs(visible=False)
 
549
  document_chunks = []
550
  document_embeddings = None
551
 
552
+ # Unload any loaded models
553
+ unload_dolphin_model()
554
+ unload_chatbot_model()
555
+
556
+ return None, "βœ… Ready to process your PDF", gr.Tabs(visible=False)
557
 
558
 
559
  # Create Gradio interface
 
607
  with gr.Tabs() as main_tabs:
608
  # Home Tab
609
  with gr.TabItem("🏠 Home", id="home"):
 
610
  embedding_status = "βœ… RAG ready" if embedding_model else "❌ RAG not loaded"
611
+ current_status = f"Currently loaded: {current_model or 'None'}"
612
  gr.Markdown(
613
  "# Scholar Express\n"
614
+ "### Upload a research paper to get a web-friendly version, an AI chatbot, and a podcast summary. Models are loaded dynamically to optimize memory usage.\n"
615
+ f"**System:** {model_status}\n"
616
+ f"**RAG System:** {embedding_status}\n"
617
+ f"**Status:** {current_status}"
618
  )
619
 
620
  with gr.Column(elem_classes="upload-container"):
 
719
  if not message.strip():
720
  return history
721
 
 
 
 
722
  if not processed_markdown:
723
  return history + [[message, "❌ Please process a PDF document first before asking questions."]]
724
 
725
  try:
726
+ # Load chatbot model
727
+ model, processor = load_chatbot_model()
728
+
729
+ if model is None or processor is None:
730
+ return history + [[message, "❌ Failed to load chatbot model. Please check your HuggingFace token."]]
731
+
732
  # Use RAG to get relevant chunks instead of full document
733
  if document_chunks and len(document_chunks) > 0:
734
  relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
735
  context = "\n\n".join(relevant_chunks)
736
  else:
737
  # Fallback to truncated document if RAG fails
738
+ context = processed_markdown[:1200] + "..." if len(processed_markdown) > 1200 else processed_markdown
739
 
740
+ # Create chat messages with shorter context
741
  messages = [
742
  {
743
  "role": "system",
744
+ "content": [{"type": "text", "text": "You are a helpful assistant. Answer questions about the document concisely."}]
745
  },
746
  {
747
  "role": "user",
748
+ "content": [{"type": "text", "text": f"Context:\n{context}\n\nQ: {message}"}]
749
  }
750
  ]
751
 
752
  # Process with the model
753
+ inputs = processor.apply_chat_template(
754
  messages,
755
  add_generation_prompt=True,
756
  tokenize=True,
757
  return_dict=True,
758
  return_tensors="pt",
759
+ ).to(model.device)
760
 
761
  input_len = inputs["input_ids"].shape[-1]
762
 
763
  with torch.inference_mode():
764
+ generation = model.generate(
765
  **inputs,
766
+ max_new_tokens=300, # Can be higher now with single model
767
  do_sample=False,
768
  temperature=0.7,
769
+ pad_token_id=processor.tokenizer.pad_token_id,
770
+ use_cache=True, # Can enable cache with single model
771
+ num_beams=1
772
  )
773
  generation = generation[0][input_len:]
774
 
775
+ response = processor.decode(generation, skip_special_tokens=True)
776
 
777
  return history + [[message, response]]
778