Update app.py
Browse files
app.py
CHANGED
@@ -3,24 +3,25 @@ from ultralytics import YOLO
|
|
3 |
import cv2
|
4 |
import tempfile
|
5 |
|
6 |
-
# Function to load
|
7 |
-
def load_model(
|
8 |
try:
|
9 |
-
|
|
|
10 |
return model
|
11 |
except Exception as e:
|
12 |
return f"Error loading model: {e}"
|
13 |
|
14 |
-
#
|
15 |
def predict_image(model, image):
|
16 |
try:
|
17 |
results = model(image)
|
18 |
-
annotated_frame = results[0].plot() # This
|
19 |
return annotated_frame
|
20 |
except Exception as e:
|
21 |
return f"Error during image inference: {e}"
|
22 |
|
23 |
-
#
|
24 |
def predict_video(model, video_file):
|
25 |
try:
|
26 |
cap = cv2.VideoCapture(video_file.name)
|
@@ -35,7 +36,7 @@ def predict_video(model, video_file):
|
|
35 |
|
36 |
if not frames:
|
37 |
return "Error: No frames processed from video."
|
38 |
-
|
39 |
height, width, _ = frames[0].shape
|
40 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
41 |
temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
@@ -47,10 +48,11 @@ def predict_video(model, video_file):
|
|
47 |
except Exception as e:
|
48 |
return f"Error during video inference: {e}"
|
49 |
|
50 |
-
# Unified inference function
|
51 |
-
def inference(
|
52 |
-
model = load_model(
|
53 |
-
if
|
|
|
54 |
return model
|
55 |
|
56 |
if media_type == "Image":
|
@@ -60,20 +62,24 @@ def inference(model_id, input_media, media_type):
|
|
60 |
else:
|
61 |
return "Unsupported media type."
|
62 |
|
63 |
-
# Updated Gradio
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
|
67 |
-
|
68 |
|
69 |
-
# Create Gradio interface.
|
70 |
iface = gr.Interface(
|
71 |
fn=inference,
|
72 |
-
inputs=[
|
73 |
-
outputs=
|
74 |
-
title="
|
75 |
description=(
|
76 |
-
"
|
|
|
77 |
)
|
78 |
)
|
79 |
|
|
|
3 |
import cv2
|
4 |
import tempfile
|
5 |
|
6 |
+
# Function to load a custom YOLO model from an uploaded file.
|
7 |
+
def load_model(model_file):
|
8 |
try:
|
9 |
+
# model_file is a TemporaryFile object. Use .name to get its path.
|
10 |
+
model = YOLO(model_file.name)
|
11 |
return model
|
12 |
except Exception as e:
|
13 |
return f"Error loading model: {e}"
|
14 |
|
15 |
+
# Function to perform inference on an image.
|
16 |
def predict_image(model, image):
|
17 |
try:
|
18 |
results = model(image)
|
19 |
+
annotated_frame = results[0].plot() # This should work across detection, segmentation, or OBB models.
|
20 |
return annotated_frame
|
21 |
except Exception as e:
|
22 |
return f"Error during image inference: {e}"
|
23 |
|
24 |
+
# Function to perform inference on a video.
|
25 |
def predict_video(model, video_file):
|
26 |
try:
|
27 |
cap = cv2.VideoCapture(video_file.name)
|
|
|
36 |
|
37 |
if not frames:
|
38 |
return "Error: No frames processed from video."
|
39 |
+
|
40 |
height, width, _ = frames[0].shape
|
41 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
42 |
temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
|
|
48 |
except Exception as e:
|
49 |
return f"Error during video inference: {e}"
|
50 |
|
51 |
+
# Unified inference function that takes an uploaded model file, an input media file, and the selected media type.
|
52 |
+
def inference(model_file, input_media, media_type):
|
53 |
+
model = load_model(model_file)
|
54 |
+
# Check if model loading resulted in an error message.
|
55 |
+
if isinstance(model, str):
|
56 |
return model
|
57 |
|
58 |
if media_type == "Image":
|
|
|
62 |
else:
|
63 |
return "Unsupported media type."
|
64 |
|
65 |
+
# Updated Gradio components:
|
66 |
+
# - A file input for the custom YOLO model (.pt file)
|
67 |
+
# - A file input for the image or video to process
|
68 |
+
# - A radio button for selecting between image and video processing.
|
69 |
+
model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
|
70 |
+
media_file_input = gr.File(label="Upload Image/Video File")
|
71 |
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
|
72 |
+
output_component = gr.File(label="Processed Output")
|
73 |
|
74 |
+
# Create the Gradio interface.
|
75 |
iface = gr.Interface(
|
76 |
fn=inference,
|
77 |
+
inputs=[model_file_input, media_file_input, media_type_dropdown],
|
78 |
+
outputs=output_component,
|
79 |
+
title="Custom YOLO Model Inference",
|
80 |
description=(
|
81 |
+
"Upload your custom YOLO model (for detection, segmentation, or OBB) along with an image or video file "
|
82 |
+
"to run inference. The system dynamically loads your model and processes the media accordingly."
|
83 |
)
|
84 |
)
|
85 |
|