sergiopaniego HF Staff commited on
Commit
fce10b4
·
verified ·
1 Parent(s): 938994c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -29
app.py CHANGED
@@ -16,45 +16,32 @@ from utils import *
16
  from PIL import Image
17
  from gradio_image_prompter import ImagePrompter
18
 
19
- #sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
20
- #sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
21
- sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base", device_map="auto", torch_dtype="auto")
22
- sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
23
 
24
- #sam_model = SamModel.from_pretrained("facebook/sam-vit-huge")
25
- #sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
26
- sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto", torch_dtype="auto")
27
- sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
28
 
29
  @spaces.GPU
30
  def predict_masks_and_scores(model_id, raw_image, input_points=None, input_boxes=None):
31
- if input_boxes is not None:
32
- input_boxes = [input_boxes]
33
-
34
  if model_id == 'sam':
35
- inputs = sam_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
 
36
  else:
37
- inputs = sam_hq_processor(raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt")
 
 
 
 
 
38
  original_sizes = inputs["original_sizes"]
39
  reshaped_sizes = inputs["reshaped_input_sizes"]
 
40
 
41
- if model_id == 'sam':
42
- inputs = inputs.to(sam_model.device)
43
- with torch.no_grad():
44
- outputs = sam_model(**inputs)
45
- else:
46
- inputs = inputs.to(sam_hq_model.device)
47
- with torch.no_grad():
48
- outputs = sam_hq_model(**inputs)
49
 
50
- if model_id == 'sam':
51
- masks = sam_processor.image_processor.post_process_masks(
52
- outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
53
- )
54
- else:
55
- masks = sam_hq_processor.image_processor.post_process_masks(
56
- outputs.pred_masks.cpu(), original_sizes, reshaped_sizes
57
- )
58
  scores = outputs.iou_scores
59
  return masks, scores
60
 
 
16
  from PIL import Image
17
  from gradio_image_prompter import ImagePrompter
18
 
 
 
 
 
19
 
20
+ #sam_hq_model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base", device_map="auto", torch_dtype="auto")
21
+ #sam_hq_processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
22
+ #sam_model = SamModel.from_pretrained("facebook/sam-vit-base", device_map="auto", torch_dtype="auto")
23
+ #sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
24
 
25
  @spaces.GPU
26
  def predict_masks_and_scores(model_id, raw_image, input_points=None, input_boxes=None):
 
 
 
27
  if model_id == 'sam':
28
+ model = SamModel.from_pretrained("facebook/sam-vit-base").to("cuda")
29
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
30
  else:
31
+ model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to("cuda")
32
+ processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
33
+
34
+ inputs = processor(raw_image, input_boxes=[input_boxes] if input_boxes else None,
35
+ input_points=[input_points] if input_points else None, return_tensors="pt")
36
+
37
  original_sizes = inputs["original_sizes"]
38
  reshaped_sizes = inputs["reshaped_input_sizes"]
39
+ inputs = inputs.to("cuda")
40
 
41
+ with torch.no_grad():
42
+ outputs = model(**inputs)
 
 
 
 
 
 
43
 
44
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), original_sizes, reshaped_sizes)
 
 
 
 
 
 
 
45
  scores = outputs.iou_scores
46
  return masks, scores
47