Update src/streamlit_app.py
Browse files- src/streamlit_app.py +59 -39
src/streamlit_app.py
CHANGED
@@ -1,40 +1,60 @@
|
|
1 |
-
|
2 |
-
import numpy as np
|
3 |
-
import pandas as pd
|
4 |
import streamlit as st
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
.
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# streamlit_app.py
|
|
|
|
|
2 |
import streamlit as st
|
3 |
+
from datasets import load_dataset
|
4 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
+
import pandas as pd
|
7 |
+
import nltk
|
8 |
+
from nltk.corpus import stopwords
|
9 |
+
from nltk.stem.snowball import SnowballStemmer
|
10 |
+
from nltk.tokenize import word_tokenize
|
11 |
+
import re
|
12 |
+
|
13 |
+
# ---------- initial setup ----------
|
14 |
+
nltk.download("stopwords", quiet=True)
|
15 |
+
nltk.download("punkt", quiet=True)
|
16 |
+
|
17 |
+
stemmer = SnowballStemmer("english")
|
18 |
+
stop_words = set(stopwords.words("english"))
|
19 |
+
|
20 |
+
def tokenizer(text: str):
|
21 |
+
# basic cleanup → NLTK tokenize → stem
|
22 |
+
text = re.sub(r"[^a-zA-Z0-9\s]", " ", text.lower())
|
23 |
+
tokens = word_tokenize(text)
|
24 |
+
return [stemmer.stem(tok) for tok in tokens if tok not in stop_words and tok.isalnum()]
|
25 |
+
|
26 |
+
@st.cache_data(show_spinner="Loading data & building index…")
|
27 |
+
def load_and_index():
|
28 |
+
# first 1 000 docs only
|
29 |
+
ds = load_dataset("webis/tldr-17", split="train[:1000]")
|
30 |
+
docs = ds["content"]
|
31 |
+
vec = TfidfVectorizer(tokenizer=tokenizer)
|
32 |
+
matrix = vec.fit_transform(docs)
|
33 |
+
return docs, vec, matrix
|
34 |
+
|
35 |
+
docs, vectorizer, tfidf_matrix = load_and_index()
|
36 |
+
|
37 |
+
# ---------- UI ----------
|
38 |
+
st.markdown(
|
39 |
+
"""
|
40 |
+
<style>
|
41 |
+
.stTextInput > div {width:100%; display:flex; justify-content:center;}
|
42 |
+
</style>
|
43 |
+
""",
|
44 |
+
unsafe_allow_html=True,
|
45 |
+
)
|
46 |
+
|
47 |
+
st.markdown("## TF-IDF Reddit Search")
|
48 |
+
query = st.text_input(" ", key="query", placeholder="Search…", label_visibility="hidden")
|
49 |
+
|
50 |
+
# ---------- search ----------
|
51 |
+
if query:
|
52 |
+
q_vec = vectorizer.transform([query])
|
53 |
+
sims = cosine_similarity(q_vec, tfidf_matrix).flatten()
|
54 |
+
top_idx = sims.argsort()[::-1] # high→low
|
55 |
+
res_df = pd.DataFrame(
|
56 |
+
{"similarity": sims[top_idx], "document": [docs[i] for i in top_idx]}
|
57 |
+
)
|
58 |
+
st.dataframe(
|
59 |
+
res_df.style.format({"similarity": "{:.3f}"}), use_container_width=True
|
60 |
+
)
|