rahideer commited on
Commit
fa35904
·
verified ·
1 Parent(s): 21fe83f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -40
app.py CHANGED
@@ -1,45 +1,38 @@
1
  import streamlit as st
2
- from PIL import Image
3
- from ultralytics import YOLO
4
  import torch
5
-
6
- st.set_page_config(page_title="Animal Detection App", layout="centered")
7
-
8
- # Load YOLOv8 model
9
- @st.cache_resource
10
- def load_model():
11
- return YOLO("yolov8s.pt")
12
-
13
- model = load_model()
14
-
15
- st.title("🐾 Animal Detection App")
16
- st.write("Upload an image and let the YOLOv8 model detect animals!")
17
-
18
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
19
-
20
- if uploaded_file is not None:
21
- image = Image.open(uploaded_file)
22
- st.image(image, caption="Uploaded Image", use_container_width=True)
 
 
 
23
 
24
- with st.spinner("Detecting..."):
25
- results = model.predict(image)
26
-
27
- for r in results:
28
- rendered_img = r.plot() # Draws boxes and labels on the image
29
- st.image(rendered_img, caption="Detected Image", use_container_width=True)
30
-
31
 
32
- result_img = Image.fromarray(results[0].plot()[:, :, ::-1])
33
- st.image(result_img, caption="Detected Animals", use_column_width=True)
34
-
35
- # Filter animal predictions
36
- animal_labels = ["cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "bird"]
37
- names = model.names
38
- detections = results[0].boxes.data.cpu().numpy()
39
 
40
- st.subheader("Detections:")
41
- for det in detections:
42
- class_id = int(det[5])
43
- label = names[class_id]
44
- if label in animal_labels:
45
- st.markdown(f"- **{label}** (Confidence: {det[4]:.2f})")
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
 
3
  import torch
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+
7
+ # App Title and Styling
8
+ st.set_page_config(page_title="Transflower 🌸", layout="centered")
9
+ st.markdown("<h1 style='text-align: center; color: #D16BA5;'>Transflower 🌸</h1>", unsafe_allow_html=True)
10
+ st.markdown("<p style='text-align: center; color: #8E44AD;'>A girly & elegant app to visualize Transformer models</p>", unsafe_allow_html=True)
11
+
12
+ # Load model and tokenizer
13
+ model_name = "sshleifer/distilbart-cnn-12-6"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
16
+
17
+ # Text Input
18
+ input_text = st.text_area("🌼 Enter text to summarize or visualize:", height=150)
19
+
20
+ # When user clicks the button
21
+ if st.button(" Visualize Transformer Magic ✨") and input_text:
22
+ inputs = tokenizer(input_text, return_tensors="pt")
23
+ with torch.no_grad():
24
+ output = model.generate(**inputs, output_attentions=True, return_dict_in_generate=True)
25
 
26
+ decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
27
+ st.success("🌸 Summary:")
28
+ st.markdown(f"`{decoded}`")
 
 
 
 
29
 
30
+ st.markdown("🌷 Attention Visualization (Encoder Self-Attention)")
31
+
32
+ # Extract attentions
33
+ with torch.no_grad():
34
+ outputs = model(**inputs, output_attentions=True)
35
+ attentions = outputs.encoder_attentions[0][0]
 
36
 
37
+ fig, ax = plt.subplots(figsize=(10, 6))
38
+ sns.heatmap(attentions.mean(