rahideer commited on
Commit
41d5ab8
ยท
verified ยท
1 Parent(s): 44cb668

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -49
app.py CHANGED
@@ -1,51 +1,45 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
  import torch
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
- import numpy as np
7
-
8
- # Page setup
9
- st.set_page_config(page_title="Transflower ๐ŸŒธ", page_icon="๐ŸŒผ", layout="centered")
10
-
11
- st.markdown(
12
- "<h1 style='text-align: center; color: pink;'>๐ŸŒธ Transflower ๐ŸŒธ</h1>"
13
-
14
- unsafe_allow_html=True,
15
- )
16
-
17
- # Load model and tokenizer
18
- model_name = "t5-small"
19
- tokenizer = AutoTokenizer.from_pretrained(model_name)
20
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
- model.eval()
22
-
23
- # Input area
24
- user_input = st.text_area("๐ŸŒผ Enter text to summarize or visualize:", height=200)
25
-
26
- if st.button("โœจ Visualize Transformer Magic โœจ"):
27
- if not user_input.strip():
28
- st.warning("Please enter some text to visualize.")
29
- else:
30
- # Encode input
31
- inputs = tokenizer("summarize: " + user_input, return_tensors="pt", truncation=True)
32
-
33
- # Forward pass manually to get attention
34
- with torch.no_grad():
35
- encoder_outputs = model.encoder(**inputs, output_attentions=True, return_dict=True)
36
- attention = encoder_outputs.attentions[-1][0].mean(dim=0).detach().numpy()
37
-
38
- # Generate summary
39
- with torch.no_grad():
40
- summary_ids = model.generate(inputs["input_ids"], max_length=50)
41
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
42
-
43
- st.subheader("๐ŸŒธ Summary:")
44
- st.success(summary)
45
-
46
- st.subheader("๐Ÿ’– Encoder Attention Heatmap:")
47
-
48
- fig, ax = plt.subplots(figsize=(10, 6))
49
- sns.heatmap(attention, cmap="YlGnBu", ax=ax)
50
- ax.set_title("Encoder Self-Attention Heatmap ๐Ÿ’ซ")
51
- st.pyplot(fig)
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ from ultralytics import YOLO
4
  import torch
5
+
6
+ st.set_page_config(page_title="Animal Detection App", layout="centered")
7
+
8
+ # Load YOLOv8 model
9
+ @st.cache_resource
10
+ def load_model():
11
+ return YOLO("yolov8s.pt")
12
+
13
+ model = load_model()
14
+
15
+ st.title("๐Ÿพ Animal Detection App")
16
+ st.write("Upload an image and let the YOLOv8 model detect animals!")
17
+
18
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
19
+
20
+ if uploaded_file:
21
+ image = Image.open(uploaded_file).convert("RGB")
22
+ st.image(image, caption="Uploaded Image", use_column_width=True)
23
+
24
+ with st.spinner("Detecting..."):
25
+ results = model(image)
26
+
27
+ # Display detection results
28
+ for r in results:
29
+ rendered_img = r.plot() # r.plot() gives the image with detections
30
+ st.image(rendered_img, caption="Detected Image", use_container_width=True)
31
+
32
+ result_img = Image.fromarray(results[0].plot()[:, :, ::-1])
33
+ st.image(result_img, caption="Detected Animals", use_column_width=True)
34
+
35
+ # Filter animal predictions
36
+ animal_labels = ["cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "bird"]
37
+ names = model.names
38
+ detections = results[0].boxes.data.cpu().numpy()
39
+
40
+ st.subheader("Detections:")
41
+ for det in detections:
42
+ class_id = int(det[5])
43
+ label = names[class_id]
44
+ if label in animal_labels:
45
+ st.markdown(f"- **{label}** (Confidence: {det[4]:.2f})")