Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import PyPDF2 | |
import logging | |
import torch | |
import threading | |
import time | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
) | |
from transformers import logging as hf_logging | |
import spaces | |
from llama_index.core import ( | |
StorageContext, | |
VectorStoreIndex, | |
load_index_from_storage, | |
Document as LlamaDocument, | |
) | |
from llama_index.core import Settings | |
from llama_index.core.node_parser import ( | |
HierarchicalNodeParser, | |
get_leaf_nodes, | |
get_root_nodes, | |
) | |
from llama_index.core.retrievers import AutoMergingRetriever | |
from llama_index.core.storage.docstore import SimpleDocumentStore | |
from llama_index.llms.huggingface import HuggingFaceLLM | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from tqdm import tqdm | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
hf_logging.set_verbosity_error() | |
MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
if not HF_TOKEN: | |
raise ValueError("HF_TOKEN not found in environment variables") | |
# --- UI Settings --- | |
TITLE = "<h1 style='text-align:center; margin-bottom: 20px;'>Local Thinking RAG: Llama 3.1 8B</h1>" | |
DISCORD_BADGE = """<p style="text-align:center; margin-top: -10px;"> | |
<a href="https://discord.gg/openfreeai" target="_blank"> | |
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="badge"> | |
</a> | |
</p> | |
""" | |
CSS = """ | |
.upload-section { | |
max-width: 400px; | |
margin: 0 auto; | |
padding: 10px; | |
border: 2px dashed #ccc; | |
border-radius: 10px; | |
} | |
.upload-button { | |
background: #34c759 !important; | |
color: white !important; | |
border-radius: 25px !important; | |
} | |
.chatbot-container { | |
margin-top: 20px; | |
} | |
.status-output { | |
margin-top: 10px; | |
font-size: 14px; | |
} | |
.processing-info { | |
margin-top: 5px; | |
font-size: 12px; | |
color: #666; | |
} | |
.info-container { | |
margin-top: 10px; | |
padding: 10px; | |
border-radius: 5px; | |
} | |
.file-list { | |
margin-top: 0; | |
max-height: 200px; | |
overflow-y: auto; | |
padding: 5px; | |
border: 1px solid #eee; | |
border-radius: 5px; | |
} | |
.stats-box { | |
margin-top: 10px; | |
padding: 10px; | |
border-radius: 5px; | |
font-size: 12px; | |
} | |
.submit-btn { | |
background: #1a73e8 !important; | |
color: white !important; | |
border-radius: 25px !important; | |
margin-left: 10px; | |
padding: 5px 10px; | |
font-size: 16px; | |
} | |
.input-row { | |
display: flex; | |
align-items: center; | |
} | |
@media (min-width: 768px) { | |
.main-container { | |
display: flex; | |
justify-content: space-between; | |
gap: 20px; | |
} | |
.upload-section { | |
flex: 1; | |
max-width: 300px; | |
} | |
.chatbot-container { | |
flex: 2; | |
margin-top: 0; | |
} | |
} | |
""" | |
global_model = None | |
global_tokenizer = None | |
global_file_info = {} | |
def initialize_model_and_tokenizer(): | |
global global_model, global_tokenizer | |
if global_model is None or global_tokenizer is None: | |
logger.info("Initializing model and tokenizer...") | |
global_tokenizer = AutoTokenizer.from_pretrained(MODEL, token=HF_TOKEN) | |
global_model = AutoModelForCausalLM.from_pretrained( | |
MODEL, | |
device_map="auto", | |
trust_remote_code=True, | |
token=HF_TOKEN, | |
torch_dtype=torch.float16 | |
) | |
logger.info("Model and tokenizer initialized successfully") | |
def get_llm(temperature=0.7, max_new_tokens=256, top_p=0.95, top_k=50): | |
global global_model, global_tokenizer | |
if global_model is None or global_tokenizer is None: | |
initialize_model_and_tokenizer() | |
return HuggingFaceLLM( | |
context_window=4096, | |
max_new_tokens=max_new_tokens, | |
tokenizer=global_tokenizer, | |
model=global_model, | |
generate_kwargs={ | |
"do_sample": True, | |
"temperature": temperature, | |
"top_k": top_k, | |
"top_p": top_p | |
} | |
) | |
def extract_text_from_document(file): | |
file_name = file.name | |
file_extension = os.path.splitext(file_name)[1].lower() | |
if file_extension == '.txt': | |
text = file.read().decode('utf-8') | |
return text, len(text.split()), None | |
elif file_extension == '.pdf': | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "\n\n".join(page.extract_text() for page in pdf_reader.pages) | |
return text, len(text.split()), None | |
else: | |
return None, 0, ValueError(f"Unsupported file format: {file_extension}") | |
def create_or_update_index(files, request: gr.Request): | |
global global_file_info | |
if not files: | |
return "Please provide files.", "" | |
start_time = time.time() | |
user_id = request.session_hash | |
save_dir = f"./{user_id}_index" | |
# Initialize LlamaIndex modules | |
llm = get_llm() | |
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN) | |
Settings.llm = llm | |
Settings.embed_model = embed_model | |
file_stats = [] | |
new_documents = [] | |
for file in tqdm(files, desc="Processing files"): | |
file_basename = os.path.basename(file.name) | |
text, word_count, error = extract_text_from_document(file) | |
if error: | |
logger.error(f"Error processing file {file_basename}: {str(error)}") | |
file_stats.append({ | |
"name": file_basename, | |
"words": 0, | |
"status": f"error: {str(error)}" | |
}) | |
continue | |
doc = LlamaDocument( | |
text=text, | |
metadata={ | |
"file_name": file_basename, | |
"word_count": word_count, | |
"source": "user_upload" | |
} | |
) | |
new_documents.append(doc) | |
file_stats.append({ | |
"name": file_basename, | |
"words": word_count, | |
"status": "processed" | |
}) | |
global_file_info[file_basename] = { | |
"word_count": word_count, | |
"processed_at": time.time() | |
} | |
node_parser = HierarchicalNodeParser.from_defaults( | |
chunk_sizes=[2048, 512, 128], | |
chunk_overlap=20 | |
) | |
logger.info(f"Parsing {len(new_documents)} documents into hierarchical nodes") | |
new_nodes = node_parser.get_nodes_from_documents(new_documents) | |
new_leaf_nodes = get_leaf_nodes(new_nodes) | |
new_root_nodes = get_root_nodes(new_nodes) | |
logger.info(f"Generated {len(new_nodes)} total nodes ({len(new_root_nodes)} root, {len(new_leaf_nodes)} leaf)") | |
if os.path.exists(save_dir): | |
logger.info(f"Loading existing index from {save_dir}") | |
storage_context = StorageContext.from_defaults(persist_dir=save_dir) | |
index = load_index_from_storage(storage_context, settings=Settings) | |
docstore = storage_context.docstore | |
docstore.add_documents(new_nodes) | |
for node in tqdm(new_leaf_nodes, desc="Adding leaf nodes to index"): | |
index.insert_nodes([node]) | |
total_docs = len(docstore.docs) | |
logger.info(f"Updated index with {len(new_nodes)} new nodes from {len(new_documents)} files") | |
else: | |
logger.info("Creating new index") | |
docstore = SimpleDocumentStore() | |
storage_context = StorageContext.from_defaults(docstore=docstore) | |
docstore.add_documents(new_nodes) | |
index = VectorStoreIndex( | |
new_leaf_nodes, | |
storage_context=storage_context, | |
settings=Settings | |
) | |
total_docs = len(new_documents) | |
logger.info(f"Created new index with {len(new_nodes)} nodes from {len(new_documents)} files") | |
index.storage_context.persist(persist_dir=save_dir) | |
# custom outputs after processing files | |
file_list_html = "<div class='file-list'>" | |
for stat in file_stats: | |
status_color = "#4CAF50" if stat["status"] == "processed" else "#f44336" | |
file_list_html += f"<div><span style='color:{status_color}'>●</span> {stat['name']} - {stat['words']} words</div>" | |
file_list_html += "</div>" | |
processing_time = time.time() - start_time | |
stats_output = f"<div class='stats-box'>" | |
stats_output += f"✓ Processed {len(files)} files in {processing_time:.2f} seconds<br>" | |
stats_output += f"✓ Created {len(new_nodes)} nodes ({len(new_leaf_nodes)} leaf nodes)<br>" | |
stats_output += f"✓ Total documents in index: {total_docs}<br>" | |
stats_output += f"✓ Index saved to: {save_dir}<br>" | |
stats_output += "</div>" | |
output_container = f"<div class='info-container'>" | |
output_container += file_list_html | |
output_container += stats_output | |
output_container += "</div>" | |
return f"Successfully indexed {len(files)} files.", output_container | |
def stream_chat( | |
message: str, | |
history: list, | |
system_prompt: str, | |
temperature: float, | |
max_new_tokens: int, | |
top_p: float, | |
top_k: int, | |
penalty: float, | |
retriever_k: int, | |
merge_threshold: float, | |
request: gr.Request | |
): | |
if not request: | |
yield history + [{"role": "assistant", "content": "Session initialization failed. Please refresh the page."}] | |
return | |
user_id = request.session_hash | |
index_dir = f"./{user_id}_index" | |
if not os.path.exists(index_dir): | |
yield history + [{"role": "assistant", "content": "Please upload documents first."}] | |
return | |
max_new_tokens = int(max_new_tokens) if isinstance(max_new_tokens, (int, float)) else 1024 | |
temperature = float(temperature) if isinstance(temperature, (int, float)) else 0.9 | |
top_p = float(top_p) if isinstance(top_p, (int, float)) else 0.95 | |
top_k = int(top_k) if isinstance(top_k, (int, float)) else 50 | |
penalty = float(penalty) if isinstance(penalty, (int, float)) else 1.2 | |
retriever_k = int(retriever_k) if isinstance(retriever_k, (int, float)) else 15 | |
merge_threshold = float(merge_threshold) if isinstance(merge_threshold, (int, float)) else 0.5 | |
llm = get_llm(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k) | |
embed_model = HuggingFaceEmbedding(model_name=EMBEDDING_MODEL, token=HF_TOKEN) | |
Settings.llm = llm | |
Settings.embed_model = embed_model | |
storage_context = StorageContext.from_defaults(persist_dir=index_dir) | |
index = load_index_from_storage(storage_context, settings=Settings) | |
base_retriever = index.as_retriever(similarity_top_k=retriever_k) | |
auto_merging_retriever = AutoMergingRetriever( | |
base_retriever, | |
storage_context=storage_context, | |
simple_ratio_thresh=merge_threshold, | |
verbose=True | |
) | |
logger.info(f"Query: {message}") | |
retrieval_start = time.time() | |
base_nodes = base_retriever.retrieve(message) | |
logger.info(f"Retrieved {len(base_nodes)} base nodes in {time.time() - retrieval_start:.2f}s") | |
base_file_sources = {} | |
for node in base_nodes: | |
if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata: | |
file_name = node.node.metadata['file_name'] | |
if file_name not in base_file_sources: | |
base_file_sources[file_name] = 0 | |
base_file_sources[file_name] += 1 | |
logger.info(f"Base retrieval file distribution: {base_file_sources}") | |
merging_start = time.time() | |
merged_nodes = auto_merging_retriever.retrieve(message) | |
logger.info(f"Retrieved {len(merged_nodes)} merged nodes in {time.time() - merging_start:.2f}s") | |
merged_file_sources = {} | |
for node in merged_nodes: | |
if hasattr(node.node, 'metadata') and 'file_name' in node.node.metadata: | |
file_name = node.node.metadata['file_name'] | |
if file_name not in merged_file_sources: | |
merged_file_sources[file_name] = 0 | |
merged_file_sources[file_name] += 1 | |
logger.info(f"Merged retrieval file distribution: {merged_file_sources}") | |
context = "\n\n".join([n.node.text for n in merged_nodes]) | |
source_info = "" | |
if merged_file_sources: | |
source_info = "\n\nRetrieved information from files: " + ", ".join(merged_file_sources.keys()) | |
formatted_system_prompt = f"{system_prompt}\n\nDocument Context:\n{context}{source_info}" | |
messages = [{"role": "system", "content": formatted_system_prompt}] | |
for entry in history: | |
messages.append(entry) | |
messages.append({"role": "user", "content": message}) | |
prompt = global_tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
stop_event = threading.Event() | |
class StopOnEvent(StoppingCriteria): | |
def __init__(self, stop_event): | |
super().__init__() | |
self.stop_event = stop_event | |
def __call__(self, input_ids, scores, **kwargs): | |
return self.stop_event.is_set() | |
stopping_criteria = StoppingCriteriaList([StopOnEvent(stop_event)]) | |
streamer = TextIteratorStreamer( | |
global_tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device) | |
generation_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=penalty, | |
do_sample=True, | |
stopping_criteria=stopping_criteria | |
) | |
thread = threading.Thread(target=global_model.generate, kwargs=generation_kwargs) | |
thread.start() | |
updated_history = history + [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": ""} | |
] | |
yield updated_history | |
partial_response = "" | |
try: | |
for new_text in streamer: | |
partial_response += new_text | |
updated_history[-1]["content"] = partial_response | |
yield updated_history | |
yield updated_history | |
except GeneratorExit: | |
stop_event.set() | |
thread.join() | |
raise | |
def create_demo(): | |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
# Title | |
gr.HTML(TITLE) | |
# Discord badge immediately under the title | |
gr.HTML(DISCORD_BADGE) | |
with gr.Row(elem_classes="main-container"): | |
with gr.Column(elem_classes="upload-section"): | |
file_upload = gr.File( | |
file_count="multiple", | |
label="Drag & Drop PDF/TXT Files Here", | |
file_types=[".pdf", ".txt"], | |
elem_id="file-upload" | |
) | |
upload_button = gr.Button("Upload & Index", elem_classes="upload-button") | |
status_output = gr.Textbox( | |
label="Status", | |
placeholder="Upload files to start...", | |
interactive=False | |
) | |
file_info_output = gr.HTML( | |
label="File Information", | |
elem_classes="processing-info" | |
) | |
upload_button.click( | |
fn=create_or_update_index, | |
inputs=[file_upload], | |
outputs=[status_output, file_info_output] | |
) | |
with gr.Column(elem_classes="chatbot-container"): | |
chatbot = gr.Chatbot( | |
height=500, | |
placeholder="Chat with your documents...", | |
show_label=False, | |
type="messages" | |
) | |
with gr.Row(elem_classes="input-row"): | |
message_input = gr.Textbox( | |
placeholder="Type your question here...", | |
show_label=False, | |
container=False, | |
lines=1, | |
scale=8 | |
) | |
submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1) | |
with gr.Accordion("Advanced Settings", open=False): | |
system_prompt = gr.Textbox( | |
value="You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside tags, and then provide your solution or response to the problem. As a knowledgeable assistant, provide detailed answers using the relevant information from all uploaded documents.", | |
label="System Prompt", | |
lines=3 | |
) | |
with gr.Tab("Generation Parameters"): | |
temperature = gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.9, | |
label="Temperature" | |
) | |
max_new_tokens = gr.Slider( | |
minimum=128, | |
maximum=8192, | |
step=64, | |
value=1024, | |
label="Max New Tokens", | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=0.95, | |
label="Top P" | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=50, | |
label="Top K" | |
) | |
penalty = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=1.2, | |
label="Repetition Penalty" | |
) | |
with gr.Tab("Retrieval Parameters"): | |
retriever_k = gr.Slider( | |
minimum=5, | |
maximum=30, | |
step=1, | |
value=15, | |
label="Initial Retrieval Size (Top K)" | |
) | |
merge_threshold = gr.Slider( | |
minimum=0.1, | |
maximum=0.9, | |
step=0.1, | |
value=0.5, | |
label="Merge Threshold (lower = more merging)" | |
) | |
submit_button.click( | |
fn=stream_chat, | |
inputs=[ | |
message_input, | |
chatbot, | |
system_prompt, | |
temperature, | |
max_new_tokens, | |
top_p, | |
top_k, | |
penalty, | |
retriever_k, | |
merge_threshold | |
], | |
outputs=chatbot | |
) | |
message_input.submit( | |
fn=stream_chat, | |
inputs=[ | |
message_input, | |
chatbot, | |
system_prompt, | |
temperature, | |
max_new_tokens, | |
top_p, | |
top_k, | |
penalty, | |
retriever_k, | |
merge_threshold | |
], | |
outputs=chatbot | |
) | |
return demo | |
if __name__ == "__main__": | |
initialize_model_and_tokenizer() | |
demo = create_demo() | |
demo.launch() | |