Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from langchain_huggingface import HuggingFacePipeline | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
# --- Load FAISS Vector Store --- | |
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
vector_store = FAISS.load_local( | |
"faiss_index", | |
embeddings, | |
allow_dangerous_deserialization=True | |
) | |
# --- Setup summarization chain --- | |
def setup_llm(): | |
model_id = "philschmid/bart-large-cnn-samsum" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
prompt_template = PromptTemplate.from_template( | |
"""Summarize the following Annapolis police incident reports. | |
Only include the most relevant details based on the user’s question. Focus on dates, locations, and type of crime. Write clearly and factually, like a local crime brief. | |
Reports: | |
{context} | |
User Question: {question} | |
Incident: | |
""" | |
) | |
return prompt_template | llm | StrOutputParser() | |
summary_chain = setup_llm() | |
truncate_to_token_limit = 512 # stub, or refactor if needed | |
# --- Query Handler --- | |
def run_query_with_details(query, vector_store, summary_chain, top_k=3, threshold=None, sort_mode=None, truncate_to_token_limit=None): | |
docs_and_scores = vector_store.similarity_search_with_score(query, k=int(top_k)) | |
docs = [doc for doc, _ in docs_and_scores] | |
scores = [score for _, score in docs_and_scores] | |
context = "\n\n".join([doc.page_content for doc in docs]) | |
response = summary_chain.invoke({"context": context, "question": query}) | |
metadata = "\n\n".join([ | |
f"### \U0001F6A8 Incident {i+1} — Similarity Score: `{scores[i]:.3f}`\n" | |
f"- **Incident ID**: {doc.metadata.get('incident_id', 'Unknown')}\n" | |
f"- **Incident Date**: {doc.metadata.get('incident_date', 'Unknown')}\n" | |
f"- **Incident Time**: {doc.metadata.get('incident_time', 'Unknown')}\n" | |
f"- **Crime Type**: {doc.metadata.get('crime_type', 'N/A')}\n" | |
f"- **Report Link**: [{doc.metadata.get('source_url', '')}]({doc.metadata.get('source_url', '')})\n\n" | |
f"**Details:**\n> {doc.page_content}" | |
for i, doc in enumerate(docs) | |
]) | |
return response, metadata | |
# --- Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("## ⛵ Annapolis Police Incident Chatbot") | |
gr.Markdown("Data sourced from [City of Annapolis Police Reports](https://www.annapolis.gov/list.aspx?PRVMSG=253)") | |
with gr.Row(): | |
query = gr.Textbox(label="Ask your question", placeholder="e.g., Incidents involving knives in 2024", scale=2) | |
run_button = gr.Button("Run") | |
with gr.Row(): | |
threshold_slider = gr.Slider(minimum=0.5, maximum=1.5, value=1.15, step=0.01, label="Similarity Threshold", scale=1) | |
top_k_input = gr.Number(value=3, precision=0, label="Top K Results", scale=1) | |
sort_dropdown = gr.Dropdown(choices=["Sort by Similarity", "Sort by Report Date"], value="Sort by Similarity", label="Sort Results By", scale=1) | |
response_box = gr.Textbox(label="Summary", lines=4) | |
metadata_box = gr.Markdown(label="Incident Metadata") | |
run_button.click( | |
fn=run_query_with_details, | |
inputs=[query, gr.State(vector_store), gr.State(summary_chain), top_k_input, threshold_slider, sort_dropdown], | |
outputs=[response_box, metadata_box] | |
) | |
query.submit( | |
fn=run_query_with_details, | |
inputs=[query, gr.State(vector_store), gr.State(summary_chain), top_k_input, threshold_slider, sort_dropdown], | |
outputs=[response_box, metadata_box] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |