yolohost / app.py
wuhp's picture
Update app.py
f502e4f verified
raw
history blame
2.83 kB
import gradio as gr
from ultralytics import YOLO
import cv2
import tempfile
# Function to load the model dynamically based on user input.
def load_model(model_id: str):
try:
model = YOLO(model_id)
return model
except Exception as e:
return f"Error loading model: {e}"
# Inference for images: runs the model and plots predictions.
def predict_image(model, image):
try:
results = model(image)
annotated_frame = results[0].plot() # This handles detection, segmentation, or OBB models.
return annotated_frame
except Exception as e:
return f"Error during image inference: {e}"
# Inference for videos: processes the video frame by frame.
def predict_video(model, video_file):
try:
cap = cv2.VideoCapture(video_file.name)
frames = []
success, frame = cap.read()
while success:
results = model(frame)
annotated_frame = results[0].plot()
frames.append(annotated_frame)
success, frame = cap.read()
cap.release()
if not frames:
return "Error: No frames processed from video."
height, width, _ = frames[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
out = cv2.VideoWriter(temp_video_file.name, fourcc, 20.0, (width, height))
for frame in frames:
out.write(frame)
out.release()
return temp_video_file.name
except Exception as e:
return f"Error during video inference: {e}"
# Unified inference function based on media type.
def inference(model_id, input_media, media_type):
model = load_model(model_id)
if isinstance(model, str): # Indicates an error message.
return model
if media_type == "Image":
return predict_image(model, input_media)
elif media_type == "Video":
return predict_video(model, input_media)
else:
return "Unsupported media type."
# Updated Gradio interface components using the new API.
model_text_input = gr.Textbox(label="Model Identifier", placeholder="e.g., yolov8n.pt, yolov8-seg.pt, yolov8-obb.pt")
file_input = gr.File(label="Upload Image/Video File")
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
output_file = gr.File(label="Processed Output")
# Create Gradio interface.
iface = gr.Interface(
fn=inference,
inputs=[model_text_input, file_input, media_type_dropdown],
outputs=output_file,
title="Dynamic Ultralytics YOLO Inference",
description=(
"Enter the model identifier (supports detection, segmentation, or OBB models) and upload an image or video."
)
)
if __name__ == "__main__":
iface.launch()