Update app.py
Browse files
app.py
CHANGED
@@ -10,14 +10,15 @@ st.set_page_config(page_title="Transflower πΈ", page_icon="πΌ", layout="cent
|
|
10 |
|
11 |
st.markdown(
|
12 |
"<h1 style='text-align: center; color: pink;'>πΈ Transflower πΈ</h1>"
|
13 |
-
|
14 |
unsafe_allow_html=True,
|
15 |
)
|
16 |
|
17 |
# Load model and tokenizer
|
18 |
model_name = "t5-small"
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name
|
|
|
21 |
|
22 |
# Input area
|
23 |
user_input = st.text_area("πΌ Enter text to summarize or visualize:", height=200)
|
@@ -26,24 +27,25 @@ if st.button("β¨ Visualize Transformer Magic β¨"):
|
|
26 |
if not user_input.strip():
|
27 |
st.warning("Please enter some text to visualize.")
|
28 |
else:
|
29 |
-
#
|
30 |
-
|
31 |
-
|
32 |
-
# Forward pass
|
33 |
with torch.no_grad():
|
34 |
-
|
35 |
-
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
39 |
|
40 |
-
|
41 |
-
st.
|
42 |
-
fig, ax = plt.subplots(figsize=(10, 5))
|
43 |
|
44 |
-
|
45 |
-
attention_data = outputs.attentions[-1] # List of attention tensors from each layer
|
46 |
-
avg_attention = attention_data[0].mean(dim=0).squeeze().detach().numpy() # mean over heads
|
47 |
|
48 |
-
|
|
|
|
|
49 |
st.pyplot(fig)
|
|
|
10 |
|
11 |
st.markdown(
|
12 |
"<h1 style='text-align: center; color: pink;'>πΈ Transflower πΈ</h1>"
|
13 |
+
|
14 |
unsafe_allow_html=True,
|
15 |
)
|
16 |
|
17 |
# Load model and tokenizer
|
18 |
model_name = "t5-small"
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
20 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
21 |
+
model.eval()
|
22 |
|
23 |
# Input area
|
24 |
user_input = st.text_area("πΌ Enter text to summarize or visualize:", height=200)
|
|
|
27 |
if not user_input.strip():
|
28 |
st.warning("Please enter some text to visualize.")
|
29 |
else:
|
30 |
+
# Encode input
|
31 |
+
inputs = tokenizer("summarize: " + user_input, return_tensors="pt", truncation=True)
|
32 |
+
|
33 |
+
# Forward pass manually to get attention
|
34 |
with torch.no_grad():
|
35 |
+
encoder_outputs = model.encoder(**inputs, output_attentions=True, return_dict=True)
|
36 |
+
attention = encoder_outputs.attentions[-1][0].mean(dim=0).detach().numpy()
|
37 |
|
38 |
+
# Generate summary
|
39 |
+
with torch.no_grad():
|
40 |
+
summary_ids = model.generate(inputs["input_ids"], max_length=50)
|
41 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
42 |
|
43 |
+
st.subheader("πΈ Summary:")
|
44 |
+
st.success(summary)
|
|
|
45 |
|
46 |
+
st.subheader("π Encoder Attention Heatmap:")
|
|
|
|
|
47 |
|
48 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
49 |
+
sns.heatmap(attention, cmap="YlGnBu", ax=ax)
|
50 |
+
ax.set_title("Encoder Self-Attention Heatmap π«")
|
51 |
st.pyplot(fig)
|