AkashDataScience commited on
Commit
45e6c67
·
1 Parent(s): 3b4885d

First commit

Browse files
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ from fastsam import FastSAM, FastSAMPrompt
7
+
8
+ device = 'cpu'
9
+ if torch.cuda.is_available():
10
+ device = 'cuda'
11
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
12
+ device = "mps"
13
+
14
+ model = FastSAM('./weights/FastSAM.pt')
15
+ model = model.to(device)
16
+
17
+ def inference(image, conf_thres, iou_thres,):
18
+ pred = model(image, device=device, retina_masks=True, imgsz=1024, conf=conf_thres, iou=iou_thres)
19
+ prompt_process = FastSAMPrompt(input, pred, device="cpu")
20
+ ann = prompt_process.everything_prompt()
21
+ prompt_process.plot(annotations=ann, output="./output.jpg", withContours=False, better_quality=False)
22
+ output = Image.open('./output.jpg')
23
+ output = np.array(output)
24
+ return output
25
+
26
+ title = "FAST-SAM Segment Anything"
27
+ description = "A simple Gradio interface to infer on FAST-SAM model"
28
+ examples = [["image_1.jpg", 0.25, 0.45],
29
+ ["image_2.jpg", 0.25, 0.45],
30
+ ["image_3.jpg", 0.25, 0.45],
31
+ ["image_4.jpg", 0.25, 0.45],
32
+ ["image_5.jpg", 0.25, 0.45],
33
+ ["image_6.jpg", 0.25, 0.45],
34
+ ["image_7.jpg", 0.25, 0.45],
35
+ ["image_8.jpg", 0.25, 0.45],
36
+ ["image_9.jpg", 0.25, 0.45],
37
+ ["image_10.jpg", 0.25, 0.45]]
38
+
39
+ demo = gr.Interface(inference,
40
+ inputs = [gr.Image(width=320, height=320, label="Input Image"),
41
+ gr.Slider(0, 1, 0.25, label="Confidence Threshold"),
42
+ gr.Slider(0, 1, 0.45, label="IoU Thresold")],
43
+ outputs= [gr.Image(width=640, height=640, label="Output")],
44
+ title=title,
45
+ description=description,
46
+ examples=examples)
47
+
48
+ demo.launch()
fastsam/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .model import FastSAM
4
+ from .predict import FastSAMPredictor
5
+ from .prompt import FastSAMPrompt
6
+ # from .val import FastSAMValidator
7
+ from .decoder import FastSAMDecoder
8
+
9
+ __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder'
fastsam/decoder.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import FastSAM
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import Optional, List, Tuple, Union
5
+
6
+
7
+ class FastSAMDecoder:
8
+ def __init__(
9
+ self,
10
+ model: FastSAM,
11
+ device: str='cpu',
12
+ conf: float=0.4,
13
+ iou: float=0.9,
14
+ imgsz: int=1024,
15
+ retina_masks: bool=True,
16
+ ):
17
+ self.model = model
18
+ self.device = device
19
+ self.retina_masks = retina_masks
20
+ self.imgsz = imgsz
21
+ self.conf = conf
22
+ self.iou = iou
23
+ self.image = None
24
+ self.image_embedding = None
25
+
26
+ def run_encoder(self, image):
27
+ if isinstance(image,str):
28
+ image = np.array(Image.open(image))
29
+ self.image = image
30
+ image_embedding = self.model(
31
+ self.image,
32
+ device=self.device,
33
+ retina_masks=self.retina_masks,
34
+ imgsz=self.imgsz,
35
+ conf=self.conf,
36
+ iou=self.iou
37
+ )
38
+ return image_embedding[0].numpy()
39
+
40
+ def run_decoder(
41
+ self,
42
+ image_embedding,
43
+ point_prompt: Optional[np.ndarray]=None,
44
+ point_label: Optional[np.ndarray]=None,
45
+ box_prompt: Optional[np.ndarray]=None,
46
+ text_prompt: Optional[str]=None,
47
+ )->np.ndarray:
48
+ self.image_embedding = image_embedding
49
+ if point_prompt is not None:
50
+ ann = self.point_prompt(points=point_prompt, pointlabel=point_label)
51
+ return ann
52
+ elif box_prompt is not None:
53
+ ann = self.box_prompt(bbox=box_prompt)
54
+ return ann
55
+ elif text_prompt is not None:
56
+ ann = self.text_prompt(text=text_prompt)
57
+ return ann
58
+ else:
59
+ return None
60
+
61
+ def box_prompt(self, bbox):
62
+ assert (bbox[2] != 0 and bbox[3] != 0)
63
+ masks = self.image_embedding.masks.data
64
+ target_height = self.image.shape[0]
65
+ target_width = self.image.shape[1]
66
+ h = masks.shape[1]
67
+ w = masks.shape[2]
68
+ if h != target_height or w != target_width:
69
+ bbox = [
70
+ int(bbox[0] * w / target_width),
71
+ int(bbox[1] * h / target_height),
72
+ int(bbox[2] * w / target_width),
73
+ int(bbox[3] * h / target_height), ]
74
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
75
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
76
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
77
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
78
+
79
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
80
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
81
+
82
+ masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2))
83
+ orig_masks_area = np.sum(masks, axis=(1, 2))
84
+
85
+ union = bbox_area + orig_masks_area - masks_area
86
+ IoUs = masks_area / union
87
+ max_iou_index = np.argmax(IoUs)
88
+
89
+ return np.array([masks[max_iou_index].cpu().numpy()])
90
+
91
+ def point_prompt(self, points, pointlabel): # numpy
92
+
93
+ masks = self._format_results(self.image_embedding[0], 0)
94
+ target_height = self.image.shape[0]
95
+ target_width = self.image.shape[1]
96
+ h = masks[0]['segmentation'].shape[0]
97
+ w = masks[0]['segmentation'].shape[1]
98
+ if h != target_height or w != target_width:
99
+ points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
100
+ onemask = np.zeros((h, w))
101
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
102
+ for i, annotation in enumerate(masks):
103
+ if type(annotation) == dict:
104
+ mask = annotation['segmentation']
105
+ else:
106
+ mask = annotation
107
+ for i, point in enumerate(points):
108
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
109
+ onemask[mask] = 1
110
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
111
+ onemask[mask] = 0
112
+ onemask = onemask >= 1
113
+ return np.array([onemask])
114
+
115
+ def _format_results(self, result, filter=0):
116
+ annotations = []
117
+ n = len(result.masks.data)
118
+ for i in range(n):
119
+ annotation = {}
120
+ mask = result.masks.data[i] == 1.0
121
+
122
+ if np.sum(mask) < filter:
123
+ continue
124
+ annotation['id'] = i
125
+ annotation['segmentation'] = mask
126
+ annotation['bbox'] = result.boxes.data[i]
127
+ annotation['score'] = result.boxes.conf[i]
128
+ annotation['area'] = annotation['segmentation'].sum()
129
+ annotations.append(annotation)
130
+ return annotations
fastsam/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ FastSAM model interface.
4
+
5
+ Usage - Predict:
6
+ from ultralytics import FastSAM
7
+
8
+ model = FastSAM('last.pt')
9
+ results = model.predict('ultralytics/assets/bus.jpg')
10
+ """
11
+
12
+ from ultralytics.yolo.cfg import get_cfg
13
+ from ultralytics.yolo.engine.exporter import Exporter
14
+ from ultralytics.yolo.engine.model import YOLO
15
+ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
16
+ from ultralytics.yolo.utils.checks import check_imgsz
17
+
18
+ from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
19
+ from .predict import FastSAMPredictor
20
+
21
+
22
+ class FastSAM(YOLO):
23
+
24
+ @smart_inference_mode()
25
+ def predict(self, source=None, stream=False, **kwargs):
26
+ """
27
+ Perform prediction using the YOLO model.
28
+
29
+ Args:
30
+ source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
31
+ Accepts all source types accepted by the YOLO model.
32
+ stream (bool): Whether to stream the predictions or not. Defaults to False.
33
+ **kwargs : Additional keyword arguments passed to the predictor.
34
+ Check the 'configuration' section in the documentation for all available options.
35
+
36
+ Returns:
37
+ (List[ultralytics.yolo.engine.results.Results]): The prediction results.
38
+ """
39
+ if source is None:
40
+ source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
41
+ LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
42
+ overrides = self.overrides.copy()
43
+ overrides['conf'] = 0.25
44
+ overrides.update(kwargs) # prefer kwargs
45
+ overrides['mode'] = kwargs.get('mode', 'predict')
46
+ assert overrides['mode'] in ['track', 'predict']
47
+ overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
48
+ self.predictor = FastSAMPredictor(overrides=overrides)
49
+ self.predictor.setup_model(model=self.model, verbose=False)
50
+ try:
51
+ return self.predictor(source, stream=stream)
52
+ except Exception as e:
53
+ return None
54
+
55
+ def train(self, **kwargs):
56
+ """Function trains models but raises an error as FastSAM models do not support training."""
57
+ raise NotImplementedError("Currently, the training codes are on the way.")
58
+
59
+ def val(self, **kwargs):
60
+ """Run validation given dataset."""
61
+ overrides = dict(task='segment', mode='val')
62
+ overrides.update(kwargs) # prefer kwargs
63
+ args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
64
+ args.imgsz = check_imgsz(args.imgsz, max_dim=1)
65
+ validator = FastSAM(args=args)
66
+ validator(model=self.model)
67
+ self.metrics = validator.metrics
68
+ return validator.metrics
69
+
70
+ @smart_inference_mode()
71
+ def export(self, **kwargs):
72
+ """
73
+ Export model.
74
+
75
+ Args:
76
+ **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
77
+ """
78
+ overrides = dict(task='detect')
79
+ overrides.update(kwargs)
80
+ overrides['mode'] = 'export'
81
+ args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
82
+ args.task = self.task
83
+ if args.imgsz == DEFAULT_CFG.imgsz:
84
+ args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
85
+ if args.batch == DEFAULT_CFG.batch:
86
+ args.batch = 1 # default to 1 if not modified
87
+ return Exporter(overrides=args)(model=self.model)
88
+
89
+ def info(self, detailed=False, verbose=True):
90
+ """
91
+ Logs model info.
92
+
93
+ Args:
94
+ detailed (bool): Show detailed information about model.
95
+ verbose (bool): Controls verbosity.
96
+ """
97
+ return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
98
+
99
+ def __call__(self, source=None, stream=False, **kwargs):
100
+ """Calls the 'predict' function with given arguments to perform object detection."""
101
+ return self.predict(source, stream, **kwargs)
102
+
103
+ def __getattr__(self, attr):
104
+ """Raises error if object has no requested attribute."""
105
+ name = self.__class__.__name__
106
+ raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
fastsam/predict.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ultralytics.yolo.engine.results import Results
4
+ from ultralytics.yolo.utils import DEFAULT_CFG, ops
5
+ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
6
+ from .utils import bbox_iou
7
+
8
+ class FastSAMPredictor(DetectionPredictor):
9
+
10
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
11
+ super().__init__(cfg, overrides, _callbacks)
12
+ self.args.task = 'segment'
13
+
14
+ def postprocess(self, preds, img, orig_imgs):
15
+ """TODO: filter by classes."""
16
+ p = ops.non_max_suppression(preds[0],
17
+ self.args.conf,
18
+ self.args.iou,
19
+ agnostic=self.args.agnostic_nms,
20
+ max_det=self.args.max_det,
21
+ nc=len(self.model.names),
22
+ classes=self.args.classes)
23
+
24
+ results = []
25
+ if len(p) == 0 or len(p[0]) == 0:
26
+ print("No object detected.")
27
+ return results
28
+
29
+ full_box = torch.zeros_like(p[0][0])
30
+ full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
31
+ full_box = full_box.view(1, -1)
32
+ critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
33
+ if critical_iou_index.numel() != 0:
34
+ full_box[0][4] = p[0][critical_iou_index][:,4]
35
+ full_box[0][6:] = p[0][critical_iou_index][:,6:]
36
+ p[0][critical_iou_index] = full_box
37
+
38
+ proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
39
+ for i, pred in enumerate(p):
40
+ orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
41
+ path = self.batch[0]
42
+ img_path = path[i] if isinstance(path, list) else path
43
+ if not len(pred): # save empty boxes
44
+ results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
45
+ continue
46
+ if self.args.retina_masks:
47
+ if not isinstance(orig_imgs, torch.Tensor):
48
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
49
+ masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
50
+ else:
51
+ masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
52
+ if not isinstance(orig_imgs, torch.Tensor):
53
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
54
+ results.append(
55
+ Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
56
+ return results
fastsam/prompt.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ from .utils import image_to_np_ndarray
8
+ from PIL import Image
9
+
10
+
11
+ class FastSAMPrompt:
12
+
13
+ def __init__(self, image, results, device='cuda'):
14
+ if isinstance(image, str) or isinstance(image, Image.Image):
15
+ image = image_to_np_ndarray(image)
16
+ self.device = device
17
+ self.results = results
18
+ self.img = image
19
+
20
+ def _segment_image(self, image, bbox):
21
+ if isinstance(image, Image.Image):
22
+ image_array = np.array(image)
23
+ else:
24
+ image_array = image
25
+ segmented_image_array = np.zeros_like(image_array)
26
+ x1, y1, x2, y2 = bbox
27
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
28
+ segmented_image = Image.fromarray(segmented_image_array)
29
+ black_image = Image.new('RGB', image.size, (255, 255, 255))
30
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
31
+ transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
32
+ transparency_mask[y1:y2, x1:x2] = 255
33
+ transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
34
+ black_image.paste(segmented_image, mask=transparency_mask_image)
35
+ return black_image
36
+
37
+ def _format_results(self, result, filter=0):
38
+ annotations = []
39
+ n = len(result.masks.data)
40
+ for i in range(n):
41
+ annotation = {}
42
+ mask = result.masks.data[i] == 1.0
43
+
44
+ if torch.sum(mask) < filter:
45
+ continue
46
+ annotation['id'] = i
47
+ annotation['segmentation'] = mask.cpu().numpy()
48
+ annotation['bbox'] = result.boxes.data[i]
49
+ annotation['score'] = result.boxes.conf[i]
50
+ annotation['area'] = annotation['segmentation'].sum()
51
+ annotations.append(annotation)
52
+ return annotations
53
+
54
+ def filter_masks(annotations): # filte the overlap mask
55
+ annotations.sort(key=lambda x: x['area'], reverse=True)
56
+ to_remove = set()
57
+ for i in range(0, len(annotations)):
58
+ a = annotations[i]
59
+ for j in range(i + 1, len(annotations)):
60
+ b = annotations[j]
61
+ if i != j and j not in to_remove:
62
+ # check if
63
+ if b['area'] < a['area']:
64
+ if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
65
+ to_remove.add(j)
66
+
67
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
68
+
69
+ def _get_bbox_from_mask(self, mask):
70
+ mask = mask.astype(np.uint8)
71
+ contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
72
+ x1, y1, w, h = cv2.boundingRect(contours[0])
73
+ x2, y2 = x1 + w, y1 + h
74
+ if len(contours) > 1:
75
+ for b in contours:
76
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
77
+ # Merge multiple bounding boxes into one.
78
+ x1 = min(x1, x_t)
79
+ y1 = min(y1, y_t)
80
+ x2 = max(x2, x_t + w_t)
81
+ y2 = max(y2, y_t + h_t)
82
+ h = y2 - y1
83
+ w = x2 - x1
84
+ return [x1, y1, x2, y2]
85
+
86
+ def plot_to_result(self,
87
+ annotations,
88
+ bboxes=None,
89
+ points=None,
90
+ point_label=None,
91
+ mask_random_color=True,
92
+ better_quality=True,
93
+ retina=False,
94
+ withContours=True) -> np.ndarray:
95
+ if isinstance(annotations[0], dict):
96
+ annotations = [annotation['segmentation'] for annotation in annotations]
97
+ image = self.img
98
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
99
+ original_h = image.shape[0]
100
+ original_w = image.shape[1]
101
+ if sys.platform == "darwin":
102
+ plt.switch_backend("TkAgg")
103
+ plt.figure(figsize=(original_w / 100, original_h / 100))
104
+ # Add subplot with no margin.
105
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
106
+ plt.margins(0, 0)
107
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
108
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
109
+
110
+ plt.imshow(image)
111
+ if better_quality:
112
+ if isinstance(annotations[0], torch.Tensor):
113
+ annotations = np.array(annotations.cpu())
114
+ for i, mask in enumerate(annotations):
115
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
116
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
117
+ if self.device == 'cpu':
118
+ annotations = np.array(annotations)
119
+ self.fast_show_mask(
120
+ annotations,
121
+ plt.gca(),
122
+ random_color=mask_random_color,
123
+ bboxes=bboxes,
124
+ points=points,
125
+ pointlabel=point_label,
126
+ retinamask=retina,
127
+ target_height=original_h,
128
+ target_width=original_w,
129
+ )
130
+ else:
131
+ if isinstance(annotations[0], np.ndarray):
132
+ annotations = torch.from_numpy(annotations)
133
+ self.fast_show_mask_gpu(
134
+ annotations,
135
+ plt.gca(),
136
+ random_color=mask_random_color,
137
+ bboxes=bboxes,
138
+ points=points,
139
+ pointlabel=point_label,
140
+ retinamask=retina,
141
+ target_height=original_h,
142
+ target_width=original_w,
143
+ )
144
+ if isinstance(annotations, torch.Tensor):
145
+ annotations = annotations.cpu().numpy()
146
+ if withContours:
147
+ contour_all = []
148
+ temp = np.zeros((original_h, original_w, 1))
149
+ for i, mask in enumerate(annotations):
150
+ if type(mask) == dict:
151
+ mask = mask['segmentation']
152
+ annotation = mask.astype(np.uint8)
153
+ if not retina:
154
+ annotation = cv2.resize(
155
+ annotation,
156
+ (original_w, original_h),
157
+ interpolation=cv2.INTER_NEAREST,
158
+ )
159
+ contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
160
+ for contour in contours:
161
+ contour_all.append(contour)
162
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
163
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
164
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
165
+ plt.imshow(contour_mask)
166
+
167
+ plt.axis('off')
168
+ fig = plt.gcf()
169
+ plt.draw()
170
+
171
+ try:
172
+ buf = fig.canvas.tostring_rgb()
173
+ except AttributeError:
174
+ fig.canvas.draw()
175
+ buf = fig.canvas.tostring_rgb()
176
+ cols, rows = fig.canvas.get_width_height()
177
+ img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
178
+ result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
179
+ plt.close()
180
+ return result
181
+
182
+ # Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control.
183
+ def plot(self,
184
+ annotations,
185
+ output_path,
186
+ bboxes=None,
187
+ points=None,
188
+ point_label=None,
189
+ mask_random_color=True,
190
+ better_quality=True,
191
+ retina=False,
192
+ withContours=True):
193
+ if len(annotations) == 0:
194
+ return None
195
+ result = self.plot_to_result(
196
+ annotations,
197
+ bboxes,
198
+ points,
199
+ point_label,
200
+ mask_random_color,
201
+ better_quality,
202
+ retina,
203
+ withContours,
204
+ )
205
+
206
+ path = os.path.dirname(os.path.abspath(output_path))
207
+ if not os.path.exists(path):
208
+ os.makedirs(path)
209
+ result = result[:, :, ::-1]
210
+ cv2.imwrite(output_path, result)
211
+
212
+ # CPU post process
213
+ def fast_show_mask(
214
+ self,
215
+ annotation,
216
+ ax,
217
+ random_color=False,
218
+ bboxes=None,
219
+ points=None,
220
+ pointlabel=None,
221
+ retinamask=True,
222
+ target_height=960,
223
+ target_width=960,
224
+ ):
225
+ msak_sum = annotation.shape[0]
226
+ height = annotation.shape[1]
227
+ weight = annotation.shape[2]
228
+ #Sort annotations based on area.
229
+ areas = np.sum(annotation, axis=(1, 2))
230
+ sorted_indices = np.argsort(areas)
231
+ annotation = annotation[sorted_indices]
232
+
233
+ index = (annotation != 0).argmax(axis=0)
234
+ if random_color:
235
+ color = np.random.random((msak_sum, 1, 1, 3))
236
+ else:
237
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
238
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
239
+ visual = np.concatenate([color, transparency], axis=-1)
240
+ mask_image = np.expand_dims(annotation, -1) * visual
241
+
242
+ show = np.zeros((height, weight, 4))
243
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
244
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
245
+ # Use vectorized indexing to update the values of 'show'.
246
+ show[h_indices, w_indices, :] = mask_image[indices]
247
+ if bboxes is not None:
248
+ for bbox in bboxes:
249
+ x1, y1, x2, y2 = bbox
250
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
251
+ # draw point
252
+ if points is not None:
253
+ plt.scatter(
254
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
255
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
256
+ s=20,
257
+ c='y',
258
+ )
259
+ plt.scatter(
260
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
261
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
262
+ s=20,
263
+ c='m',
264
+ )
265
+
266
+ if not retinamask:
267
+ show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
268
+ ax.imshow(show)
269
+
270
+ def fast_show_mask_gpu(
271
+ self,
272
+ annotation,
273
+ ax,
274
+ random_color=False,
275
+ bboxes=None,
276
+ points=None,
277
+ pointlabel=None,
278
+ retinamask=True,
279
+ target_height=960,
280
+ target_width=960,
281
+ ):
282
+ msak_sum = annotation.shape[0]
283
+ height = annotation.shape[1]
284
+ weight = annotation.shape[2]
285
+ areas = torch.sum(annotation, dim=(1, 2))
286
+ sorted_indices = torch.argsort(areas, descending=False)
287
+ annotation = annotation[sorted_indices]
288
+ # Find the index of the first non-zero value at each position.
289
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
290
+ if random_color:
291
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
292
+ else:
293
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([
294
+ 30 / 255, 144 / 255, 255 / 255]).to(annotation.device)
295
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
296
+ visual = torch.cat([color, transparency], dim=-1)
297
+ mask_image = torch.unsqueeze(annotation, -1) * visual
298
+ # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form.
299
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
300
+ try:
301
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
302
+ except:
303
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
304
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
305
+ # Use vectorized indexing to update the values of 'show'.
306
+ show[h_indices, w_indices, :] = mask_image[indices]
307
+ show_cpu = show.cpu().numpy()
308
+ if bboxes is not None:
309
+ for bbox in bboxes:
310
+ x1, y1, x2, y2 = bbox
311
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
312
+ # draw point
313
+ if points is not None:
314
+ plt.scatter(
315
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
316
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
317
+ s=20,
318
+ c='y',
319
+ )
320
+ plt.scatter(
321
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
322
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
323
+ s=20,
324
+ c='m',
325
+ )
326
+ if not retinamask:
327
+ show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
328
+ ax.imshow(show_cpu)
329
+
330
+ # clip
331
+ @torch.no_grad()
332
+ def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
333
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
334
+ try:
335
+ import clip # for linear_assignment
336
+
337
+ except (ImportError, AssertionError, AttributeError):
338
+ from ultralytics.yolo.utils.checks import check_requirements
339
+
340
+ check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
341
+ import clip
342
+
343
+
344
+ tokenized_text = clip.tokenize([search_text]).to(device)
345
+ stacked_images = torch.stack(preprocessed_images)
346
+ image_features = model.encode_image(stacked_images)
347
+ text_features = model.encode_text(tokenized_text)
348
+ image_features /= image_features.norm(dim=-1, keepdim=True)
349
+ text_features /= text_features.norm(dim=-1, keepdim=True)
350
+ probs = 100.0 * image_features @ text_features.T
351
+ return probs[:, 0].softmax(dim=0)
352
+
353
+ def _crop_image(self, format_results):
354
+
355
+ image = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB))
356
+ ori_w, ori_h = image.size
357
+ annotations = format_results
358
+ mask_h, mask_w = annotations[0]['segmentation'].shape
359
+ if ori_w != mask_w or ori_h != mask_h:
360
+ image = image.resize((mask_w, mask_h))
361
+ cropped_boxes = []
362
+ cropped_images = []
363
+ not_crop = []
364
+ filter_id = []
365
+ # annotations, _ = filter_masks(annotations)
366
+ # filter_id = list(_)
367
+ for _, mask in enumerate(annotations):
368
+ if np.sum(mask['segmentation']) <= 100:
369
+ filter_id.append(_)
370
+ continue
371
+ bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
372
+ cropped_boxes.append(self._segment_image(image, bbox))
373
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
374
+ cropped_images.append(bbox) # Save the bounding box of the cropped image.
375
+
376
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
377
+
378
+ def box_prompt(self, bbox=None, bboxes=None):
379
+ if self.results == None:
380
+ return []
381
+ assert bbox or bboxes
382
+ if bboxes is None:
383
+ bboxes = [bbox]
384
+ max_iou_index = []
385
+ for bbox in bboxes:
386
+ assert (bbox[2] != 0 and bbox[3] != 0)
387
+ masks = self.results[0].masks.data
388
+ target_height = self.img.shape[0]
389
+ target_width = self.img.shape[1]
390
+ h = masks.shape[1]
391
+ w = masks.shape[2]
392
+ if h != target_height or w != target_width:
393
+ bbox = [
394
+ int(bbox[0] * w / target_width),
395
+ int(bbox[1] * h / target_height),
396
+ int(bbox[2] * w / target_width),
397
+ int(bbox[3] * h / target_height), ]
398
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
399
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
400
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
401
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
402
+
403
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
404
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
405
+
406
+ masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
407
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
408
+
409
+ union = bbox_area + orig_masks_area - masks_area
410
+ IoUs = masks_area / union
411
+ max_iou_index.append(int(torch.argmax(IoUs)))
412
+ max_iou_index = list(set(max_iou_index))
413
+ return np.array(masks[max_iou_index].cpu().numpy())
414
+
415
+ def point_prompt(self, points, pointlabel): # numpy
416
+ if self.results == None:
417
+ return []
418
+ masks = self._format_results(self.results[0], 0)
419
+ target_height = self.img.shape[0]
420
+ target_width = self.img.shape[1]
421
+ h = masks[0]['segmentation'].shape[0]
422
+ w = masks[0]['segmentation'].shape[1]
423
+ if h != target_height or w != target_width:
424
+ points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
425
+ onemask = np.zeros((h, w))
426
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
427
+ for i, annotation in enumerate(masks):
428
+ if type(annotation) == dict:
429
+ mask = annotation['segmentation']
430
+ else:
431
+ mask = annotation
432
+ for i, point in enumerate(points):
433
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
434
+ onemask[mask] = 1
435
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
436
+ onemask[mask] = 0
437
+ onemask = onemask >= 1
438
+ return np.array([onemask])
439
+
440
+ def text_prompt(self, text):
441
+ if self.results == None:
442
+ return []
443
+ format_results = self._format_results(self.results[0], 0)
444
+ cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
445
+ clip_model, preprocess = clip.load('ViT-B/32', device=self.device)
446
+ scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
447
+ max_idx = scores.argsort()
448
+ max_idx = max_idx[-1]
449
+ max_idx += sum(np.array(filter_id) <= int(max_idx))
450
+ return np.array([annotations[max_idx]['segmentation']])
451
+
452
+ def everything_prompt(self):
453
+ if self.results == None:
454
+ return []
455
+ return self.results[0].masks.data
456
+
fastsam/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+
5
+
6
+ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
7
+ '''Adjust bounding boxes to stick to image border if they are within a certain threshold.
8
+ Args:
9
+ boxes: (n, 4)
10
+ image_shape: (height, width)
11
+ threshold: pixel threshold
12
+ Returns:
13
+ adjusted_boxes: adjusted bounding boxes
14
+ '''
15
+
16
+ # Image dimensions
17
+ h, w = image_shape
18
+
19
+ # Adjust boxes
20
+ boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(
21
+ 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1
22
+ boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(
23
+ 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1
24
+ boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(
25
+ w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2
26
+ boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(
27
+ h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2
28
+
29
+ return boxes
30
+
31
+
32
+
33
+ def convert_box_xywh_to_xyxy(box):
34
+ x1 = box[0]
35
+ y1 = box[1]
36
+ x2 = box[0] + box[2]
37
+ y2 = box[1] + box[3]
38
+ return [x1, y1, x2, y2]
39
+
40
+
41
+ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
42
+ '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
43
+ Args:
44
+ box1: (4, )
45
+ boxes: (n, 4)
46
+ Returns:
47
+ high_iou_indices: Indices of boxes with IoU > thres
48
+ '''
49
+ boxes = adjust_bboxes_to_image_border(boxes, image_shape)
50
+ # obtain coordinates for intersections
51
+ x1 = torch.max(box1[0], boxes[:, 0])
52
+ y1 = torch.max(box1[1], boxes[:, 1])
53
+ x2 = torch.min(box1[2], boxes[:, 2])
54
+ y2 = torch.min(box1[3], boxes[:, 3])
55
+
56
+ # compute the area of intersection
57
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
58
+
59
+ # compute the area of both individual boxes
60
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
61
+ box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
62
+
63
+ # compute the area of union
64
+ union = box1_area + box2_area - intersection
65
+
66
+ # compute the IoU
67
+ iou = intersection / union # Should be shape (n, )
68
+ if raw_output:
69
+ if iou.numel() == 0:
70
+ return 0
71
+ return iou
72
+
73
+ # get indices of boxes with IoU > thres
74
+ high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
75
+
76
+ return high_iou_indices
77
+
78
+
79
+ def image_to_np_ndarray(image):
80
+ if type(image) is str:
81
+ return np.array(Image.open(image))
82
+ elif issubclass(type(image), Image.Image):
83
+ return np.array(image)
84
+ elif type(image) is np.ndarray:
85
+ return image
86
+ return None
image_1.jpg ADDED
image_10.jpg ADDED
image_2.jpg ADDED
image_3.jpg ADDED
image_4.jpg ADDED
image_5.jpg ADDED
image_6.jpg ADDED
image_7.jpg ADDED
image_8.jpg ADDED
image_9.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ anyio==4.4.0
5
+ attrs==23.2.0
6
+ certifi==2024.6.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ contourpy==1.2.1
11
+ cycler==0.12.1
12
+ dnspython==2.6.1
13
+ email_validator==2.1.1
14
+ fastapi==0.111.0
15
+ fastapi-cli==0.0.4
16
+ ffmpy==0.3.2
17
+ filelock==3.13.1
18
+ fonttools==4.53.0
19
+ fsspec==2024.2.0
20
+ gradio==4.36.1
21
+ gradio_client==1.0.1
22
+ h11==0.14.0
23
+ httpcore==1.0.5
24
+ httptools==0.6.1
25
+ httpx==0.27.0
26
+ huggingface-hub==0.23.4
27
+ idna==3.7
28
+ importlib_resources==6.4.0
29
+ intel-openmp==2021.4.0
30
+ Jinja2==3.1.3
31
+ jsonschema==4.22.0
32
+ jsonschema-specifications==2023.12.1
33
+ kiwisolver==1.4.5
34
+ markdown-it-py==3.0.0
35
+ MarkupSafe==2.1.5
36
+ matplotlib==3.9.0
37
+ mdurl==0.1.2
38
+ mkl==2021.4.0
39
+ mpmath==1.3.0
40
+ networkx==3.2.1
41
+ numpy==1.26.3
42
+ orjson==3.10.5
43
+ packaging==24.1
44
+ pandas==2.2.2
45
+ pillow==10.2.0
46
+ pydantic==2.7.4
47
+ pydantic_core==2.18.4
48
+ pydub==0.25.1
49
+ Pygments==2.18.0
50
+ pyparsing==3.1.2
51
+ python-dateutil==2.9.0.post0
52
+ python-dotenv==1.0.1
53
+ python-multipart==0.0.9
54
+ pytz==2024.1
55
+ PyYAML==6.0.1
56
+ referencing==0.35.1
57
+ regex==2024.5.15
58
+ requests==2.32.3
59
+ rich==13.7.1
60
+ rpds-py==0.18.1
61
+ ruff==0.4.9
62
+ semantic-version==2.10.0
63
+ shellingham==1.5.4
64
+ six==1.16.0
65
+ sniffio==1.3.1
66
+ starlette==0.37.2
67
+ sympy==1.12
68
+ tbb==2021.11.0
69
+ tiktoken==0.7.0
70
+ tomlkit==0.12.0
71
+ toolz==0.12.1
72
+ torch==2.3.1
73
+ torchaudio==2.3.1
74
+ torchvision==0.18.1
75
+ tqdm==4.66.4
76
+ typer==0.12.3
77
+ typing_extensions==4.9.0
78
+ tzdata==2024.1
79
+ ujson==5.10.0
80
+ urllib3==2.2.1
81
+ uvicorn==0.30.1
82
+ watchfiles==0.22.0
83
+ websockets==11.0.3
weights/FastSAM-x.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0be4e7ddbe4c15333d15a859c676d053c486d0a746a3be6a7a9790d52a9b6d7
3
+ size 144943063