File size: 2,831 Bytes
eb25988
 
 
 
 
f502e4f
eb25988
 
 
 
 
 
 
f502e4f
eb25988
 
 
f502e4f
eb25988
 
 
 
f502e4f
eb25988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f502e4f
eb25988
 
f502e4f
eb25988
 
 
 
 
 
 
 
 
f502e4f
eb25988
f502e4f
eb25988
f502e4f
eb25988
f502e4f
eb25988
 
 
f502e4f
eb25988
 
f502e4f
eb25988
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from ultralytics import YOLO
import cv2
import tempfile

# Function to load the model dynamically based on user input.
def load_model(model_id: str):
    try:
        model = YOLO(model_id)
        return model
    except Exception as e:
        return f"Error loading model: {e}"

# Inference for images: runs the model and plots predictions.
def predict_image(model, image):
    try:
        results = model(image)
        annotated_frame = results[0].plot()  # This handles detection, segmentation, or OBB models.
        return annotated_frame
    except Exception as e:
        return f"Error during image inference: {e}"

# Inference for videos: processes the video frame by frame.
def predict_video(model, video_file):
    try:
        cap = cv2.VideoCapture(video_file.name)
        frames = []
        success, frame = cap.read()
        while success:
            results = model(frame)
            annotated_frame = results[0].plot()
            frames.append(annotated_frame)
            success, frame = cap.read()
        cap.release()
        
        if not frames:
            return "Error: No frames processed from video."
            
        height, width, _ = frames[0].shape
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
        out = cv2.VideoWriter(temp_video_file.name, fourcc, 20.0, (width, height))
        for frame in frames:
            out.write(frame)
        out.release()
        return temp_video_file.name
    except Exception as e:
        return f"Error during video inference: {e}"

# Unified inference function based on media type.
def inference(model_id, input_media, media_type):
    model = load_model(model_id)
    if isinstance(model, str):  # Indicates an error message.
        return model

    if media_type == "Image":
        return predict_image(model, input_media)
    elif media_type == "Video":
        return predict_video(model, input_media)
    else:
        return "Unsupported media type."

# Updated Gradio interface components using the new API.
model_text_input = gr.Textbox(label="Model Identifier", placeholder="e.g., yolov8n.pt, yolov8-seg.pt, yolov8-obb.pt")
file_input = gr.File(label="Upload Image/Video File")
media_type_dropdown = gr.Radio(choices=["Image", "Video"], label="Select Media Type", value="Image")
output_file = gr.File(label="Processed Output")

# Create Gradio interface.
iface = gr.Interface(
    fn=inference,
    inputs=[model_text_input, file_input, media_type_dropdown],
    outputs=output_file,
    title="Dynamic Ultralytics YOLO Inference",
    description=(
        "Enter the model identifier (supports detection, segmentation, or OBB models) and upload an image or video."
    )
)

if __name__ == "__main__":
    iface.launch()