AkashDataScience commited on
Commit
4dcd8cb
·
1 Parent(s): 96836cb

Adding text promp

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -14,33 +14,41 @@ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
14
  model = FastSAM('./weights/FastSAM-x.pt')
15
  model.to(device)
16
 
17
- def inference(image, conf_thres, iou_thres,):
18
  pred = model(image, device=device, retina_masks=True, imgsz=1024, conf=conf_thres, iou=iou_thres)
19
  prompt_process = FastSAMPrompt(image, pred, device="cpu")
20
  ann = prompt_process.everything_prompt()
21
- prompt_process.plot(annotations=ann, output_path="./output.jpg", withContours=False, better_quality=False)
22
- output = Image.open('./output.jpg')
23
- output = np.array(output)
24
- return output
 
 
 
 
 
 
25
 
26
  title = "FAST-SAM Segment Anything"
27
  description = "A simple Gradio interface to infer on FAST-SAM model"
28
- examples = [["image_1.jpg", 0.25, 0.45],
29
- ["image_2.jpg", 0.25, 0.45],
30
- ["image_3.jpg", 0.25, 0.45],
31
- ["image_4.jpg", 0.25, 0.45],
32
- ["image_5.jpg", 0.25, 0.45],
33
- ["image_6.jpg", 0.25, 0.45],
34
- ["image_7.jpg", 0.25, 0.45],
35
- ["image_8.jpg", 0.25, 0.45],
36
- ["image_9.jpg", 0.25, 0.45],
37
- ["image_10.jpg", 0.25, 0.45]]
38
 
39
  demo = gr.Interface(inference,
40
- inputs = [gr.Image(width=320, height=320, label="Input Image"),
41
  gr.Slider(0, 1, 0.25, label="Confidence Threshold"),
42
- gr.Slider(0, 1, 0.45, label="IoU Thresold")],
43
- outputs= [gr.Image(width=640, height=640, label="Output")],
 
 
44
  title=title,
45
  description=description,
46
  examples=examples)
 
14
  model = FastSAM('./weights/FastSAM-x.pt')
15
  model.to(device)
16
 
17
+ def inference(image, conf_thres, iou_thres, text):
18
  pred = model(image, device=device, retina_masks=True, imgsz=1024, conf=conf_thres, iou=iou_thres)
19
  prompt_process = FastSAMPrompt(image, pred, device="cpu")
20
  ann = prompt_process.everything_prompt()
21
+ prompt_process.plot(annotations=ann, output_path="./output_sam.jpg", withContours=False, better_quality=False)
22
+ output_sam = Image.open('./output_sam.jpg')
23
+ output_sam = np.array(output_sam)
24
+ output_text = None
25
+ if text:
26
+ ann = prompt_process.text_prompt(text=text)
27
+ prompt_process.plot(annotations=ann, output_path="./output_text.jpg", withContours=False, better_quality=False)
28
+ output_text = Image.open('./output_text.jpg')
29
+ output_text = np.array(output_text)
30
+ return output_sam, output_text
31
 
32
  title = "FAST-SAM Segment Anything"
33
  description = "A simple Gradio interface to infer on FAST-SAM model"
34
+ examples = [["image_1.jpg", 0.25, 0.45, 'A black tire'],
35
+ ["image_2.jpg", 0.25, 0.45, 'Shades of blue'],
36
+ ["image_3.jpg", 0.25, 0.45, 'A spiral staircase'],
37
+ ["image_4.jpg", 0.25, 0.45, 'A clock and a plane'],
38
+ ["image_5.jpg", 0.25, 0.45, 'Clouds in the sky'],
39
+ ["image_6.jpg", 0.25, 0.45, 'Front wheel'],
40
+ ["image_7.jpg", 0.25, 0.45, 'A white chair'],
41
+ ["image_8.jpg", 0.25, 0.45, 'The grassy field'],
42
+ ["image_9.jpg", 0.25, 0.45, 'Rock formation'],
43
+ ["image_10.jpg", 0.25, 0.45, 'A rope railing']]
44
 
45
  demo = gr.Interface(inference,
46
+ inputs = [gr.Image(width=640, height=640, label="Input Image"),
47
  gr.Slider(0, 1, 0.25, label="Confidence Threshold"),
48
+ gr.Slider(0, 1, 0.45, label="IoU Thresold"),
49
+ gr.Textbox(label="Enter text promp", type="text"),],
50
+ outputs= [gr.Image(width=640, height=640, label="Output SAM"),
51
+ gr.Image(width=640, height=640, label="Output text prompt")],
52
  title=title,
53
  description=description,
54
  examples=examples)