Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import fitz # PyMuPDF | |
import os | |
import time | |
import traceback | |
from huggingface_hub import snapshot_download | |
from pleias_rag_interface import RAGWithCitations | |
from dotenv import load_dotenv | |
# Debugging setup | |
DEBUG = True | |
debug_messages = [] | |
def log_debug(message): | |
"""Log debug messages and keep last 20 entries""" | |
if DEBUG: | |
timestamp = time.strftime("%Y-%m-%d %H:%M:%S") | |
full_message = f"[{timestamp}] {message}" | |
debug_messages.append(full_message) | |
print(full_message) # Print to console | |
# Keep only the last 20 messages | |
if len(debug_messages) > 20: | |
debug_messages.pop(0) | |
return "\n".join(debug_messages) | |
return "" | |
# Initialize debug logging | |
log_debug("Application starting...") | |
# Download and initialize model | |
#MODEL_REPO = "PleIAs/Pleias-RAG-350M" | |
#MODEL_CACHE_DIR = "pleias_model" | |
#if not os.path.exists(MODEL_CACHE_DIR): | |
# log_debug("Downloading model...") | |
# snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_CACHE_DIR) | |
# Load environment variables | |
load_dotenv() | |
log_debug("Initializing RAG model...") | |
try: | |
#rag = RAGWithCitations(model_path_or_name=MODEL_CACHE_DIR) | |
rag = RAGWithCitations( | |
model_path_or_name="PleIAs/Pleias-RAG-350M" | |
) | |
# model_path_or_name="1b_rag", | |
# max_tokens=2048, # Maximum tokens to generate (default: 2048) | |
# temperature=0.0, # Sampling temperature (default: 0.0) | |
# top_p=0.95, # Nucleus sampling parameter (default: 0.95) | |
# repetition_penalty=1.0, # Penalty to reduce repetition (default: 1.0) | |
# trust_remote_code=True, # Whether to trust remote code (default: True) | |
# hf_token=os.getenv("HF_TOKEN")#, # Required for downloading predefined models | |
# models_dir=MODEL_CACHE_DIR # Custom directory for downloaded models | |
# ) | |
# Fix the warnings by properly configuring generation parameters | |
# if hasattr(rag, "model"): | |
# Configure tokenizer | |
# if hasattr(rag, "tokenizer"): | |
# if rag.tokenizer.pad_token is None: | |
# rag.tokenizer.pad_token = rag.tokenizer.eos_token | |
# rag.tokenizer.padding_side = "left" # For batch generation | |
# Configure model generation settings | |
# rag.model.config.pad_token_id = rag.tokenizer.pad_token_id | |
# rag.model.generation_config.pad_token_id = rag.tokenizer.pad_token_id | |
# Fix the do_sample/top_p warning | |
# rag.model.generation_config.do_sample = True | |
# rag.model.generation_config.top_p = 0.95 # Explicitly set to match warning | |
# Configure attention mask handling | |
# rag.model.config.use_cache = True | |
# log_debug("β Model loaded successfully with configuration:") | |
# log_debug(f" - Pad token: {rag.tokenizer.pad_token} (ID: {rag.tokenizer.pad_token_id})") | |
# log_debug(f" - Generation config: {rag.model.generation_config}") | |
except Exception as e: | |
log_debug(f"β Model initialization failed: {str(e)}") | |
raise | |
## Let's a do simple test from the doc -- | |
# Define query and sources | |
query = "What is the capital of France?" | |
log_debug(f"π Test Query: {query}") | |
sources = [ | |
{ | |
"text": "Paris is the capital and most populous city of France.", | |
"metadata": {"source": "Geographic Encyclopedia", "reliability": "high"} | |
}, | |
{ | |
"text": "The Eiffel Tower is located in Paris, France.", | |
"metadata": {"source": "Travel Guide", "year": 2020} | |
} | |
] | |
log_debug("π Test Sources loaded successfully.") | |
# Generate a response | |
try: | |
log_debug("π§ Test rag model on simple example...") | |
# rag1 = RAGWithCitations( | |
# model_path_or_name="PleIAs/Pleias-RAG-350M" | |
# ) | |
response = rag.generate(query, | |
sources #, | |
# do_sample=True, # Enable sampling | |
# top_p=0.95, # Set top_p for nucleus sampling | |
# pad_token_id=rag.tokenizer.eos_token_id, # Set pad_token_id to eos_token_id | |
# attention_mask=None # Ensure attention_mask is passed if needed | |
) | |
log_debug("β Test Answer generated successfully.") | |
log_debug(response["processed"]["clean_answer"]) | |
except Exception as e: | |
log_debug(f"β Test Answer generation failed: {str(e)}") | |
raise | |
def extract_text_from_pdf_url(url, debug_state): | |
"""Extract text from PDF with debug logging""" | |
debug_state = log_debug(f"π Fetching PDF: {url[:60]}...") | |
try: | |
start_time = time.time() | |
response = requests.get(url, timeout=30) | |
response.raise_for_status() | |
load_time = time.time() - start_time | |
debug_state = log_debug(f"β³ PDF downloaded in {load_time:.2f}s (size: {len(response.content)/1024:.1f}KB)") | |
doc = fitz.open(stream=response.content, filetype="pdf") | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
debug_state = log_debug(f"β Extracted {len(text)} characters from PDF") | |
return text.strip(), debug_state | |
except Exception as e: | |
error_msg = f"β PDF Error: {str(e)}" | |
debug_state = log_debug(error_msg) | |
return f"[Error loading PDF: {str(e)}]", debug_state | |
def generate_answer(query, pdf_urls_str, debug_state=""): | |
"""Main processing function with debug output""" | |
try: | |
debug_state = log_debug(f"π New query: {query}") | |
if not query or not pdf_urls_str: | |
debug_state = log_debug("β Missing question or PDF URLs") | |
return "Please provide both inputs", debug_state | |
pdf_urls = [url.strip() for url in pdf_urls_str.strip().split("\n") if url.strip()] | |
sources = [] | |
feedback = "### PDF Load Report:\n" | |
debug_state = log_debug(f"Processing {len(pdf_urls)} PDF URLs...") | |
for url in pdf_urls: | |
text, debug_state = extract_text_from_pdf_url(url, debug_state) | |
if not text.startswith("[Error"): | |
sources.append({"text": text, "metadata": {"source": url}}) | |
feedback += f"- β Loaded: {url[:80]}\n" | |
else: | |
feedback += f"- β Failed: {url[:80]}\n" | |
if not sources: | |
debug_state = log_debug("β No valid PDFs processed") | |
return feedback + "\nNo valid PDFs processed", debug_state | |
debug_state = log_debug(f"π§ Generating answer using {len(sources)} sources...") | |
start_time = time.time() | |
try: | |
response = rag.generate(query, sources) | |
gen_time = time.time() - start_time | |
debug_state = log_debug(f"β‘ Generation completed in {gen_time:.2f}s") | |
answer = response["processed"]["clean_answer"] | |
debug_state = log_debug(f"π‘ Answer preview: {answer[:200]}...") | |
debug_state = log_debug(f"π οΈ Generated in {gen_time:.2f}s") | |
return answer, debug_state | |
except Exception as e: | |
error_msg = f"β Generation error: {str(e)}" | |
debug_state = log_debug(error_msg) | |
debug_state = log_debug(traceback.format_exc()) | |
return feedback + f"\n\nβ Error: {str(e)}", debug_state | |
except Exception as e: | |
error_msg = f"β System error: {str(e)}" | |
debug_state = log_debug(error_msg) | |
debug_state = log_debug(traceback.format_exc()) | |
return error_msg, debug_state | |
# Create the Gradio interface | |
with gr.Blocks(title="Pleias RAG QA", css=""" | |
.debug-console { | |
font-family: monospace; | |
max-height: 400px; | |
overflow-y: auto !important; | |
background: #f5f5f5; | |
padding: 10px; | |
border-radius: 5px; | |
} | |
.debug-title { | |
font-weight: bold; | |
margin-bottom: 5px; | |
} | |
""") as demo: | |
gr.Markdown("# Retrieval Generation from PDF files with a 350MB Pocket Size Model from Pleias") | |
with gr.Row(): | |
with gr.Column(): | |
question = gr.Textbox(label="Your Question", placeholder="What is this document about?") | |
pdf_urls = gr.Textbox(lines=5, label="PDF URLs (one per line)", | |
placeholder="https://example.com/doc1.pdf") | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
output = gr.Markdown(label="Model Response") | |
if DEBUG: | |
gr.Markdown("### Debug Console", elem_classes=["debug-title"]) | |
debug_console = gr.Textbox( | |
label="", | |
interactive=False, | |
lines=15, | |
elem_classes=["debug-console"] | |
) | |
# Handle submission | |
submit_btn.click( | |
fn=generate_answer, | |
inputs=[question, pdf_urls] + ([debug_console] if DEBUG else []), | |
outputs=[output, debug_console] if DEBUG else [output], | |
) | |
if __name__ == "__main__": | |
log_debug("π Launching interface...") | |
demo.launch( | |
server_port=7860, | |
server_name="0.0.0.0", | |
show_error=True, | |
debug=DEBUG | |
) |