ML / app.py
bunnyroshan's picture
Create app.py
4f23172 verified
raw
history blame
1.08 kB
import torch
from torchvision import transforms
from PIL import Image
import gradio as gr
# Load the YOLO model
model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt') # Ensure 'best.pt' is the path to your trained model
# Define a function to process the image and make predictions
def detect_objects(image):
# Preprocess the image
transform = transforms.Compose([
transforms.ToTensor()
])
image = transform(image).unsqueeze(0) # Add batch dimension
# Perform inference
results = model(image)
# Extract bounding boxes and labels
bbox_img = results.render()[0] # This gets the image with bounding boxes drawn
return Image.fromarray(bbox_img)
# Create the Gradio interface
inputs = gr.inputs.Image(shape=(640, 480))
outputs = gr.outputs.Image(type="pil")
gr_interface = gr.Interface(fn=detect_objects, inputs=inputs, outputs=outputs, title="YOLO Object Detection", description="Upload an image to detect objects using a YOLO model.")
# Run the Gradio app
if __name__ == "__main__":
gr_interface.launch()