FastSAM / app.py
AkashDataScience's picture
Minor fix
e93d8c1
raw
history blame
1.83 kB
import torch
import numpy as np
import gradio as gr
import torch.nn.functional as F
from PIL import Image
from fastsam import FastSAM, FastSAMPrompt
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
model = FastSAM('./weights/FastSAM-x.pt')
model.to(device)
def inference(image, conf_thres, iou_thres,):
pred = model(image, device=device, retina_masks=True, imgsz=1024, conf=conf_thres, iou=iou_thres)
prompt_process = FastSAMPrompt(image, pred, device="cpu")
ann = prompt_process.everything_prompt()
prompt_process.plot(annotations=ann, output_path="./output.jpg", withContours=False, better_quality=False)
output = Image.open('./output.jpg')
output = np.array(output)
return output
title = "FAST-SAM Segment Anything"
description = "A simple Gradio interface to infer on FAST-SAM model"
examples = [["image_1.jpg", 0.25, 0.45],
["image_2.jpg", 0.25, 0.45],
["image_3.jpg", 0.25, 0.45],
["image_4.jpg", 0.25, 0.45],
["image_5.jpg", 0.25, 0.45],
["image_6.jpg", 0.25, 0.45],
["image_7.jpg", 0.25, 0.45],
["image_8.jpg", 0.25, 0.45],
["image_9.jpg", 0.25, 0.45],
["image_10.jpg", 0.25, 0.45]]
demo = gr.Interface(inference,
inputs = [gr.Image(width=320, height=320, label="Input Image"),
gr.Slider(0, 1, 0.25, label="Confidence Threshold"),
gr.Slider(0, 1, 0.45, label="IoU Thresold")],
outputs= [gr.Image(width=640, height=640, label="Output")],
title=title,
description=description,
examples=examples)
demo.launch()