information-retrieval-demo / src /streamlit_app.py
julian-schelb's picture
Update src/streamlit_app.py
c4892ac verified
# streamlit_app.py
import streamlit as st
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import nltk
from nltk.corpus import stopwords
from nltk.stem.snowball import SnowballStemmer
from nltk.tokenize import word_tokenize
import re
# ---------- initial setup ----------
nltk.download("stopwords", quiet=True)
nltk.download("punkt", quiet=True)
stemmer = SnowballStemmer("english")
stop_words = set(stopwords.words("english"))
def tokenizer(text: str):
# basic cleanup → NLTK tokenize → stem
text = re.sub(r"[^a-zA-Z0-9\s]", " ", text.lower())
tokens = word_tokenize(text)
return [stemmer.stem(tok) for tok in tokens if tok not in stop_words and tok.isalnum()]
@st.cache_data(show_spinner="Loading data & building index…")
def load_and_index():
# first 1 000 docs only
ds = load_dataset("webis/tldr-17", split="train[:1000]")
docs = ds["content"]
vec = TfidfVectorizer(tokenizer=tokenizer)
matrix = vec.fit_transform(docs)
return docs, vec, matrix
docs, vectorizer, tfidf_matrix = load_and_index()
# ---------- UI ----------
st.markdown(
"""
<style>
.stTextInput > div {width:100%; display:flex; justify-content:center;}
</style>
""",
unsafe_allow_html=True,
)
st.markdown("## TF-IDF Reddit Search")
query = st.text_input(" ", key="query", placeholder="Search…", label_visibility="hidden")
# ---------- search ----------
if query:
q_vec = vectorizer.transform([query])
sims = cosine_similarity(q_vec, tfidf_matrix).flatten()
top_idx = sims.argsort()[::-1] # high→low
res_df = pd.DataFrame(
{"similarity": sims[top_idx], "document": [docs[i] for i in top_idx]}
)
st.dataframe(
res_df.style.format({"similarity": "{:.3f}"}), use_container_width=True
)