Update app.py
Browse files
app.py
CHANGED
@@ -1,45 +1,38 @@
|
|
1 |
import streamlit as st
|
2 |
-
from
|
3 |
-
from ultralytics import YOLO
|
4 |
import torch
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
#
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
model
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
for r in results:
|
28 |
-
rendered_img = r.plot() # Draws boxes and labels on the image
|
29 |
-
st.image(rendered_img, caption="Detected Image", use_container_width=True)
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
detections = results[0].boxes.data.cpu().numpy()
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
class_id = int(det[5])
|
43 |
-
label = names[class_id]
|
44 |
-
if label in animal_labels:
|
45 |
-
st.markdown(f"- **{label}** (Confidence: {det[4]:.2f})")
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
3 |
import torch
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import seaborn as sns
|
6 |
+
|
7 |
+
# App Title and Styling
|
8 |
+
st.set_page_config(page_title="Transflower 🌸", layout="centered")
|
9 |
+
st.markdown("<h1 style='text-align: center; color: #D16BA5;'>Transflower 🌸</h1>", unsafe_allow_html=True)
|
10 |
+
st.markdown("<p style='text-align: center; color: #8E44AD;'>A girly & elegant app to visualize Transformer models</p>", unsafe_allow_html=True)
|
11 |
+
|
12 |
+
# Load model and tokenizer
|
13 |
+
model_name = "sshleifer/distilbart-cnn-12-6"
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
|
16 |
+
|
17 |
+
# Text Input
|
18 |
+
input_text = st.text_area("🌼 Enter text to summarize or visualize:", height=150)
|
19 |
+
|
20 |
+
# When user clicks the button
|
21 |
+
if st.button("✨ Visualize Transformer Magic ✨") and input_text:
|
22 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
23 |
+
with torch.no_grad():
|
24 |
+
output = model.generate(**inputs, output_attentions=True, return_dict_in_generate=True)
|
25 |
|
26 |
+
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
|
27 |
+
st.success("🌸 Summary:")
|
28 |
+
st.markdown(f"`{decoded}`")
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
st.markdown("🌷 Attention Visualization (Encoder Self-Attention)")
|
31 |
+
|
32 |
+
# Extract attentions
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = model(**inputs, output_attentions=True)
|
35 |
+
attentions = outputs.encoder_attentions[0][0]
|
|
|
36 |
|
37 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
38 |
+
sns.heatmap(attentions.mean(
|
|
|
|
|
|
|
|