openfree's picture
Update app.py
56cc871 verified
import gradio as gr
import os
import PyPDF2
import logging
import torch
import threading
import time
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList,
)
from transformers import logging as hf_logging
import spaces
from llama_index.core import (
StorageContext,
VectorStoreIndex,
load_index_from_storage,
Document as LlamaDocument,
)
from llama_index.core import Settings
from llama_index.core.node_parser import (
HierarchicalNodeParser,
get_leaf_nodes,
get_root_nodes,
)
from llama_index.core.retrievers import AutoMergingRetriever
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
hf_logging.set_verbosity_error()
MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("HF_TOKEN not found in environment variables")
# --- UI Settings ---
TITLE = "<h1 style='text-align:center; margin-bottom: 20px;'>Local Thinking RAG: Llama 3.1 8B</h1>"
DISCORD_BADGE = """<p style="text-align:center; margin-top: -10px;">
<a href="https://discord.gg/openfreeai" target="_blank">
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="badge">
</a>
</p>
"""
CSS = """
.upload-section {
max-width: 400px;
margin: 0 auto;
padding: 10px;
border: 2px dashed #ccc;
border-radius: 10px;
}
.upload-button {
background: #34c759 !important;
color: white !important;
border-radius: 25px !important;
}
.chatbot-container {
margin-top: 20px;
}
.status-output {
margin-top: 10px;
font-size: 14px;
}
.processing-info {
margin-top: 5px;
font-size: 12px;
color: #666;
}
.info-container {
margin-top: 10px;
padding: 10px;
border-radius: 5px;
}
.file-list {
margin-top: 0;
max-height: 200px;
overflow-y: auto;
padding: 5px;
border: 1px solid #eee;
border-radius: 5px;
}
.stats-box {
margin-top: 10px;
padding: 10px;
border-radius: 5px;
font-size: 12px;
}
.submit-btn {
background: #1a73e8 !important;
color: white !important;
border-radius: 25px !important;
margin-left: 10px;
padding: 5px 10px;
font-size: 16px;
}
.input-row {
display: flex;
align-items: center;
}
@media (min-width: 768px) {
.main-container {
display: flex;
justify-content: space-between;
gap: 20px;
}
.upload-section {
flex: 1;
max-width: 300px;
}
.chatbot-container {
flex: 2;
margin-top: 0;
}
}
"""
global_model = None
global_tokenizer = None
global_file_info = {}
def initialize_model_and_tokenizer():
global global_model, global_tokenizer
if global_model is None or global_tokenizer is None:
logger.info("Initializing model and tokenizer...")
global_tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN)
global_model = AutoModelForCausalLM.from_pretrained(
MODEL,
device_map="auto",
trust_remote_code=True,
token=HF_TOKEN,
torch_dtype=torch.float16
)
logger.info("Model and tokenizer initialized successfully")
def get_llm(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50):
global global_model, global_tokenizer
if global_model is None or global_tokenizer is None:
initialize_model_and_tokenizer()
return HuggingFaceLLM(
context_window=4096,
max_new_tokens=max_new_tokens,
tokenizer=global_tokenizer,
model=global_model,
generate_kwargs={
"do_sample": True,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p
}
)
def extract_text_from_document(file):
file_name = file.name
file_extension = os.path.splitext(file_name)[1].lower()
if file_extension == '.txt':
text = file.read().decode('utf-8')
return text, len(text.split()), None
elif file_extension == '.pdf':
pdf_reader = PyPDF2.PdfReader(file)
text = "\n\n".join(page.extract_text() for page in pdf_reader.pages)
return text, len(text.split()), None
else:
return None, 0, ValueError(f"Unsupported file format: {file_extension}")
@spaces.GPU()
def create_or_update_index(files, request: gr.Request):
global global_file_info
if not files:
return "Please provide files.", ""
start_time = time.time()
user_id = request.session_hash
save_dir = f"./{user_id}_index"
# Initialize LlamaIndex modules
llm = get_llm()
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
Settings.llm = llm
Settings.embed_model = embed_model
file_stats = []
new_documents = []
for file in tqdm(files, desc="Processing files"):
file_basename = os.path.basename(file.name)
text, word_count, error = extract_text_from_document(file)
if error:
logger.error(f"Error processing file {file_basename}: {str(error)}")
file_stats.append({
"name": file_basename,
"words": 0,
"status": f"error: {str(error)}"
})
continue
doc = LlamaDocument(
text=text,
metadata={
"file_name": file_basename,
"word_count": word_count,
"source": "user_upload"
}
)
new_documents.append(doc)
file_stats.append({
"name": file_basename,
"words": word_count,
"status": "processed"
})
global_file_info[file_basename] = {
"word_count": word_count,
"processed_at": time.time()
}
node_parser = HierarchicalNodeParser.from_defaults(
chunk_sizes=[2048, 512, 128],
chunk_overlap=20
)
logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes")
new_nodes = node_parser.get_nodes_from_documents(new_documents)
new_leaf_nodes = get_leaf_nodes(new_nodes)
new_root_nodes = get_root_nodes(new_nodes)
logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)")
if os.path.exists(save_dir):
logger.info(f"Loading existing index from {save_dir}")
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
index = load_index_from_storage(storage_context, settings=Settings)
docstore = storage_context.docstore
docstore.add_documents(new_nodes)
for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"):
index.insert_nodes([node])
total_docs = len(docstore.docs)
logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files")
else:
logger.info("Creating new index")
docstore = SimpleDocumentStore()
storage_context = StorageContext.from_defaults(docstore=docstore)
docstore.add_documents(new_nodes)
index = VectorStoreIndex(
new_leaf_nodes,
storage_context=storage_context,
settings=Settings
)
total_docs = len(new_documents)
logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files")
index.storage_context.persist(persist_dir=save_dir)
# custom outputs after processing files
file_list_html = "<div class='file-list'>"
for stat in file_stats:
status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336"
file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>"
file_list_html += "</div>"
processing_time = time.time() - start_time
stats_output = f"<div class='stats-box'>"
stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>"
stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>"
stats_output += f"✓ Total documents in index: {total_docs}<br>"
stats_output += f"✓ Index saved to: {save_dir}<br>"
stats_output += "</div>"
output_container = f"<div class='info-container'>"
output_container += file_list_html
output_container += stats_output
output_container += "</div>"
return f"Successfully indexed {len(files)} files.", output_container
@spaces.GPU()
def stream_chat(
message: str,
history: list,
system_prompt: str,
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
penalty: float,
retriever_k: int,
merge_threshold: float,
request: gr.Request
):
if not request:
yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}]
return
user_id = request.session_hash
index_dir = f"./{user_id}_index"
if not os.path.exists(index_dir):
yield history + [{"role": "assistant", "content": "Please upload documents first."}]
return
max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 1024
temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.9
top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95
top_k = int(top_k) if isinstance(top_k, (int, float)) else 50
penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2
retriever_k = int(retriever_k) if isinstance(retriever_k, (int, float)) else 15
merge_threshold = float(merge_threshold) if isinstance(merge_threshold, (int, float)) else 0.5
llm = get_llm(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k)
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN)
Settings.llm = llm
Settings.embed_model = embed_model
storage_context = StorageContext.from_defaults(persist_dir=index_dir)
index = load_index_from_storage(storage_context, settings=Settings)
base_retriever = index.as_retriever(similarity_top_k=retriever_k)
auto_merging_retriever = AutoMergingRetriever(
base_retriever,
storage_context=storage_context,
simple_ratio_thresh=merge_threshold,
verbose=True
)
logger.info(f"Query: {message}")
retrieval_start = time.time()
base_nodes = base_retriever.retrieve(message)
logger.info(f"Retrieved {len(base_nodes)} base nodes in {time.time() - retrieval_start:.2f}s")
base_file_sources = {}
for node in base_nodes:
if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
file_name = node.node.metadata['file_name']
if file_name not in base_file_sources:
base_file_sources[file_name] = 0
base_file_sources[file_name] += 1
logger.info(f"Base retrieval file distribution: {base_file_sources}")
merging_start = time.time()
merged_nodes = auto_merging_retriever.retrieve(message)
logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - merging_start:.2f}s")
merged_file_sources = {}
for node in merged_nodes:
if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata:
file_name = node.node.metadata['file_name']
if file_name not in merged_file_sources:
merged_file_sources[file_name] = 0
merged_file_sources[file_name] += 1
logger.info(f"Merged retrieval file distribution: {merged_file_sources}")
context = "\n\n".join([n.node.text for n in merged_nodes])
source_info = ""
if merged_file_sources:
source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys())
formatted_system_prompt = f"{system_prompt}\n\nDocument Context:\n{context}{source_info}"
messages = [{"role": "system", "content": formatted_system_prompt}]
for entry in history:
messages.append(entry)
messages.append({"role": "user", "content": message})
prompt = global_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
stop_event = threading.Event()
class StopOnEvent(StoppingCriteria):
def __init__(self, stop_event):
super().__init__()
self.stop_event = stop_event
def __call__(self, input_ids, scores, **kwargs):
return self.stop_event.is_set()
stopping_criteria = StoppingCriteriaList([StopOnEvent(stop_event)])
streamer = TextIteratorStreamer(
global_tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=penalty,
do_sample=True,
stopping_criteria=stopping_criteria
)
thread = threading.Thread(target=global_model.generate, kwargs=generation_kwargs)
thread.start()
updated_history = history + [
{"role": "user", "content": message},
{"role": "assistant", "content": ""}
]
yield updated_history
partial_response = ""
try:
for new_text in streamer:
partial_response += new_text
updated_history[-1]["content"] = partial_response
yield updated_history
yield updated_history
except GeneratorExit:
stop_event.set()
thread.join()
raise
def create_demo():
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
# Title
gr.HTML(TITLE)
# Discord badge immediately under the title
gr.HTML(DISCORD_BADGE)
with gr.Row(elem_classes="main-container"):
with gr.Column(elem_classes="upload-section"):
file_upload = gr.File(
file_count="multiple",
label="Drag & Drop PDF/TXT Files Here",
file_types=[".pdf", ".txt"],
elem_id="file-upload"
)
upload_button = gr.Button("Upload & Index", elem_classes="upload-button")
status_output = gr.Textbox(
label="Status",
placeholder="Upload files to start...",
interactive=False
)
file_info_output = gr.HTML(
label="File Information",
elem_classes="processing-info"
)
upload_button.click(
fn=create_or_update_index,
inputs=[file_upload],
outputs=[status_output, file_info_output]
)
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
height=500,
placeholder="Chat with your documents...",
show_label=False,
type="messages"
)
with gr.Row(elem_classes="input-row"):
message_input = gr.Textbox(
placeholder="Type your question here...",
show_label=False,
container=False,
lines=1,
scale=8
)
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1)
with gr.Accordion("Advanced Settings", open=False):
system_prompt = gr.Textbox(
value="You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. As a knowledgeable assistant, provide detailed answers using the relevant information from all uploaded documents.",
label="System Prompt",
lines=3
)
with gr.Tab("Generation Parameters"):
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.9,
label="Temperature"
)
max_new_tokens = gr.Slider(
minimum=128,
maximum=8192,
step=64,
value=1024,
label="Max New Tokens",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.95,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top K"
)
penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition Penalty"
)
with gr.Tab("Retrieval Parameters"):
retriever_k = gr.Slider(
minimum=5,
maximum=30,
step=1,
value=15,
label="Initial Retrieval Size (Top K)"
)
merge_threshold = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.1,
value=0.5,
label="Merge Threshold (lower = more merging)"
)
submit_button.click(
fn=stream_chat,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold
],
outputs=chatbot
)
message_input.submit(
fn=stream_chat,
inputs=[
message_input,
chatbot,
system_prompt,
temperature,
max_new_tokens,
top_p,
top_k,
penalty,
retriever_k,
merge_threshold
],
outputs=chatbot
)
return demo
if __name__ == "__main__":
initialize_model_and_tokenizer()
demo = create_demo()
demo.launch()