bunnyroshan commited on
Commit
4ece717
·
verified ·
1 Parent(s): 4102c80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -26
app.py CHANGED
@@ -1,33 +1,63 @@
1
- import torch
2
- from torchvision import transforms
3
- from PIL import Image
4
  import gradio as gr
 
 
 
 
5
 
6
- # Load the YOLO model
7
- model = torch.hub.load('ultralytics/yolov8', 'custom', path='best.pt') # Ensure 'best.pt' is the path to your trained model
8
-
9
- # Define a function to process the image and make predictions
10
- def detect_objects(image):
11
- # Preprocess the image
12
- transform = transforms.Compose([
13
- transforms.ToTensor()
14
- ])
15
- image = transform(image).unsqueeze(0) # Add batch dimension
16
 
17
- # Perform inference
18
  results = model(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Extract bounding boxes and labels
21
- bbox_img = results.render()[0] # This gets the image with bounding boxes drawn
22
-
23
- return Image.fromarray(bbox_img)
24
-
25
- # Create the Gradio interface
26
- inputs = gr.inputs.Image(shape=(640, 480))
27
- outputs = gr.outputs.Image(type="pil")
28
 
29
- gr_interface = gr.Interface(fn=detect_objects, inputs=inputs, outputs=outputs, title="YOLO Object Detection", description="Upload an image to detect objects using a YOLO model.")
30
 
31
- # Run the Gradio app
32
- if __name__ == "__main__":
33
- gr_interface.launch()
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from ultralytics import YOLO
4
+ import cv2
5
+ import tempfile
6
 
7
+ # Load the trained YOLOv8 model
8
+ model = YOLO('best.pt')
 
 
 
 
 
 
 
 
9
 
10
+ def predict(image):
11
  results = model(image)
12
+ # You might want to process results to return bounding boxes, class labels, etc.
13
+ annotated_image = results[0].plot() # plot the results on the image
14
+ return annotated_image
15
+
16
+ def predict_video(video):
17
+ # Read the video file
18
+ cap = cv2.VideoCapture(video)
19
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
20
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
21
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
22
+
23
+ # Create a temporary file to save the output video
24
+ out_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
25
+ out_path = out_file.name
26
+
27
+ # Define the codec and create VideoWriter object
28
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
29
+ out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
30
+
31
+ while cap.isOpened():
32
+ ret, frame = cap.read()
33
+ if not ret:
34
+ break
35
+ results = model(frame)
36
+ annotated_frame = results[0].plot() # plot the results on the frame
37
+ out.write(annotated_frame)
38
+
39
+ cap.release()
40
+ out.release()
41
+
42
+ return out_path
43
+
44
+ # Create Gradio interface
45
+ interface = gr.Interface(
46
+ fn=lambda img, vid: (predict(img), predict_video(vid)),
47
+ inputs=[
48
+ gr.inputs.Image(type="numpy", label="Input Image"),
49
+ gr.inputs.Video(label="Input Video")
50
+ ],
51
+ outputs=[
52
+ gr.outputs.Image(type="numpy", label="Output Image"),
53
+ gr.outputs.Video(label="Output Video")
54
+ ],
55
+ title="YOLOv8 Object Detection",
56
+ description="Upload an image or a video and get the object detection results using a YOLOv8 model."
57
+ )
58
 
59
+ if __name__ == "__main__":
60
+ interface.launch()
 
 
 
 
 
 
61
 
 
62
 
63
+