FastSAM / app.py
AkashDataScience's picture
Updating image size
eaa892e
import torch
import numpy as np
import gradio as gr
import torch.nn.functional as F
from PIL import Image
from fastsam import FastSAM, FastSAMPrompt
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
model = FastSAM('./weights/FastSAM-x.pt')
model.to(device)
def inference(image, conf_thres, iou_thres, text):
pred = model(image, device=device, retina_masks=True, imgsz=1024, conf=conf_thres, iou=iou_thres)
prompt_process = FastSAMPrompt(image, pred, device="cpu")
ann = prompt_process.everything_prompt()
prompt_process.plot(annotations=ann, output_path="./output_sam.jpg", withContours=False, better_quality=False)
output_sam = Image.open('./output_sam.jpg')
output_sam = np.array(output_sam)
output_text = None
if text:
ann = prompt_process.text_prompt(text=text)
prompt_process.plot(annotations=ann, output_path="./output_text.jpg", withContours=False, better_quality=False)
output_text = Image.open('./output_text.jpg')
output_text = np.array(output_text)
return output_sam, output_text
title = "FAST-SAM Segment Anything"
description = "A simple Gradio interface to infer on FAST-SAM model"
examples = [["image_1.jpg", 0.25, 0.45, 'A black tire'],
["image_2.jpg", 0.25, 0.45, 'Shades of blue'],
["image_3.jpg", 0.25, 0.45, 'A spiral staircase'],
["image_4.jpg", 0.25, 0.45, 'A clock and a plane'],
["image_5.jpg", 0.25, 0.45, 'Clouds in the sky'],
["image_6.jpg", 0.25, 0.45, 'Front wheel'],
["image_7.jpg", 0.25, 0.45, 'A white chair'],
["image_8.jpg", 0.25, 0.45, 'The grassy field'],
["image_9.jpg", 0.25, 0.45, 'Rock formation'],
["image_10.jpg", 0.25, 0.45, 'A rope railing']]
demo = gr.Interface(inference,
inputs = [gr.Image(width=320, height=320, label="Input Image"),
gr.Slider(0, 1, 0.25, label="Confidence Threshold"),
gr.Slider(0, 1, 0.45, label="IoU Thresold"),
gr.Textbox(label="Enter text promp", type="text"),],
outputs= [gr.Image(width=320, height=320, label="Output SAM"),
gr.Image(width=320, height=320, label="Output text prompt")],
title=title,
description=description,
examples=examples)
demo.launch()