raksama19 commited on
Commit
2b9109f
Β·
verified Β·
1 Parent(s): c7f802d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -105
app.py CHANGED
@@ -15,10 +15,11 @@ try:
15
  from sentence_transformers import SentenceTransformer
16
  import numpy as np
17
  from sklearn.metrics.pairwise import cosine_similarity
 
18
  RAG_DEPENDENCIES_AVAILABLE = True
19
  except ImportError as e:
20
  print(f"RAG dependencies not available: {e}")
21
- print("Please install: pip install sentence-transformers scikit-learn")
22
  RAG_DEPENDENCIES_AVAILABLE = False
23
  SentenceTransformer = None
24
  import os
@@ -320,21 +321,32 @@ hf_token = os.getenv('HF_TOKEN')
320
  # Don't load models initially - load them on demand
321
  model_status = "βœ… Models ready (Dynamic loading)"
322
 
323
- # Initialize embedding model for RAG (CPU to save GPU memory)
324
  if RAG_DEPENDENCIES_AVAILABLE:
325
  try:
326
  print("Loading embedding model for RAG...")
327
- # Use CPU for embedding model to save GPU memory for main models
328
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
329
  print("βœ… Embedding model loaded successfully (CPU)")
 
 
 
 
 
 
 
 
 
 
330
  except Exception as e:
331
- print(f"❌ Error loading embedding model: {e}")
332
  import traceback
333
  traceback.print_exc()
334
  embedding_model = None
 
335
  else:
336
  print("❌ RAG dependencies not available")
337
  embedding_model = None
 
338
 
339
  # Model management functions
340
  def load_dolphin_model():
@@ -371,59 +383,29 @@ def unload_dolphin_model():
371
  torch.cuda.empty_cache()
372
  print("βœ… DOLPHIN model unloaded")
373
 
374
- def load_chatbot_model():
375
- """Load Gemma chatbot model"""
376
- global chatbot_model, chatbot_processor, current_model
377
-
378
- if current_model == "chatbot":
379
- return chatbot_model, chatbot_processor
380
 
381
- # Unload DOLPHIN model if loaded
382
- unload_dolphin_model()
383
 
384
  try:
385
- print("Loading Gemma chatbot model...")
386
- print(f"HF_TOKEN found: {'Yes' if hf_token else 'No'}")
387
-
388
- if hf_token:
389
- chatbot_model = Gemma3nForConditionalGeneration.from_pretrained(
390
- "google/gemma-3n-e4b-it",
391
- device_map="auto",
392
- torch_dtype=torch.bfloat16,
393
- token=hf_token
394
- ).eval()
395
-
396
- chatbot_processor = AutoProcessor.from_pretrained(
397
- "google/gemma-3n-e4b-it",
398
- token=hf_token
399
- )
400
-
401
- current_model = "chatbot"
402
- print("βœ… Gemma chatbot model loaded")
403
- return chatbot_model, chatbot_processor
404
- else:
405
- print("❌ No HF_TOKEN found")
406
- return None, None
407
  except Exception as e:
408
- print(f"❌ Error loading chatbot model: {e}")
409
  import traceback
410
  traceback.print_exc()
411
- return None, None
412
-
413
- def unload_chatbot_model():
414
- """Unload chatbot model to free memory"""
415
- global chatbot_model, chatbot_processor, current_model
416
-
417
- if chatbot_model is not None:
418
- print("Unloading Gemma chatbot model...")
419
- del chatbot_model, chatbot_processor
420
- chatbot_model = None
421
- chatbot_processor = None
422
- if current_model == "chatbot":
423
- current_model = None
424
- if torch.cuda.is_available():
425
- torch.cuda.empty_cache()
426
- print("βœ… Gemma chatbot model unloaded")
427
 
428
 
429
  # Global state for managing tabs
@@ -431,12 +413,10 @@ processed_markdown = ""
431
  show_results_tab = False
432
  document_chunks = []
433
  document_embeddings = None
434
- embedding_model = None
435
 
436
- # Global model state - only one model loaded at a time
437
  dolphin_model = None
438
- chatbot_model = None
439
- chatbot_processor = None
440
  current_model = None # Track which model is currently loaded
441
 
442
 
@@ -518,9 +498,8 @@ def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
518
  document_embeddings = create_embeddings(document_chunks)
519
  print(f"Created {len(document_chunks)} chunks")
520
 
521
- # Unload DOLPHIN model to free memory for chatbot
522
  progress(0.95, desc="Preparing chatbot...")
523
- unload_dolphin_model()
524
 
525
  show_results_tab = True
526
  progress(1.0, desc="PDF processed successfully!")
@@ -549,11 +528,10 @@ def clear_all():
549
  document_chunks = []
550
  document_embeddings = None
551
 
552
- # Unload any loaded models
553
  unload_dolphin_model()
554
- unload_chatbot_model()
555
 
556
- return None, "βœ… Ready to process your PDF", gr.Tabs(visible=False)
557
 
558
 
559
  # Create Gradio interface
@@ -608,12 +586,14 @@ with gr.Blocks(
608
  # Home Tab
609
  with gr.TabItem("🏠 Home", id="home"):
610
  embedding_status = "βœ… RAG ready" if embedding_model else "❌ RAG not loaded"
 
611
  current_status = f"Currently loaded: {current_model or 'None'}"
612
  gr.Markdown(
613
  "# Scholar Express\n"
614
- "### Upload a research paper to get a web-friendly version, an AI chatbot, and a podcast summary. Models are loaded dynamically to optimize memory usage.\n"
615
  f"**System:** {model_status}\n"
616
  f"**RAG System:** {embedding_status}\n"
 
617
  f"**Status:** {current_status}"
618
  )
619
 
@@ -648,7 +628,7 @@ with gr.Blocks(
648
 
649
  # Status output (hidden during processing)
650
  status_output = gr.Markdown(
651
- "βœ… Ready to process your PDF",
652
  elem_classes="status-message"
653
  )
654
 
@@ -685,7 +665,7 @@ with gr.Blocks(
685
  send_btn = gr.Button("Send", variant="primary", scale=1)
686
 
687
  gr.Markdown(
688
- "*Ask questions about your processed document. The AI uses RAG (Retrieval-Augmented Generation) to find relevant sections and provide accurate answers.*",
689
  elem_id="chat-notice"
690
  )
691
 
@@ -714,7 +694,7 @@ with gr.Blocks(
714
  outputs=[chat_tab]
715
  )
716
 
717
- # Chatbot functionality
718
  def chatbot_response(message, history):
719
  if not message.strip():
720
  return history
@@ -723,61 +703,42 @@ with gr.Blocks(
723
  return history + [[message, "❌ Please process a PDF document first before asking questions."]]
724
 
725
  try:
726
- # Load chatbot model
727
- model, processor = load_chatbot_model()
728
 
729
- if model is None or processor is None:
730
- return history + [[message, "❌ Failed to load chatbot model. Please check your HuggingFace token."]]
731
 
732
- # Use RAG to get relevant chunks instead of full document
733
  if document_chunks and len(document_chunks) > 0:
734
  relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
735
  context = "\n\n".join(relevant_chunks)
736
  else:
737
  # Fallback to truncated document if RAG fails
738
- context = processed_markdown[:1200] + "..." if len(processed_markdown) > 1200 else processed_markdown
739
-
740
- # Create chat messages with shorter context
741
- messages = [
742
- {
743
- "role": "system",
744
- "content": [{"type": "text", "text": "You are a helpful assistant. Answer questions about the document concisely."}]
745
- },
746
- {
747
- "role": "user",
748
- "content": [{"type": "text", "text": f"Context:\n{context}\n\nQ: {message}"}]
749
- }
750
- ]
751
 
752
- # Process with the model
753
- inputs = processor.apply_chat_template(
754
- messages,
755
- add_generation_prompt=True,
756
- tokenize=True,
757
- return_dict=True,
758
- return_tensors="pt",
759
- ).to(model.device)
760
-
761
- input_len = inputs["input_ids"].shape[-1]
762
 
763
- with torch.inference_mode():
764
- generation = model.generate(
765
- **inputs,
766
- max_new_tokens=300, # Can be higher now with single model
767
- do_sample=False,
768
- temperature=0.7,
769
- pad_token_id=processor.tokenizer.pad_token_id,
770
- use_cache=True, # Can enable cache with single model
771
- num_beams=1
772
- )
773
- generation = generation[0][input_len:]
774
 
775
- response = processor.decode(generation, skip_special_tokens=True)
776
 
777
- return history + [[message, response]]
778
 
779
  except Exception as e:
780
  error_msg = f"❌ Error generating response: {str(e)}"
 
 
 
781
  return history + [[message, error_msg]]
782
 
783
  send_btn.click(
 
15
  from sentence_transformers import SentenceTransformer
16
  import numpy as np
17
  from sklearn.metrics.pairwise import cosine_similarity
18
+ import google.generativeai as genai
19
  RAG_DEPENDENCIES_AVAILABLE = True
20
  except ImportError as e:
21
  print(f"RAG dependencies not available: {e}")
22
+ print("Please install: pip install sentence-transformers scikit-learn google-generativeai")
23
  RAG_DEPENDENCIES_AVAILABLE = False
24
  SentenceTransformer = None
25
  import os
 
321
  # Don't load models initially - load them on demand
322
  model_status = "βœ… Models ready (Dynamic loading)"
323
 
324
+ # Initialize embedding model and Gemini API
325
  if RAG_DEPENDENCIES_AVAILABLE:
326
  try:
327
  print("Loading embedding model for RAG...")
 
328
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
329
  print("βœ… Embedding model loaded successfully (CPU)")
330
+
331
+ # Initialize Gemini API
332
+ gemini_api_key = os.getenv('GEMINI_API_KEY')
333
+ if gemini_api_key:
334
+ genai.configure(api_key=gemini_api_key)
335
+ gemini_model = genai.GenerativeModel('gemma-3n-e4b-it')
336
+ print("βœ… Gemini API configured successfully")
337
+ else:
338
+ print("❌ GEMINI_API_KEY not found in environment")
339
+ gemini_model = None
340
  except Exception as e:
341
+ print(f"❌ Error loading models: {e}")
342
  import traceback
343
  traceback.print_exc()
344
  embedding_model = None
345
+ gemini_model = None
346
  else:
347
  print("❌ RAG dependencies not available")
348
  embedding_model = None
349
+ gemini_model = None
350
 
351
  # Model management functions
352
  def load_dolphin_model():
 
383
  torch.cuda.empty_cache()
384
  print("βœ… DOLPHIN model unloaded")
385
 
386
+ def initialize_gemini_model():
387
+ """Initialize Gemini API model"""
388
+ global gemini_model
 
 
 
389
 
390
+ if gemini_model is not None:
391
+ return gemini_model
392
 
393
  try:
394
+ gemini_api_key = os.getenv('GEMINI_API_KEY')
395
+ if not gemini_api_key:
396
+ print("❌ GEMINI_API_KEY not found in environment")
397
+ return None
398
+
399
+ print("Initializing Gemini API...")
400
+ genai.configure(api_key=gemini_api_key)
401
+ gemini_model = genai.GenerativeModel('gemma-3n-e4b-it')
402
+ print("βœ… Gemini API model ready")
403
+ return gemini_model
 
 
 
 
 
 
 
 
 
 
 
 
404
  except Exception as e:
405
+ print(f"❌ Error initializing Gemini model: {e}")
406
  import traceback
407
  traceback.print_exc()
408
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
 
411
  # Global state for managing tabs
 
413
  show_results_tab = False
414
  document_chunks = []
415
  document_embeddings = None
 
416
 
417
+ # Global model state
418
  dolphin_model = None
419
+ gemini_model = None
 
420
  current_model = None # Track which model is currently loaded
421
 
422
 
 
498
  document_embeddings = create_embeddings(document_chunks)
499
  print(f"Created {len(document_chunks)} chunks")
500
 
501
+ # Keep DOLPHIN model loaded for GPU usage
502
  progress(0.95, desc="Preparing chatbot...")
 
503
 
504
  show_results_tab = True
505
  progress(1.0, desc="PDF processed successfully!")
 
528
  document_chunks = []
529
  document_embeddings = None
530
 
531
+ # Unload DOLPHIN model
532
  unload_dolphin_model()
 
533
 
534
+ return None, "", gr.Tabs(visible=False)
535
 
536
 
537
  # Create Gradio interface
 
586
  # Home Tab
587
  with gr.TabItem("🏠 Home", id="home"):
588
  embedding_status = "βœ… RAG ready" if embedding_model else "❌ RAG not loaded"
589
+ gemini_status = "βœ… Gemini API ready" if gemini_model else "❌ Gemini API not configured"
590
  current_status = f"Currently loaded: {current_model or 'None'}"
591
  gr.Markdown(
592
  "# Scholar Express\n"
593
+ "### Upload a research paper to get a web-friendly version and an AI chatbot powered by Gemini API. DOLPHIN model runs on GPU for optimal performance.\n"
594
  f"**System:** {model_status}\n"
595
  f"**RAG System:** {embedding_status}\n"
596
+ f"**Gemini API:** {gemini_status}\n"
597
  f"**Status:** {current_status}"
598
  )
599
 
 
628
 
629
  # Status output (hidden during processing)
630
  status_output = gr.Markdown(
631
+ "",
632
  elem_classes="status-message"
633
  )
634
 
 
665
  send_btn = gr.Button("Send", variant="primary", scale=1)
666
 
667
  gr.Markdown(
668
+ "*Ask questions about your processed document. The AI uses RAG (Retrieval-Augmented Generation) with Gemini API to find relevant sections and provide accurate answers.*",
669
  elem_id="chat-notice"
670
  )
671
 
 
694
  outputs=[chat_tab]
695
  )
696
 
697
+ # Chatbot functionality with Gemini API
698
  def chatbot_response(message, history):
699
  if not message.strip():
700
  return history
 
703
  return history + [[message, "❌ Please process a PDF document first before asking questions."]]
704
 
705
  try:
706
+ # Initialize Gemini model
707
+ model = initialize_gemini_model()
708
 
709
+ if model is None:
710
+ return history + [[message, "❌ Failed to initialize Gemini model. Please check your GEMINI_API_KEY."]]
711
 
712
+ # Use RAG to get relevant chunks from markdown
713
  if document_chunks and len(document_chunks) > 0:
714
  relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
715
  context = "\n\n".join(relevant_chunks)
716
  else:
717
  # Fallback to truncated document if RAG fails
718
+ context = processed_markdown[:2000] + "..." if len(processed_markdown) > 2000 else processed_markdown
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
+ # Create prompt for Gemini
721
+ prompt = f"""You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely.
722
+
723
+ Context from the document:
724
+ {context}
725
+
726
+ Question: {message}
727
+
728
+ Please provide a clear and helpful answer based on the context provided."""
 
729
 
730
+ # Generate response using Gemini API
731
+ response = model.generate_content(prompt)
 
 
 
 
 
 
 
 
 
732
 
733
+ response_text = response.text if hasattr(response, 'text') else str(response)
734
 
735
+ return history + [[message, response_text]]
736
 
737
  except Exception as e:
738
  error_msg = f"❌ Error generating response: {str(e)}"
739
+ print(f"Full error: {e}")
740
+ import traceback
741
+ traceback.print_exc()
742
  return history + [[message, error_msg]]
743
 
744
  send_btn.click(