saritha5 commited on
Commit
de28250
·
1 Parent(s): 797c064

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -1
app.py CHANGED
@@ -10,7 +10,16 @@ import streamlit as st
10
  warnings.filterwarnings("ignore", category=UserWarning)
11
  from tempfile import NamedTemporaryFile
12
 
 
 
13
 
 
 
 
 
 
 
 
14
 
15
 
16
  MODEL_PATH = "SD_model_weights.pth"
@@ -43,8 +52,62 @@ def detect_object(IMAGE_PATH):
43
  num_list = filtered_indices[0].tolist()
44
  filtered_labels = [labels[i] for i in num_list]
45
  show_labeled_image(image, filtered_boxes, filtered_labels)
46
- #st.image(image,filtered_boxes,filtered_labels)
 
47
  #img_array = img_to_array(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  file = st.file_uploader('Upload an Image',type=(["jpeg","jpg","png"]))
50
 
 
10
  warnings.filterwarnings("ignore", category=UserWarning)
11
  from tempfile import NamedTemporaryFile
12
 
13
+ import cv2
14
+ import matplotlib.patches as patches
15
 
16
+ import torch
17
+
18
+ import matplotlib.image as mpimg
19
+ import os
20
+
21
+ from detecto.utils import reverse_normalize, normalize_transform, _is_iterable
22
+ from torchvision import transforms
23
 
24
 
25
  MODEL_PATH = "SD_model_weights.pth"
 
52
  num_list = filtered_indices[0].tolist()
53
  filtered_labels = [labels[i] for i in num_list]
54
  show_labeled_image(image, filtered_boxes, filtered_labels)
55
+
56
+ show_image(image,filtered_boxes,filtered_labels)
57
  #img_array = img_to_array(img)
58
+ def show_image(image, boxes, labels=None):
59
+ """Show the image along with the specified boxes around detected objects.
60
+ Also displays each box's label if a list of labels is provided.
61
+ :param image: The image to plot. If the image is a normalized
62
+ torch.Tensor object, it will automatically be reverse-normalized
63
+ and converted to a PIL image for plotting.
64
+ :type image: numpy.ndarray or torch.Tensor
65
+ :param boxes: A torch tensor of size (N, 4) where N is the number
66
+ of boxes to plot, or simply size 4 if N is 1.
67
+ :type boxes: torch.Tensor
68
+ :param labels: (Optional) A list of size N giving the labels of
69
+ each box (labels[i] corresponds to boxes[i]). Defaults to None.
70
+ :type labels: torch.Tensor or None
71
+ **Example**::
72
+ >>> from detecto.core import Model
73
+ >>> from detecto.utils import read_image
74
+ >>> from detecto.visualize import show_labeled_image
75
+ >>> model = Model.load('model_weights.pth', ['tick', 'gate'])
76
+ >>> image = read_image('image.jpg')
77
+ >>> labels, boxes, scores = model.predict(image)
78
+ >>> show_labeled_image(image, boxes, labels)
79
+ """
80
+ fig, ax = plt.subplots(1)
81
+ # If the image is already a tensor, convert it back to a PILImage
82
+ # and reverse normalize it
83
+ if isinstance(image, torch.Tensor):
84
+ image = reverse_normalize(image)
85
+ image = transforms.ToPILImage()(image)
86
+ ax.imshow(image)
87
+
88
+ # Show a single box or multiple if provided
89
+ if boxes.ndim == 1:
90
+ boxes = boxes.view(1, 4)
91
+
92
+ if labels is not None and not _is_iterable(labels):
93
+ labels = [labels]
94
+
95
+ # Plot each box
96
+ for i in range(2):
97
+ box = boxes[i]
98
+ width, height = (box[2] - box[0]).item(), (box[3] - box[1]).item()
99
+ initial_pos = (box[0].item(), box[1].item())
100
+ rect = patches.Rectangle(initial_pos, width, height, linewidth=1,
101
+ edgecolor='r', facecolor='none')
102
+ if labels:
103
+ ax.text(box[0] + 5, box[1] - 5, '{}'.format(labels[i]), color='red')
104
+
105
+ ax.add_patch(rect)
106
+
107
+ cp = os.path.abspath(os.getcwd()) + '/foo.png'
108
+ plt.savefig(cp)
109
+ plt.close(fig)
110
+ #print(type(plt
111
 
112
  file = st.file_uploader('Upload an Image',type=(["jpeg","jpg","png"]))
113