sergiopaniego HF Staff commited on
Commit
579099e
·
verified ·
1 Parent(s): fce10b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -17,31 +17,41 @@ from PIL import Image
17
  from gradio_image_prompter import ImagePrompter
18
 
19
 
20
- #sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base", device_map="auto", torch_dtype="auto")
21
- #sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
22
- #sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto", torch_dtype="auto")
23
- #sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
24
 
25
  @spaces.GPU
26
  def predict_masks_and_scores(model_id, raw_image, input_points=None, input_boxes=None):
 
 
 
27
  if model_id == 'sam':
28
- model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
29
- processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
30
  else:
31
- model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to("cuda")
32
- processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
33
-
34
- inputs = processor(raw_image, input_boxes=[input_boxes] if input_boxes else None,
35
- input_points=[input_points] if input_points else None, return_tensors="pt")
36
 
37
  original_sizes = inputs["original_sizes"]
38
  reshaped_sizes = inputs["reshaped_input_sizes"]
39
- inputs = inputs.to("cuda")
40
 
41
- with torch.no_grad():
42
- outputs = model(**inputs)
 
 
 
 
 
 
43
 
44
- masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), original_sizes, reshaped_sizes)
 
 
 
 
 
 
 
45
  scores = outputs.iou_scores
46
  return masks, scores
47
 
 
17
  from gradio_image_prompter import ImagePrompter
18
 
19
 
20
+ sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base", device_map="auto", torch_dtype="auto")
21
+ sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
22
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto", torch_dtype="auto")
23
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
24
 
25
  @spaces.GPU
26
  def predict_masks_and_scores(model_id, raw_image, input_points=None, input_boxes=None):
27
+ if input_boxes is not None:
28
+ input_boxes = [input_boxes]
29
+
30
  if model_id == 'sam':
31
+ inputs = sam_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
 
32
  else:
33
+ inputs = sam_hq_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
 
 
 
 
34
 
35
  original_sizes = inputs["original_sizes"]
36
  reshaped_sizes = inputs["reshaped_input_sizes"]
 
37
 
38
+ if model_id == 'sam':
39
+ inputs = inputs.to(sam_model.device)
40
+ with torch.no_grad():
41
+ outputs = sam_model(**inputs)
42
+ else:
43
+ inputs = inputs.to(sam_hq_model.device)
44
+ with torch.no_grad():
45
+ outputs = sam_hq_model(**inputs)
46
 
47
+ if model_id == 'sam':
48
+ masks = sam_processor.image_processor.post_process_masks(
49
+ outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
50
+ )
51
+ else:
52
+ masks = sam_hq_processor.image_processor.post_process_masks(
53
+ outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
54
+ )
55
  scores = outputs.iou_scores
56
  return masks, scores
57