Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -29,7 +29,7 @@ from utils import (
|
|
| 29 |
)
|
| 30 |
|
| 31 |
# Initialize the model and tokenizer.
|
| 32 |
-
api_token = os.getenv("
|
| 33 |
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 34 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
|
| 35 |
model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
|
|
@@ -456,7 +456,9 @@ def run_naive_rag_query(collection_name, query, rag_token_size, prefix, task, fe
|
|
| 456 |
def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()):
|
| 457 |
progress(0, desc="Starting compression process")
|
| 458 |
|
| 459 |
-
percentage = int(global_local_value.replace('%', ''))
|
|
|
|
|
|
|
| 460 |
progress(0.1, desc="Tokenizing text and preparing task")
|
| 461 |
question_text = task_description + "\n" + few_shot
|
| 462 |
context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
|
|
@@ -538,6 +540,7 @@ def chat_response_stream(message: str, history: list, state: dict, compression_d
|
|
| 538 |
percentage = state["global_local"]
|
| 539 |
rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100)))
|
| 540 |
print("RAG retrieval size: ", rag_retrieval_size)
|
|
|
|
| 541 |
if percentage == 0:
|
| 542 |
rag_prefix = prefix
|
| 543 |
rag_task = state["task_description"]
|
|
@@ -583,7 +586,9 @@ def chat_response_stream(message: str, history: list, state: dict, compression_d
|
|
| 583 |
|
| 584 |
def update_token_breakdown(token_count, retrieval_slider, global_local_value):
|
| 585 |
retrieval_context_length = int(token_count / retrieval_slider)
|
| 586 |
-
percentage = int(global_local_value.replace('%', ''))
|
|
|
|
|
|
|
| 587 |
rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
|
| 588 |
kv_tokens = retrieval_context_length - rag_tokens
|
| 589 |
return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)"
|
|
@@ -592,36 +597,51 @@ def update_token_breakdown(token_count, retrieval_slider, global_local_value):
|
|
| 592 |
# Gradio Interface
|
| 593 |
##########################################################################
|
| 594 |
CSS = """
|
| 595 |
-
|
| 596 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
}
|
|
|
|
| 598 |
.upload-section {
|
| 599 |
padding: 10px;
|
| 600 |
border: 2px dashed #ccc;
|
| 601 |
border-radius: 10px;
|
| 602 |
}
|
|
|
|
| 603 |
.upload-button {
|
| 604 |
background: #34c759 !important;
|
| 605 |
color: white !important;
|
| 606 |
border-radius: 25px !important;
|
| 607 |
}
|
|
|
|
| 608 |
.chatbot-container {
|
| 609 |
-
margin-top:
|
| 610 |
}
|
|
|
|
| 611 |
.status-output {
|
| 612 |
margin-top: 10px;
|
| 613 |
font-size: 14px;
|
| 614 |
}
|
|
|
|
| 615 |
.processing-info {
|
| 616 |
margin-top: 5px;
|
| 617 |
font-size: 12px;
|
| 618 |
color: #666;
|
| 619 |
}
|
|
|
|
| 620 |
.info-container {
|
| 621 |
margin-top: 10px;
|
| 622 |
padding: 10px;
|
| 623 |
border-radius: 5px;
|
| 624 |
}
|
|
|
|
| 625 |
.file-list {
|
| 626 |
margin-top: 0;
|
| 627 |
max-height: 200px;
|
|
@@ -630,12 +650,14 @@ body {
|
|
| 630 |
border: 1px solid #eee;
|
| 631 |
border-radius: 5px;
|
| 632 |
}
|
|
|
|
| 633 |
.stats-box {
|
| 634 |
margin-top: 10px;
|
| 635 |
padding: 10px;
|
| 636 |
border-radius: 5px;
|
| 637 |
font-size: 12px;
|
| 638 |
}
|
|
|
|
| 639 |
.submit-btn {
|
| 640 |
background: #1a73e8 !important;
|
| 641 |
color: white !important;
|
|
@@ -644,18 +666,18 @@ body {
|
|
| 644 |
padding: 5px 10px;
|
| 645 |
font-size: 16px;
|
| 646 |
}
|
|
|
|
| 647 |
.input-row {
|
| 648 |
display: flex;
|
| 649 |
align-items: center;
|
| 650 |
}
|
| 651 |
-
|
| 652 |
"""
|
| 653 |
def reset_chat_state():
|
| 654 |
return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False
|
| 655 |
|
| 656 |
-
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
| 657 |
-
gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>")
|
| 658 |
-
gr.HTML("<center
|
| 659 |
|
| 660 |
# Define chat_status_text as a Textbox with a set elem_id for custom styling.
|
| 661 |
chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5)
|
|
@@ -666,13 +688,13 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
| 666 |
|
| 667 |
with gr.Row(elem_classes="main-container"):
|
| 668 |
with gr.Column(elem_classes="upload-section"):
|
| 669 |
-
gr.Markdown("
|
| 670 |
with gr.Row():
|
| 671 |
-
file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area")
|
| 672 |
-
url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf")
|
| 673 |
with gr.Row():
|
| 674 |
-
do_ocr = gr.Checkbox(label="Do OCR", value=False)
|
| 675 |
-
do_table = gr.Checkbox(label="
|
| 676 |
with gr.Accordion("Prompt Designer", open=False):
|
| 677 |
task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description")
|
| 678 |
few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot")
|
|
@@ -682,9 +704,15 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
| 682 |
retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
|
| 683 |
retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
|
| 684 |
tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
|
| 685 |
-
global_local_slider = gr.Radio(label="Hybrid Retrieval (0 is all RAG, 100 is all global)",
|
| 686 |
-
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
# File input: Run auto_convert then chain reset_chat_state.
|
| 690 |
file_input.change(
|
|
@@ -785,7 +813,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
| 785 |
)
|
| 786 |
with gr.Column(elem_classes="chatbot-container"):
|
| 787 |
chat_status_text.render()
|
| 788 |
-
gr.Markdown("## Chat")
|
| 789 |
chat_interface = gr.ChatInterface(
|
| 790 |
fn=chat_response_stream,
|
| 791 |
additional_inputs=[compressed_doc_state, compression_done],
|
|
@@ -793,5 +821,4 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
| 793 |
fill_height=True
|
| 794 |
)
|
| 795 |
|
| 796 |
-
demo.queue(max_size=16).launch()
|
| 797 |
-
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
# Initialize the model and tokenizer.
|
| 32 |
+
api_token = os.getenv("HF_TOKEN")
|
| 33 |
model_name = "meta-llama/Llama-3.1-8B-Instruct"
|
| 34 |
tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token)
|
| 35 |
model = AutoModelForCausalLM.from_pretrained(model_name, token=api_token, torch_dtype=torch.float16)
|
|
|
|
| 456 |
def prepare_compression_and_rag(combined_text, retrieval_slider_value, global_local_value, task_description, few_shot, state, progress=gr.Progress()):
|
| 457 |
progress(0, desc="Starting compression process")
|
| 458 |
|
| 459 |
+
# percentage = int(global_local_value.replace('%', ''))
|
| 460 |
+
percentage = 0 if global_local_value == "RAG" else 100
|
| 461 |
+
|
| 462 |
progress(0.1, desc="Tokenizing text and preparing task")
|
| 463 |
question_text = task_description + "\n" + few_shot
|
| 464 |
context_encoding = tokenizer(combined_text, return_tensors="pt").to(device)
|
|
|
|
| 540 |
percentage = state["global_local"]
|
| 541 |
rag_retrieval_size = int(retrieval_slider_value * (1.0 - (percentage / 100)))
|
| 542 |
print("RAG retrieval size: ", rag_retrieval_size)
|
| 543 |
+
print("Compressed cache: ", compressed_length)
|
| 544 |
if percentage == 0:
|
| 545 |
rag_prefix = prefix
|
| 546 |
rag_task = state["task_description"]
|
|
|
|
| 586 |
|
| 587 |
def update_token_breakdown(token_count, retrieval_slider, global_local_value):
|
| 588 |
retrieval_context_length = int(token_count / retrieval_slider)
|
| 589 |
+
# percentage = int(global_local_value.replace('%', ''))
|
| 590 |
+
percentage = 0 if global_local_value == "RAG" else 100
|
| 591 |
+
|
| 592 |
rag_tokens = int(retrieval_context_length * (1.0 - (percentage / 100)))
|
| 593 |
kv_tokens = retrieval_context_length - rag_tokens
|
| 594 |
return f"Token Breakdown: {kv_tokens} tokens (KV compression), {rag_tokens} tokens (RAG retrieval)"
|
|
|
|
| 597 |
# Gradio Interface
|
| 598 |
##########################################################################
|
| 599 |
CSS = """
|
| 600 |
+
.main-container {
|
| 601 |
+
display: flex;
|
| 602 |
+
align-items: stretch;
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
.upload-section, .chatbot-container {
|
| 606 |
+
display: flex;
|
| 607 |
+
flex-direction: column;
|
| 608 |
+
height: 100%;
|
| 609 |
+
overflow-y: auto;
|
| 610 |
}
|
| 611 |
+
|
| 612 |
.upload-section {
|
| 613 |
padding: 10px;
|
| 614 |
border: 2px dashed #ccc;
|
| 615 |
border-radius: 10px;
|
| 616 |
}
|
| 617 |
+
|
| 618 |
.upload-button {
|
| 619 |
background: #34c759 !important;
|
| 620 |
color: white !important;
|
| 621 |
border-radius: 25px !important;
|
| 622 |
}
|
| 623 |
+
|
| 624 |
.chatbot-container {
|
| 625 |
+
margin-top: 0;
|
| 626 |
}
|
| 627 |
+
|
| 628 |
.status-output {
|
| 629 |
margin-top: 10px;
|
| 630 |
font-size: 14px;
|
| 631 |
}
|
| 632 |
+
|
| 633 |
.processing-info {
|
| 634 |
margin-top: 5px;
|
| 635 |
font-size: 12px;
|
| 636 |
color: #666;
|
| 637 |
}
|
| 638 |
+
|
| 639 |
.info-container {
|
| 640 |
margin-top: 10px;
|
| 641 |
padding: 10px;
|
| 642 |
border-radius: 5px;
|
| 643 |
}
|
| 644 |
+
|
| 645 |
.file-list {
|
| 646 |
margin-top: 0;
|
| 647 |
max-height: 200px;
|
|
|
|
| 650 |
border: 1px solid #eee;
|
| 651 |
border-radius: 5px;
|
| 652 |
}
|
| 653 |
+
|
| 654 |
.stats-box {
|
| 655 |
margin-top: 10px;
|
| 656 |
padding: 10px;
|
| 657 |
border-radius: 5px;
|
| 658 |
font-size: 12px;
|
| 659 |
}
|
| 660 |
+
|
| 661 |
.submit-btn {
|
| 662 |
background: #1a73e8 !important;
|
| 663 |
color: white !important;
|
|
|
|
| 666 |
padding: 5px 10px;
|
| 667 |
font-size: 16px;
|
| 668 |
}
|
| 669 |
+
|
| 670 |
.input-row {
|
| 671 |
display: flex;
|
| 672 |
align-items: center;
|
| 673 |
}
|
|
|
|
| 674 |
"""
|
| 675 |
def reset_chat_state():
|
| 676 |
return gr.update(value="Document not compressed yet. Please compress the document to enable chat."), False
|
| 677 |
|
| 678 |
+
with gr.Blocks(css=CSS, theme=gr.themes.Soft(font=["Arial", gr.themes.GoogleFont("Inconsolata"), "sans-serif"])) as demo:
|
| 679 |
+
# gr.HTML("<h1><center>Beyond RAG with LLama 3.1-8B-Instruct Model</center></h1>")
|
| 680 |
+
gr.HTML("<h1><center>Beyond RAG: Compress your document and chat with it.</center></h1>")
|
| 681 |
|
| 682 |
# Define chat_status_text as a Textbox with a set elem_id for custom styling.
|
| 683 |
chat_status_text = gr.Textbox(value="Document not compressed yet. Please compress the document to enable chat.", interactive=False, show_label=False, render=False, lines=5)
|
|
|
|
| 688 |
|
| 689 |
with gr.Row(elem_classes="main-container"):
|
| 690 |
with gr.Column(elem_classes="upload-section"):
|
| 691 |
+
gr.Markdown("### Document Preprocessing")
|
| 692 |
with gr.Row():
|
| 693 |
+
file_input = gr.File(label="Drop file here or upload", file_count="multiple", elem_id="file-upload-area", height=120)
|
| 694 |
+
url_input = gr.Textbox(label="or enter a URL", placeholder="https://example.com/document.pdf", lines=2)
|
| 695 |
with gr.Row():
|
| 696 |
+
do_ocr = gr.Checkbox(label="Do OCR on Images", value=True, visible=False)
|
| 697 |
+
do_table = gr.Checkbox(label="Parse Tables", value=True, visible=False)
|
| 698 |
with gr.Accordion("Prompt Designer", open=False):
|
| 699 |
task_description_input = gr.Textbox(label="Task Description", value=default_task_description, lines=3, elem_id="task-description")
|
| 700 |
few_shot_input = gr.Textbox(label="Few-Shot Examples", value=default_few_shot, lines=10, elem_id="few-shot")
|
|
|
|
| 704 |
retrieval_slider = gr.Slider(label="Select Compression Rate", minimum=1, maximum=32, step=1, value=2)
|
| 705 |
retrieval_info_text = gr.Markdown("Number of tokens after compression: ")
|
| 706 |
tokens_breakdown_text = gr.Markdown("Token breakdown will appear here.")
|
| 707 |
+
# global_local_slider = gr.Radio(label="Hybrid Retrieval (0 is all RAG, 100 is all global)",
|
| 708 |
+
# choices=["0%", "25%", "50%", "75%", "100%"], value="100%")
|
| 709 |
+
global_local_slider = gr.Radio(
|
| 710 |
+
label="Retrieval Mode",
|
| 711 |
+
choices=["RAG", "KVCompress"],
|
| 712 |
+
value="KVCompress"
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
compress_button = gr.Button("Compress Document", interactive=False, size="md", elem_classes="upload-button")
|
| 716 |
|
| 717 |
# File input: Run auto_convert then chain reset_chat_state.
|
| 718 |
file_input.change(
|
|
|
|
| 813 |
)
|
| 814 |
with gr.Column(elem_classes="chatbot-container"):
|
| 815 |
chat_status_text.render()
|
| 816 |
+
gr.Markdown("## Chat (LLama 3.1-8B-Instruct)")
|
| 817 |
chat_interface = gr.ChatInterface(
|
| 818 |
fn=chat_response_stream,
|
| 819 |
additional_inputs=[compressed_doc_state, compression_done],
|
|
|
|
| 821 |
fill_height=True
|
| 822 |
)
|
| 823 |
|
| 824 |
+
demo.queue(max_size=16).launch()
|
|
|