sergiopaniego HF Staff commited on
Commit
938994c
·
verified ·
1 Parent(s): 9c9f88f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -31
app.py CHANGED
@@ -27,44 +27,37 @@ sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto",
27
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
28
 
29
  @spaces.GPU
30
- def predict_masks_and_scores_sam_hq(raw_image, input_points=None, input_boxes=None):
31
  if input_boxes is not None:
32
  input_boxes = [input_boxes]
33
 
34
- inputs = sam_hq_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
 
 
 
35
  original_sizes = inputs["original_sizes"]
36
  reshaped_sizes = inputs["reshaped_input_sizes"]
37
 
38
- inputs = inputs.to(sam_hq_model.device)
39
-
40
- with torch.no_grad():
41
- outputs = sam_hq_model(**inputs)
42
-
43
- masks = sam_hq_processor.image_processor.post_process_masks(
44
- outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
45
- )
 
 
 
 
 
 
 
 
 
46
  scores = outputs.iou_scores
47
  return masks, scores
48
 
49
- @spaces.GPU
50
- def predict_masks_and_scores_sam(raw_image, input_points=None, input_boxes=None):
51
- if input_boxes is not None:
52
- input_boxes = [input_boxes]
53
-
54
- inputs = sam_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
55
- original_sizes = inputs["original_sizes"]
56
- reshaped_sizes = inputs["reshaped_input_sizes"]
57
-
58
- inputs = inputs.to(sam_model.device)
59
-
60
- with torch.no_grad():
61
- outputs = sam_model(**inputs)
62
-
63
- masks = sam_processor.image_processor.post_process_masks(
64
- outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
65
- )
66
- scores = outputs.iou_scores
67
- return masks, scores
68
 
69
 
70
  def process_inputs(prompts):
@@ -88,8 +81,8 @@ def process_inputs(prompts):
88
  input_points = [input_points] if input_points else None
89
  user_image = prompts['image']
90
 
91
- sam_masks, sam_scores = predict_masks_and_scores_sam(user_image, input_boxes=input_boxes, input_points=input_points)
92
- sam_hq_masks, sam_hq_scores = predict_masks_and_scores_sam_hq(user_image, input_boxes=input_boxes, input_points=input_points)
93
 
94
  if input_boxes and input_points:
95
  img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')
 
27
  sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
28
 
29
  @spaces.GPU
30
+ def predict_masks_and_scores(model_id, raw_image, input_points=None, input_boxes=None):
31
  if input_boxes is not None:
32
  input_boxes = [input_boxes]
33
 
34
+ if model_id == 'sam':
35
+ inputs = sam_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
36
+ else:
37
+ inputs = sam_hq_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
38
  original_sizes = inputs["original_sizes"]
39
  reshaped_sizes = inputs["reshaped_input_sizes"]
40
 
41
+ if model_id == 'sam':
42
+ inputs = inputs.to(sam_model.device)
43
+ with torch.no_grad():
44
+ outputs = sam_model(**inputs)
45
+ else:
46
+ inputs = inputs.to(sam_hq_model.device)
47
+ with torch.no_grad():
48
+ outputs = sam_hq_model(**inputs)
49
+
50
+ if model_id == 'sam':
51
+ masks = sam_processor.image_processor.post_process_masks(
52
+ outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
53
+ )
54
+ else:
55
+ masks = sam_hq_processor.image_processor.post_process_masks(
56
+ outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
57
+ )
58
  scores = outputs.iou_scores
59
  return masks, scores
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def process_inputs(prompts):
 
81
  input_points = [input_points] if input_points else None
82
  user_image = prompts['image']
83
 
84
+ sam_masks, sam_scores = predict_masks_and_scores('sam', user_image, input_boxes=input_boxes, input_points=input_points)
85
+ sam_hq_masks, sam_hq_scores = predict_masks_and_scores('sam_hq', user_image, input_boxes=input_boxes, input_points=input_points)
86
 
87
  if input_boxes and input_points:
88
  img1_b64 = show_all_annotations_on_image_base64(user_image, sam_masks[0][0], sam_scores[:, 0, :], input_boxes[0], input_points[0], model_name='SAM')