File size: 1,904 Bytes
f81bf22
446f799
f81bf22
 
 
 
 
 
 
 
 
 
c4892ac
f81bf22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# streamlit_app.py
import streamlit as st
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import nltk
from nltk.corpus import stopwords
from nltk.stem.snowball import SnowballStemmer
from nltk.tokenize import word_tokenize
import re


# ---------- initial setup ----------
nltk.download("stopwords", quiet=True)
nltk.download("punkt", quiet=True)

stemmer = SnowballStemmer("english")
stop_words = set(stopwords.words("english"))

def tokenizer(text: str):
    # basic cleanup → NLTK tokenize → stem
    text = re.sub(r"[^a-zA-Z0-9\s]", " ", text.lower())
    tokens = word_tokenize(text)
    return [stemmer.stem(tok) for tok in tokens if tok not in stop_words and tok.isalnum()]

@st.cache_data(show_spinner="Loading data & building index…")
def load_and_index():
    # first 1 000 docs only
    ds = load_dataset("webis/tldr-17", split="train[:1000]")
    docs = ds["content"]
    vec = TfidfVectorizer(tokenizer=tokenizer)
    matrix = vec.fit_transform(docs)
    return docs, vec, matrix

docs, vectorizer, tfidf_matrix = load_and_index()

# ---------- UI ----------
st.markdown(
    """
    <style>
    .stTextInput > div {width:100%; display:flex; justify-content:center;}
    </style>
    """,
    unsafe_allow_html=True,
)

st.markdown("## TF-IDF Reddit Search")
query = st.text_input(" ", key="query", placeholder="Search…", label_visibility="hidden")

# ---------- search ----------
if query:
    q_vec = vectorizer.transform([query])
    sims = cosine_similarity(q_vec, tfidf_matrix).flatten()
    top_idx = sims.argsort()[::-1]  # high→low
    res_df = pd.DataFrame(
        {"similarity": sims[top_idx], "document": [docs[i] for i in top_idx]}
    )
    st.dataframe(
        res_df.style.format({"similarity": "{:.3f}"}), use_container_width=True
    )