import gradio as gr from ultralyticsplus import YOLO, render_result, postprocess_classify_output def classification(image, threshold): model = YOLO('yolov8n-cls.pt') model.overrides['conf'] = threshold # result = model('bus.jpg') result = model.predict(image) render = postprocess_classify_output(model=model, result=result[0]) return render def detection(image, threshold): model = YOLO('yolov8n.pt') model.overrides['conf'] = threshold results = model.predict(image) render = render_result(model=model, image=image, result=results[0]) return render def segmentation(image, threshold): model = YOLO('yolov8n-seg.pt') model.overrides['conf'] = threshold results = model.predict(image) render = render_result(model=model, image=image, result=results[0]) return render with gr.Blocks() as demo: with gr.Tab("Detection"): with gr.Row(): with gr.Column(): detect_input = gr.Image() detect_threshold = gr.Slider( maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) detect_button = gr.Button("Detect!") with gr.Column(): detect_output = gr.Image( label="Predictions:", interactive=False) with gr.Tab("Segmentation"): with gr.Row(): with gr.Column(): segment_input = gr.Image() segment_threshold = gr.Slider( maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) segment_button = gr.Button("Segment!") with gr.Column(): segment_output = gr.Image( label="Predictions:", interactive=False) with gr.Tab("Classification"): with gr.Row(): with gr.Column(): classify_input = gr.Image() classify_threshold = gr.Slider( maximum=1, step=0.01, value=0.25, label="Threshold:", interactive=True) classify_button = gr.Button("Classify!") with gr.Column(): classify_output = gr.Label( label="Predictions:", show_label=True, num_top_classes=5) detect_button.click( detection, inputs=[ detect_input, detect_threshold], outputs=detect_output, api_name="Detect") segment_button.click( segmentation, inputs=[ segment_input, segment_threshold], outputs=segment_output, api_name="Segmentation") classify_button.click( classification, inputs=[ classify_input, classify_threshold], outputs=classify_output, api_name="classify") demo.launch(debug=True, enable_queue=True)