bengballard's picture
updating app
cf84a8a
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import SentenceTransformerEmbeddings
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
# --- Load FAISS Vector Store ---
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store = FAISS.load_local(
"faiss_index",
embeddings,
allow_dangerous_deserialization=True
)
# --- Setup summarization chain ---
def setup_llm():
model_id = "philschmid/bart-large-cnn-samsum"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300)
llm = HuggingFacePipeline(pipeline=pipe)
prompt_template = PromptTemplate.from_template(
"""Summarize the following Annapolis police incident reports.
Only include the most relevant details based on the user’s question. Focus on dates, locations, and type of crime. Write clearly and factually, like a local crime brief.
Reports:
{context}
User Question: {question}
Incident:
"""
)
return prompt_template | llm | StrOutputParser()
summary_chain = setup_llm()
truncate_to_token_limit = 512 # stub, or refactor if needed
# --- Query Handler ---
def run_query_with_details(query, vector_store, summary_chain, top_k=3, threshold=None, sort_mode=None, truncate_to_token_limit=None):
docs_and_scores = vector_store.similarity_search_with_score(query, k=int(top_k))
docs = [doc for doc, _ in docs_and_scores]
scores = [score for _, score in docs_and_scores]
context = "\n\n".join([doc.page_content for doc in docs])
response = summary_chain.invoke({"context": context, "question": query})
metadata = "\n\n".join([
f"### \U0001F6A8 Incident {i+1} — Similarity Score: `{scores[i]:.3f}`\n"
f"- **Incident ID**: {doc.metadata.get('incident_id', 'Unknown')}\n"
f"- **Incident Date**: {doc.metadata.get('incident_date', 'Unknown')}\n"
f"- **Incident Time**: {doc.metadata.get('incident_time', 'Unknown')}\n"
f"- **Crime Type**: {doc.metadata.get('crime_type', 'N/A')}\n"
f"- **Report Link**: [{doc.metadata.get('source_url', '')}]({doc.metadata.get('source_url', '')})\n\n"
f"**Details:**\n> {doc.page_content}"
for i, doc in enumerate(docs)
])
return response, metadata
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("## ⛵ Annapolis Police Incident Chatbot")
gr.Markdown("Data sourced from [City of Annapolis Police Reports](https://www.annapolis.gov/list.aspx?PRVMSG=253)")
with gr.Row():
query = gr.Textbox(label="Ask your question", placeholder="e.g., Incidents involving knives in 2024", scale=2)
run_button = gr.Button("Run")
with gr.Row():
threshold_slider = gr.Slider(minimum=0.5, maximum=1.5, value=1.15, step=0.01, label="Similarity Threshold", scale=1)
top_k_input = gr.Number(value=3, precision=0, label="Top K Results", scale=1)
sort_dropdown = gr.Dropdown(choices=["Sort by Similarity", "Sort by Report Date"], value="Sort by Similarity", label="Sort Results By", scale=1)
response_box = gr.Textbox(label="Summary", lines=4)
metadata_box = gr.Markdown(label="Incident Metadata")
run_button.click(
fn=run_query_with_details,
inputs=[query, gr.State(vector_store), gr.State(summary_chain), top_k_input, threshold_slider, sort_dropdown],
outputs=[response_box, metadata_box]
)
query.submit(
fn=run_query_with_details,
inputs=[query, gr.State(vector_store), gr.State(summary_chain), top_k_input, threshold_slider, sort_dropdown],
outputs=[response_box, metadata_box]
)
if __name__ == "__main__":
demo.launch()