wuhp commited on
Commit
651f077
·
verified ·
1 Parent(s): f502e4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -20
app.py CHANGED
@@ -3,24 +3,25 @@ from ultralytics import YOLO
3
  import cv2
4
  import tempfile
5
 
6
- # Function to load the model dynamically based on user input.
7
- def load_model(model_id: str):
8
  try:
9
- model = YOLO(model_id)
 
10
  return model
11
  except Exception as e:
12
  return f"Error loading model: {e}"
13
 
14
- # Inference for images: runs the model and plots predictions.
15
  def predict_image(model, image):
16
  try:
17
  results = model(image)
18
- annotated_frame = results[0].plot() # This handles detection, segmentation, or OBB models.
19
  return annotated_frame
20
  except Exception as e:
21
  return f"Error during image inference: {e}"
22
 
23
- # Inference for videos: processes the video frame by frame.
24
  def predict_video(model, video_file):
25
  try:
26
  cap = cv2.VideoCapture(video_file.name)
@@ -35,7 +36,7 @@ def predict_video(model, video_file):
35
 
36
  if not frames:
37
  return "Error: No frames processed from video."
38
-
39
  height, width, _ = frames[0].shape
40
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
41
  temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
@@ -47,10 +48,11 @@ def predict_video(model, video_file):
47
  except Exception as e:
48
  return f"Error during video inference: {e}"
49
 
50
- # Unified inference function based on media type.
51
- def inference(model_id, input_media, media_type):
52
- model = load_model(model_id)
53
- if isinstance(model, str): # Indicates an error message.
 
54
  return model
55
 
56
  if media_type == "Image":
@@ -60,20 +62,24 @@ def inference(model_id, input_media, media_type):
60
  else:
61
  return "Unsupported media type."
62
 
63
- # Updated Gradio interface components using the new API.
64
- model_text_input = gr.Textbox(label="Model Identifier", placeholder="e.g., yolov8n.pt, yolov8-seg.pt, yolov8-obb.pt")
65
- file_input = gr.File(label="Upload Image/Video File")
 
 
 
66
  media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
67
- output_file = gr.File(label="Processed Output")
68
 
69
- # Create Gradio interface.
70
  iface = gr.Interface(
71
  fn=inference,
72
- inputs=[model_text_input, file_input, media_type_dropdown],
73
- outputs=output_file,
74
- title="Dynamic Ultralytics YOLO Inference",
75
  description=(
76
- "Enter the model identifier (supports detection, segmentation, or OBB models) and upload an image or video."
 
77
  )
78
  )
79
 
 
3
  import cv2
4
  import tempfile
5
 
6
+ # Function to load a custom YOLO model from an uploaded file.
7
+ def load_model(model_file):
8
  try:
9
+ # model_file is a TemporaryFile object. Use .name to get its path.
10
+ model = YOLO(model_file.name)
11
  return model
12
  except Exception as e:
13
  return f"Error loading model: {e}"
14
 
15
+ # Function to perform inference on an image.
16
  def predict_image(model, image):
17
  try:
18
  results = model(image)
19
+ annotated_frame = results[0].plot() # This should work across detection, segmentation, or OBB models.
20
  return annotated_frame
21
  except Exception as e:
22
  return f"Error during image inference: {e}"
23
 
24
+ # Function to perform inference on a video.
25
  def predict_video(model, video_file):
26
  try:
27
  cap = cv2.VideoCapture(video_file.name)
 
36
 
37
  if not frames:
38
  return "Error: No frames processed from video."
39
+
40
  height, width, _ = frames[0].shape
41
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
42
  temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
 
48
  except Exception as e:
49
  return f"Error during video inference: {e}"
50
 
51
+ # Unified inference function that takes an uploaded model file, an input media file, and the selected media type.
52
+ def inference(model_file, input_media, media_type):
53
+ model = load_model(model_file)
54
+ # Check if model loading resulted in an error message.
55
+ if isinstance(model, str):
56
  return model
57
 
58
  if media_type == "Image":
 
62
  else:
63
  return "Unsupported media type."
64
 
65
+ # Updated Gradio components:
66
+ # - A file input for the custom YOLO model (.pt file)
67
+ # - A file input for the image or video to process
68
+ # - A radio button for selecting between image and video processing.
69
+ model_file_input = gr.File(label="Upload Custom YOLO Model (.pt file)")
70
+ media_file_input = gr.File(label="Upload Image/Video File")
71
  media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
72
+ output_component = gr.File(label="Processed Output")
73
 
74
+ # Create the Gradio interface.
75
  iface = gr.Interface(
76
  fn=inference,
77
+ inputs=[model_file_input, media_file_input, media_type_dropdown],
78
+ outputs=output_component,
79
+ title="Custom YOLO Model Inference",
80
  description=(
81
+ "Upload your custom YOLO model (for detection, segmentation, or OBB) along with an image or video file "
82
+ "to run inference. The system dynamically loads your model and processes the media accordingly."
83
  )
84
  )
85