Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -313,23 +313,20 @@ model_path = "./hf_model"
|
|
313 |
if not os.path.exists(model_path):
|
314 |
model_path = "ByteDance/DOLPHIN"
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
# Initialize embedding model for RAG
|
326 |
if RAG_DEPENDENCIES_AVAILABLE:
|
327 |
try:
|
328 |
print("Loading embedding model for RAG...")
|
329 |
-
# Use
|
330 |
-
|
331 |
-
|
332 |
-
print(f"β
Embedding model loaded successfully ({device})")
|
333 |
except Exception as e:
|
334 |
print(f"β Error loading embedding model: {e}")
|
335 |
import traceback
|
@@ -339,39 +336,94 @@ else:
|
|
339 |
print("β RAG dependencies not available")
|
340 |
embedding_model = None
|
341 |
|
342 |
-
#
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
chatbot_model = None
|
367 |
chatbot_processor = None
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
chatbot_model = None
|
374 |
-
chatbot_processor = None
|
375 |
|
376 |
|
377 |
# Global state for managing tabs
|
@@ -380,10 +432,15 @@ show_results_tab = False
|
|
380 |
document_chunks = []
|
381 |
document_embeddings = None
|
382 |
embedding_model = None
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
|
386 |
-
def chunk_document(text, chunk_size=
|
387 |
"""Split document into overlapping chunks for RAG"""
|
388 |
words = text.split()
|
389 |
chunks = []
|
@@ -401,8 +458,8 @@ def create_embeddings(chunks):
|
|
401 |
return None
|
402 |
|
403 |
try:
|
404 |
-
# Process in
|
405 |
-
batch_size =
|
406 |
embeddings = []
|
407 |
|
408 |
for i in range(0, len(chunks), batch_size):
|
@@ -415,10 +472,10 @@ def create_embeddings(chunks):
|
|
415 |
print(f"Error creating embeddings: {e}")
|
416 |
return None
|
417 |
|
418 |
-
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=
|
419 |
"""Retrieve most relevant chunks for a question"""
|
420 |
if embedding_model is None or embeddings is None:
|
421 |
-
return chunks[:
|
422 |
|
423 |
try:
|
424 |
question_embedding = embedding_model.encode([question], show_progress_bar=False)
|
@@ -431,32 +488,43 @@ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
|
|
431 |
return relevant_chunks
|
432 |
except Exception as e:
|
433 |
print(f"Error retrieving chunks: {e}")
|
434 |
-
return chunks[:
|
435 |
|
436 |
def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
|
437 |
"""Main processing function for uploaded PDF"""
|
438 |
global processed_markdown, show_results_tab, document_chunks, document_embeddings
|
439 |
|
440 |
-
if dolphin_model is None:
|
441 |
-
return "β Model not loaded", gr.Tabs(visible=False)
|
442 |
-
|
443 |
if pdf_file is None:
|
444 |
return "β No PDF uploaded", gr.Tabs(visible=False)
|
445 |
|
446 |
try:
|
447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
|
449 |
if status == "processing_complete":
|
450 |
processed_markdown = combined_markdown
|
451 |
|
452 |
# Create chunks and embeddings for RAG
|
453 |
-
|
454 |
document_chunks = chunk_document(processed_markdown)
|
455 |
document_embeddings = create_embeddings(document_chunks)
|
456 |
print(f"Created {len(document_chunks)} chunks")
|
457 |
|
|
|
|
|
|
|
|
|
458 |
show_results_tab = True
|
459 |
-
|
|
|
460 |
else:
|
461 |
show_results_tab = False
|
462 |
return combined_markdown, gr.Tabs(visible=False)
|
@@ -481,7 +549,11 @@ def clear_all():
|
|
481 |
document_chunks = []
|
482 |
document_embeddings = None
|
483 |
|
484 |
-
|
|
|
|
|
|
|
|
|
485 |
|
486 |
|
487 |
# Create Gradio interface
|
@@ -535,14 +607,14 @@ with gr.Blocks(
|
|
535 |
with gr.Tabs() as main_tabs:
|
536 |
# Home Tab
|
537 |
with gr.TabItem("π Home", id="home"):
|
538 |
-
chatbot_status = "β
Chatbot ready" if chatbot_model else "β Chatbot not loaded"
|
539 |
embedding_status = "β
RAG ready" if embedding_model else "β RAG not loaded"
|
|
|
540 |
gr.Markdown(
|
541 |
"# Scholar Express\n"
|
542 |
-
"### Upload a research paper to get a web-friendly version, an AI chatbot, and a podcast summary.
|
543 |
-
f"**
|
544 |
-
f"**
|
545 |
-
f"**
|
546 |
)
|
547 |
|
548 |
with gr.Column(elem_classes="upload-container"):
|
@@ -647,56 +719,60 @@ with gr.Blocks(
|
|
647 |
if not message.strip():
|
648 |
return history
|
649 |
|
650 |
-
if chatbot_model is None:
|
651 |
-
return history + [[message, "β Chatbot model not loaded. Please check your HuggingFace token."]]
|
652 |
-
|
653 |
if not processed_markdown:
|
654 |
return history + [[message, "β Please process a PDF document first before asking questions."]]
|
655 |
|
656 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
657 |
# Use RAG to get relevant chunks instead of full document
|
658 |
if document_chunks and len(document_chunks) > 0:
|
659 |
relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
|
660 |
context = "\n\n".join(relevant_chunks)
|
661 |
else:
|
662 |
# Fallback to truncated document if RAG fails
|
663 |
-
context = processed_markdown[:
|
664 |
|
665 |
-
# Create chat messages
|
666 |
messages = [
|
667 |
{
|
668 |
"role": "system",
|
669 |
-
"content": [{"type": "text", "text": "You are a helpful assistant
|
670 |
},
|
671 |
{
|
672 |
"role": "user",
|
673 |
-
"content": [{"type": "text", "text": f"
|
674 |
}
|
675 |
]
|
676 |
|
677 |
# Process with the model
|
678 |
-
inputs =
|
679 |
messages,
|
680 |
add_generation_prompt=True,
|
681 |
tokenize=True,
|
682 |
return_dict=True,
|
683 |
return_tensors="pt",
|
684 |
-
).to(
|
685 |
|
686 |
input_len = inputs["input_ids"].shape[-1]
|
687 |
|
688 |
with torch.inference_mode():
|
689 |
-
generation =
|
690 |
**inputs,
|
691 |
-
max_new_tokens=
|
692 |
do_sample=False,
|
693 |
temperature=0.7,
|
694 |
-
pad_token_id=
|
695 |
-
use_cache=True #
|
|
|
696 |
)
|
697 |
generation = generation[0][input_len:]
|
698 |
|
699 |
-
response =
|
700 |
|
701 |
return history + [[message, response]]
|
702 |
|
|
|
313 |
if not os.path.exists(model_path):
|
314 |
model_path = "ByteDance/DOLPHIN"
|
315 |
|
316 |
+
# Model paths and configuration
|
317 |
+
model_path = "./hf_model" if os.path.exists("./hf_model") else "ByteDance/DOLPHIN"
|
318 |
+
hf_token = os.getenv('HF_TOKEN')
|
319 |
+
|
320 |
+
# Don't load models initially - load them on demand
|
321 |
+
model_status = "β
Models ready (Dynamic loading)"
|
322 |
+
|
323 |
+
# Initialize embedding model for RAG (CPU to save GPU memory)
|
|
|
|
|
324 |
if RAG_DEPENDENCIES_AVAILABLE:
|
325 |
try:
|
326 |
print("Loading embedding model for RAG...")
|
327 |
+
# Use CPU for embedding model to save GPU memory for main models
|
328 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
329 |
+
print("β
Embedding model loaded successfully (CPU)")
|
|
|
330 |
except Exception as e:
|
331 |
print(f"β Error loading embedding model: {e}")
|
332 |
import traceback
|
|
|
336 |
print("β RAG dependencies not available")
|
337 |
embedding_model = None
|
338 |
|
339 |
+
# Model management functions
|
340 |
+
def load_dolphin_model():
|
341 |
+
"""Load DOLPHIN model for PDF processing"""
|
342 |
+
global dolphin_model, current_model
|
343 |
+
|
344 |
+
if current_model == "dolphin":
|
345 |
+
return dolphin_model
|
346 |
+
|
347 |
+
# Unload chatbot model if loaded
|
348 |
+
unload_chatbot_model()
|
349 |
+
|
350 |
+
try:
|
351 |
+
print("Loading DOLPHIN model...")
|
352 |
+
dolphin_model = DOLPHIN(model_path)
|
353 |
+
current_model = "dolphin"
|
354 |
+
print(f"β
DOLPHIN model loaded (Device: {dolphin_model.device})")
|
355 |
+
return dolphin_model
|
356 |
+
except Exception as e:
|
357 |
+
print(f"β Error loading DOLPHIN model: {e}")
|
358 |
+
return None
|
359 |
+
|
360 |
+
def unload_dolphin_model():
|
361 |
+
"""Unload DOLPHIN model to free memory"""
|
362 |
+
global dolphin_model, current_model
|
363 |
+
|
364 |
+
if dolphin_model is not None:
|
365 |
+
print("Unloading DOLPHIN model...")
|
366 |
+
del dolphin_model
|
367 |
+
dolphin_model = None
|
368 |
+
if current_model == "dolphin":
|
369 |
+
current_model = None
|
370 |
+
if torch.cuda.is_available():
|
371 |
+
torch.cuda.empty_cache()
|
372 |
+
print("β
DOLPHIN model unloaded")
|
373 |
+
|
374 |
+
def load_chatbot_model():
|
375 |
+
"""Load Gemma chatbot model"""
|
376 |
+
global chatbot_model, chatbot_processor, current_model
|
377 |
+
|
378 |
+
if current_model == "chatbot":
|
379 |
+
return chatbot_model, chatbot_processor
|
380 |
+
|
381 |
+
# Unload DOLPHIN model if loaded
|
382 |
+
unload_dolphin_model()
|
383 |
+
|
384 |
+
try:
|
385 |
+
print("Loading Gemma chatbot model...")
|
386 |
+
print(f"HF_TOKEN found: {'Yes' if hf_token else 'No'}")
|
387 |
+
|
388 |
+
if hf_token:
|
389 |
+
chatbot_model = Gemma3nForConditionalGeneration.from_pretrained(
|
390 |
+
"google/gemma-3n-e4b-it",
|
391 |
+
device_map="auto",
|
392 |
+
torch_dtype=torch.bfloat16,
|
393 |
+
token=hf_token
|
394 |
+
).eval()
|
395 |
+
|
396 |
+
chatbot_processor = AutoProcessor.from_pretrained(
|
397 |
+
"google/gemma-3n-e4b-it",
|
398 |
+
token=hf_token
|
399 |
+
)
|
400 |
+
|
401 |
+
current_model = "chatbot"
|
402 |
+
print("β
Gemma chatbot model loaded")
|
403 |
+
return chatbot_model, chatbot_processor
|
404 |
+
else:
|
405 |
+
print("β No HF_TOKEN found")
|
406 |
+
return None, None
|
407 |
+
except Exception as e:
|
408 |
+
print(f"β Error loading chatbot model: {e}")
|
409 |
+
import traceback
|
410 |
+
traceback.print_exc()
|
411 |
+
return None, None
|
412 |
+
|
413 |
+
def unload_chatbot_model():
|
414 |
+
"""Unload chatbot model to free memory"""
|
415 |
+
global chatbot_model, chatbot_processor, current_model
|
416 |
+
|
417 |
+
if chatbot_model is not None:
|
418 |
+
print("Unloading Gemma chatbot model...")
|
419 |
+
del chatbot_model, chatbot_processor
|
420 |
chatbot_model = None
|
421 |
chatbot_processor = None
|
422 |
+
if current_model == "chatbot":
|
423 |
+
current_model = None
|
424 |
+
if torch.cuda.is_available():
|
425 |
+
torch.cuda.empty_cache()
|
426 |
+
print("β
Gemma chatbot model unloaded")
|
|
|
|
|
427 |
|
428 |
|
429 |
# Global state for managing tabs
|
|
|
432 |
document_chunks = []
|
433 |
document_embeddings = None
|
434 |
embedding_model = None
|
435 |
+
|
436 |
+
# Global model state - only one model loaded at a time
|
437 |
+
dolphin_model = None
|
438 |
+
chatbot_model = None
|
439 |
+
chatbot_processor = None
|
440 |
+
current_model = None # Track which model is currently loaded
|
441 |
|
442 |
|
443 |
+
def chunk_document(text, chunk_size=400, overlap=40):
|
444 |
"""Split document into overlapping chunks for RAG"""
|
445 |
words = text.split()
|
446 |
chunks = []
|
|
|
458 |
return None
|
459 |
|
460 |
try:
|
461 |
+
# Process in smaller batches on CPU
|
462 |
+
batch_size = 32
|
463 |
embeddings = []
|
464 |
|
465 |
for i in range(0, len(chunks), batch_size):
|
|
|
472 |
print(f"Error creating embeddings: {e}")
|
473 |
return None
|
474 |
|
475 |
+
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=2):
|
476 |
"""Retrieve most relevant chunks for a question"""
|
477 |
if embedding_model is None or embeddings is None:
|
478 |
+
return chunks[:2] # Fallback to first 2 chunks
|
479 |
|
480 |
try:
|
481 |
question_embedding = embedding_model.encode([question], show_progress_bar=False)
|
|
|
488 |
return relevant_chunks
|
489 |
except Exception as e:
|
490 |
print(f"Error retrieving chunks: {e}")
|
491 |
+
return chunks[:2] # Fallback
|
492 |
|
493 |
def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
|
494 |
"""Main processing function for uploaded PDF"""
|
495 |
global processed_markdown, show_results_tab, document_chunks, document_embeddings
|
496 |
|
|
|
|
|
|
|
497 |
if pdf_file is None:
|
498 |
return "β No PDF uploaded", gr.Tabs(visible=False)
|
499 |
|
500 |
try:
|
501 |
+
# Load DOLPHIN model for PDF processing
|
502 |
+
progress(0.1, desc="Loading DOLPHIN model...")
|
503 |
+
dolphin = load_dolphin_model()
|
504 |
+
|
505 |
+
if dolphin is None:
|
506 |
+
return "β Failed to load DOLPHIN model", gr.Tabs(visible=False)
|
507 |
+
|
508 |
+
# Process PDF
|
509 |
+
progress(0.2, desc="Processing PDF...")
|
510 |
+
combined_markdown, status = process_pdf_document(pdf_file, dolphin, progress)
|
511 |
|
512 |
if status == "processing_complete":
|
513 |
processed_markdown = combined_markdown
|
514 |
|
515 |
# Create chunks and embeddings for RAG
|
516 |
+
progress(0.9, desc="Creating document chunks for RAG...")
|
517 |
document_chunks = chunk_document(processed_markdown)
|
518 |
document_embeddings = create_embeddings(document_chunks)
|
519 |
print(f"Created {len(document_chunks)} chunks")
|
520 |
|
521 |
+
# Unload DOLPHIN model to free memory for chatbot
|
522 |
+
progress(0.95, desc="Preparing chatbot...")
|
523 |
+
unload_dolphin_model()
|
524 |
+
|
525 |
show_results_tab = True
|
526 |
+
progress(1.0, desc="PDF processed successfully!")
|
527 |
+
return "β
PDF processed successfully! Chatbot is ready in the Chat tab.", gr.Tabs(visible=True)
|
528 |
else:
|
529 |
show_results_tab = False
|
530 |
return combined_markdown, gr.Tabs(visible=False)
|
|
|
549 |
document_chunks = []
|
550 |
document_embeddings = None
|
551 |
|
552 |
+
# Unload any loaded models
|
553 |
+
unload_dolphin_model()
|
554 |
+
unload_chatbot_model()
|
555 |
+
|
556 |
+
return None, "β
Ready to process your PDF", gr.Tabs(visible=False)
|
557 |
|
558 |
|
559 |
# Create Gradio interface
|
|
|
607 |
with gr.Tabs() as main_tabs:
|
608 |
# Home Tab
|
609 |
with gr.TabItem("π Home", id="home"):
|
|
|
610 |
embedding_status = "β
RAG ready" if embedding_model else "β RAG not loaded"
|
611 |
+
current_status = f"Currently loaded: {current_model or 'None'}"
|
612 |
gr.Markdown(
|
613 |
"# Scholar Express\n"
|
614 |
+
"### Upload a research paper to get a web-friendly version, an AI chatbot, and a podcast summary. Models are loaded dynamically to optimize memory usage.\n"
|
615 |
+
f"**System:** {model_status}\n"
|
616 |
+
f"**RAG System:** {embedding_status}\n"
|
617 |
+
f"**Status:** {current_status}"
|
618 |
)
|
619 |
|
620 |
with gr.Column(elem_classes="upload-container"):
|
|
|
719 |
if not message.strip():
|
720 |
return history
|
721 |
|
|
|
|
|
|
|
722 |
if not processed_markdown:
|
723 |
return history + [[message, "β Please process a PDF document first before asking questions."]]
|
724 |
|
725 |
try:
|
726 |
+
# Load chatbot model
|
727 |
+
model, processor = load_chatbot_model()
|
728 |
+
|
729 |
+
if model is None or processor is None:
|
730 |
+
return history + [[message, "β Failed to load chatbot model. Please check your HuggingFace token."]]
|
731 |
+
|
732 |
# Use RAG to get relevant chunks instead of full document
|
733 |
if document_chunks and len(document_chunks) > 0:
|
734 |
relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
|
735 |
context = "\n\n".join(relevant_chunks)
|
736 |
else:
|
737 |
# Fallback to truncated document if RAG fails
|
738 |
+
context = processed_markdown[:1200] + "..." if len(processed_markdown) > 1200 else processed_markdown
|
739 |
|
740 |
+
# Create chat messages with shorter context
|
741 |
messages = [
|
742 |
{
|
743 |
"role": "system",
|
744 |
+
"content": [{"type": "text", "text": "You are a helpful assistant. Answer questions about the document concisely."}]
|
745 |
},
|
746 |
{
|
747 |
"role": "user",
|
748 |
+
"content": [{"type": "text", "text": f"Context:\n{context}\n\nQ: {message}"}]
|
749 |
}
|
750 |
]
|
751 |
|
752 |
# Process with the model
|
753 |
+
inputs = processor.apply_chat_template(
|
754 |
messages,
|
755 |
add_generation_prompt=True,
|
756 |
tokenize=True,
|
757 |
return_dict=True,
|
758 |
return_tensors="pt",
|
759 |
+
).to(model.device)
|
760 |
|
761 |
input_len = inputs["input_ids"].shape[-1]
|
762 |
|
763 |
with torch.inference_mode():
|
764 |
+
generation = model.generate(
|
765 |
**inputs,
|
766 |
+
max_new_tokens=300, # Can be higher now with single model
|
767 |
do_sample=False,
|
768 |
temperature=0.7,
|
769 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
770 |
+
use_cache=True, # Can enable cache with single model
|
771 |
+
num_beams=1
|
772 |
)
|
773 |
generation = generation[0][input_len:]
|
774 |
|
775 |
+
response = processor.decode(generation, skip_special_tokens=True)
|
776 |
|
777 |
return history + [[message, response]]
|
778 |
|