Rag_with_Pleias / app.py
edouardlgp's picture
Update app.py
96017ad verified
import gradio as gr
import requests
import fitz # PyMuPDF
import os
import time
import traceback
from huggingface_hub import snapshot_download
from pleias_rag_interface import RAGWithCitations
from dotenv import load_dotenv
# Debugging setup
DEBUG = True
debug_messages = []
def log_debug(message):
"""Log debug messages and keep last 20 entries"""
if DEBUG:
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
full_message = f"[{timestamp}] {message}"
debug_messages.append(full_message)
print(full_message) # Print to console
# Keep only the last 20 messages
if len(debug_messages) > 20:
debug_messages.pop(0)
return "\n".join(debug_messages)
return ""
# Initialize debug logging
log_debug("Application starting...")
# Download and initialize model
#MODEL_REPO = "PleIAs/Pleias-RAG-350M"
#MODEL_CACHE_DIR = "pleias_model"
#if not os.path.exists(MODEL_CACHE_DIR):
# log_debug("Downloading model...")
# snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_CACHE_DIR)
# Load environment variables
load_dotenv()
log_debug("Initializing RAG model...")
try:
#rag = RAGWithCitations(model_path_or_name=MODEL_CACHE_DIR)
rag = RAGWithCitations(
model_path_or_name="PleIAs/Pleias-RAG-350M"
)
# model_path_or_name="1b_rag",
# max_tokens=2048, # Maximum tokens to generate (default: 2048)
# temperature=0.0, # Sampling temperature (default: 0.0)
# top_p=0.95, # Nucleus sampling parameter (default: 0.95)
# repetition_penalty=1.0, # Penalty to reduce repetition (default: 1.0)
# trust_remote_code=True, # Whether to trust remote code (default: True)
# hf_token=os.getenv("HF_TOKEN")#, # Required for downloading predefined models
# models_dir=MODEL_CACHE_DIR # Custom directory for downloaded models
# )
# Fix the warnings by properly configuring generation parameters
# if hasattr(rag, "model"):
# Configure tokenizer
# if hasattr(rag, "tokenizer"):
# if rag.tokenizer.pad_token is None:
# rag.tokenizer.pad_token = rag.tokenizer.eos_token
# rag.tokenizer.padding_side = "left" # For batch generation
# Configure model generation settings
# rag.model.config.pad_token_id = rag.tokenizer.pad_token_id
# rag.model.generation_config.pad_token_id = rag.tokenizer.pad_token_id
# Fix the do_sample/top_p warning
# rag.model.generation_config.do_sample = True
# rag.model.generation_config.top_p = 0.95 # Explicitly set to match warning
# Configure attention mask handling
# rag.model.config.use_cache = True
# log_debug("βœ… Model loaded successfully with configuration:")
# log_debug(f" - Pad token: {rag.tokenizer.pad_token} (ID: {rag.tokenizer.pad_token_id})")
# log_debug(f" - Generation config: {rag.model.generation_config}")
except Exception as e:
log_debug(f"❌ Model initialization failed: {str(e)}")
raise
## Let's a do simple test from the doc --
# Define query and sources
query = "What is the capital of France?"
log_debug(f"πŸ” Test Query: {query}")
sources = [
{
"text": "Paris is the capital and most populous city of France.",
"metadata": {"source": "Geographic Encyclopedia", "reliability": "high"}
},
{
"text": "The Eiffel Tower is located in Paris, France.",
"metadata": {"source": "Travel Guide", "year": 2020}
}
]
log_debug("πŸ“„ Test Sources loaded successfully.")
# Generate a response
try:
log_debug("🧠 Test rag model on simple example...")
# rag1 = RAGWithCitations(
# model_path_or_name="PleIAs/Pleias-RAG-350M"
# )
response = rag.generate(query,
sources #,
# do_sample=True, # Enable sampling
# top_p=0.95, # Set top_p for nucleus sampling
# pad_token_id=rag.tokenizer.eos_token_id, # Set pad_token_id to eos_token_id
# attention_mask=None # Ensure attention_mask is passed if needed
)
log_debug("βœ… Test Answer generated successfully.")
log_debug(response["processed"]["clean_answer"])
except Exception as e:
log_debug(f"❌ Test Answer generation failed: {str(e)}")
raise
def extract_text_from_pdf_url(url, debug_state):
"""Extract text from PDF with debug logging"""
debug_state = log_debug(f"πŸ“„ Fetching PDF: {url[:60]}...")
try:
start_time = time.time()
response = requests.get(url, timeout=30)
response.raise_for_status()
load_time = time.time() - start_time
debug_state = log_debug(f"⏳ PDF downloaded in {load_time:.2f}s (size: {len(response.content)/1024:.1f}KB)")
doc = fitz.open(stream=response.content, filetype="pdf")
text = ""
for page in doc:
text += page.get_text()
debug_state = log_debug(f"βœ… Extracted {len(text)} characters from PDF")
return text.strip(), debug_state
except Exception as e:
error_msg = f"❌ PDF Error: {str(e)}"
debug_state = log_debug(error_msg)
return f"[Error loading PDF: {str(e)}]", debug_state
def generate_answer(query, pdf_urls_str, debug_state=""):
"""Main processing function with debug output"""
try:
debug_state = log_debug(f"πŸ” New query: {query}")
if not query or not pdf_urls_str:
debug_state = log_debug("❌ Missing question or PDF URLs")
return "Please provide both inputs", debug_state
pdf_urls = [url.strip() for url in pdf_urls_str.strip().split("\n") if url.strip()]
sources = []
feedback = "### PDF Load Report:\n"
debug_state = log_debug(f"Processing {len(pdf_urls)} PDF URLs...")
for url in pdf_urls:
text, debug_state = extract_text_from_pdf_url(url, debug_state)
if not text.startswith("[Error"):
sources.append({"text": text, "metadata": {"source": url}})
feedback += f"- βœ… Loaded: {url[:80]}\n"
else:
feedback += f"- ❌ Failed: {url[:80]}\n"
if not sources:
debug_state = log_debug("❌ No valid PDFs processed")
return feedback + "\nNo valid PDFs processed", debug_state
debug_state = log_debug(f"🧠 Generating answer using {len(sources)} sources...")
start_time = time.time()
try:
response = rag.generate(query, sources)
gen_time = time.time() - start_time
debug_state = log_debug(f"⚑ Generation completed in {gen_time:.2f}s")
answer = response["processed"]["clean_answer"]
debug_state = log_debug(f"πŸ’‘ Answer preview: {answer[:200]}...")
debug_state = log_debug(f"πŸ› οΈ Generated in {gen_time:.2f}s")
return answer, debug_state
except Exception as e:
error_msg = f"❌ Generation error: {str(e)}"
debug_state = log_debug(error_msg)
debug_state = log_debug(traceback.format_exc())
return feedback + f"\n\n❌ Error: {str(e)}", debug_state
except Exception as e:
error_msg = f"❌ System error: {str(e)}"
debug_state = log_debug(error_msg)
debug_state = log_debug(traceback.format_exc())
return error_msg, debug_state
# Create the Gradio interface
with gr.Blocks(title="Pleias RAG QA", css="""
.debug-console {
font-family: monospace;
max-height: 400px;
overflow-y: auto !important;
background: #f5f5f5;
padding: 10px;
border-radius: 5px;
}
.debug-title {
font-weight: bold;
margin-bottom: 5px;
}
""") as demo:
gr.Markdown("# Retrieval Generation from PDF files with a 350MB Pocket Size Model from Pleias")
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Your Question", placeholder="What is this document about?")
pdf_urls = gr.Textbox(lines=5, label="PDF URLs (one per line)",
placeholder="https://example.com/doc1.pdf")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output = gr.Markdown(label="Model Response")
if DEBUG:
gr.Markdown("### Debug Console", elem_classes=["debug-title"])
debug_console = gr.Textbox(
label="",
interactive=False,
lines=15,
elem_classes=["debug-console"]
)
# Handle submission
submit_btn.click(
fn=generate_answer,
inputs=[question, pdf_urls] + ([debug_console] if DEBUG else []),
outputs=[output, debug_console] if DEBUG else [output],
)
if __name__ == "__main__":
log_debug("πŸš€ Launching interface...")
demo.launch(
server_port=7860,
server_name="0.0.0.0",
show_error=True,
debug=DEBUG
)