Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import numpy as np | |
| import torch | |
| # from shapely import affinity | |
| # from shapely.geometry import box | |
| # transpose | |
| FLIP_LEFT_RIGHT = 0 | |
| FLIP_TOP_BOTTOM = 1 | |
| class BoxList(object): | |
| """ | |
| This class represents a set of bounding boxes. | |
| The bounding boxes are represented as a Nx4 Tensor. | |
| In order ot uniquely determine the bounding boxes with respect | |
| to an image, we also store the corresponding image dimensions. | |
| They can contain extra information that is specific to each bounding box, such as | |
| labels. | |
| """ | |
| def __init__(self, bbox, image_size, mode="xyxy", use_char_ann=True, is_fake=False): | |
| device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu") | |
| bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device) | |
| if bbox.ndimension() != 2: | |
| raise ValueError( | |
| "bbox should have 2 dimensions, got {}".format(bbox.ndimension()) | |
| ) | |
| if bbox.size(-1) != 4: | |
| raise ValueError( | |
| "last dimenion of bbox should have a " | |
| "size of 4, got {}".format(bbox.size(-1)) | |
| ) | |
| if mode not in ("xyxy", "xywh"): | |
| raise ValueError("mode should be 'xyxy' or 'xywh'") | |
| self.bbox = bbox | |
| self.size = image_size # (image_width, image_height) | |
| self.mode = mode | |
| self.extra_fields = {} | |
| self.use_char_ann = use_char_ann | |
| def set_size(self, size): | |
| self.size = size | |
| bbox = BoxList( | |
| self.bbox, size, mode=self.mode, use_char_ann=self.use_char_ann | |
| ) | |
| for k, v in self.extra_fields.items(): | |
| if not isinstance(v, torch.Tensor): | |
| v = v.set_size(size) | |
| bbox.add_field(k, v) | |
| return bbox.convert(self.mode) | |
| def add_field(self, field, field_data): | |
| self.extra_fields[field] = field_data | |
| def get_field(self, field): | |
| return self.extra_fields[field] | |
| def has_field(self, field): | |
| return field in self.extra_fields | |
| def fields(self): | |
| return list(self.extra_fields.keys()) | |
| def _copy_extra_fields(self, bbox): | |
| for k, v in bbox.extra_fields.items(): | |
| self.extra_fields[k] = v | |
| def convert(self, mode): | |
| if mode not in ("xyxy", "xywh"): | |
| raise ValueError("mode should be 'xyxy' or 'xywh'") | |
| if mode == self.mode: | |
| return self | |
| # we only have two modes, so don't need to check | |
| # self.mode | |
| xmin, ymin, xmax, ymax = self._split_into_xyxy() | |
| if mode == "xyxy": | |
| bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1) | |
| bbox = BoxList(bbox, self.size, mode=mode, use_char_ann=self.use_char_ann) | |
| else: | |
| TO_REMOVE = 1 | |
| bbox = torch.cat( | |
| (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1 | |
| ) | |
| bbox = BoxList(bbox, self.size, mode=mode, use_char_ann=self.use_char_ann) | |
| bbox._copy_extra_fields(self) | |
| return bbox | |
| def _split_into_xyxy(self): | |
| if self.mode == "xyxy": | |
| xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1) | |
| return xmin, ymin, xmax, ymax | |
| elif self.mode == "xywh": | |
| TO_REMOVE = 1 | |
| xmin, ymin, w, h = self.bbox.split(1, dim=-1) | |
| return ( | |
| xmin, | |
| ymin, | |
| xmin + (w - TO_REMOVE).clamp(min=0), | |
| ymin + (h - TO_REMOVE).clamp(min=0), | |
| ) | |
| else: | |
| raise RuntimeError("Should not be here") | |
| def resize(self, size, *args, **kwargs): | |
| """ | |
| Returns a resized copy of this bounding box | |
| :param size: The requested size in pixels, as a 2-tuple: | |
| (width, height). | |
| """ | |
| ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) | |
| if ratios[0] == ratios[1]: | |
| ratio = ratios[0] | |
| scaled_box = self.bbox * ratio | |
| bbox = BoxList( | |
| scaled_box, size, mode=self.mode, use_char_ann=self.use_char_ann | |
| ) | |
| # bbox._copy_extra_fields(self) | |
| for k, v in self.extra_fields.items(): | |
| if not isinstance(v, torch.Tensor): | |
| v = v.resize(size, *args, **kwargs) | |
| bbox.add_field(k, v) | |
| return bbox | |
| ratio_width, ratio_height = ratios | |
| xmin, ymin, xmax, ymax = self._split_into_xyxy() | |
| scaled_xmin = xmin * ratio_width | |
| scaled_xmax = xmax * ratio_width | |
| scaled_ymin = ymin * ratio_height | |
| scaled_ymax = ymax * ratio_height | |
| scaled_box = torch.cat( | |
| (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 | |
| ) | |
| bbox = BoxList(scaled_box, size, mode="xyxy", use_char_ann=self.use_char_ann) | |
| # bbox._copy_extra_fields(self) | |
| for k, v in self.extra_fields.items(): | |
| if not isinstance(v, torch.Tensor): | |
| v = v.resize(size, *args, **kwargs) | |
| bbox.add_field(k, v) | |
| return bbox.convert(self.mode) | |
| def poly2box(self, poly): | |
| xmin = min(poly[0::2]) | |
| xmax = max(poly[0::2]) | |
| ymin = min(poly[1::2]) | |
| ymax = max(poly[1::2]) | |
| return [xmin, ymin, xmax, ymax] | |
| def rotate(self, angle, r_c, start_h, start_w): | |
| masks = self.extra_fields["masks"] | |
| masks = masks.rotate(angle, r_c, start_h, start_w) | |
| polys = masks.polygons | |
| boxes = [] | |
| for poly in polys: | |
| box = self.poly2box(poly.polygons[0].numpy()) | |
| boxes.append(box) | |
| self.size = (r_c[0] * 2, r_c[1] * 2) | |
| bbox = BoxList(boxes, self.size, mode="xyxy", use_char_ann=self.use_char_ann) | |
| for k, v in self.extra_fields.items(): | |
| if k == "masks": | |
| v = masks | |
| else: | |
| if self.use_char_ann: | |
| if not isinstance(v, torch.Tensor): | |
| v = v.rotate(angle, r_c, start_h, start_w) | |
| else: | |
| if not isinstance(v, torch.Tensor) and k != "char_masks": | |
| v = v.rotate(angle, r_c, start_h, start_w) | |
| bbox.add_field(k, v) | |
| return bbox.convert(self.mode) | |
| def transpose(self, method): | |
| """ | |
| Transpose bounding box (flip or rotate in 90 degree steps) | |
| :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`, | |
| :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`, | |
| :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`, | |
| :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`. | |
| """ | |
| if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
| raise NotImplementedError( | |
| "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
| ) | |
| image_width, image_height = self.size | |
| xmin, ymin, xmax, ymax = self._split_into_xyxy() | |
| if method == FLIP_LEFT_RIGHT: | |
| TO_REMOVE = 1 | |
| transposed_xmin = image_width - xmax - TO_REMOVE | |
| transposed_xmax = image_width - xmin - TO_REMOVE | |
| transposed_ymin = ymin | |
| transposed_ymax = ymax | |
| elif method == FLIP_TOP_BOTTOM: | |
| transposed_xmin = xmin | |
| transposed_xmax = xmax | |
| transposed_ymin = image_height - ymax | |
| transposed_ymax = image_height - ymin | |
| transposed_boxes = torch.cat( | |
| (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 | |
| ) | |
| bbox = BoxList( | |
| transposed_boxes, self.size, mode="xyxy", use_char_ann=self.use_char_ann | |
| ) | |
| # bbox._copy_extra_fields(self) | |
| for k, v in self.extra_fields.items(): | |
| if not isinstance(v, torch.Tensor): | |
| v = v.transpose(method) | |
| bbox.add_field(k, v) | |
| return bbox.convert(self.mode) | |
| def crop(self, box): | |
| """ | |
| Cropss a rectangular region from this bounding box. The box is a | |
| 4-tuple defining the left, upper, right, and lower pixel | |
| coordinate. | |
| """ | |
| xmin, ymin, xmax, ymax = self._split_into_xyxy() | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| cropped_xmin = (xmin - box[0]).clamp(min=0, max=w) | |
| cropped_ymin = (ymin - box[1]).clamp(min=0, max=h) | |
| cropped_xmax = (xmax - box[0]).clamp(min=0, max=w) | |
| cropped_ymax = (ymax - box[1]).clamp(min=0, max=h) | |
| keep_ind = None | |
| not_empty = np.where( | |
| (cropped_xmin != cropped_xmax) & (cropped_ymin != cropped_ymax) | |
| )[0] | |
| if len(not_empty) > 0: | |
| keep_ind = not_empty | |
| cropped_box = torch.cat( | |
| (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1 | |
| ) | |
| cropped_box = cropped_box[not_empty] | |
| bbox = BoxList(cropped_box, (w, h), mode="xyxy", use_char_ann=self.use_char_ann) | |
| # bbox._copy_extra_fields(self) | |
| for k, v in self.extra_fields.items(): | |
| if self.use_char_ann: | |
| if not isinstance(v, torch.Tensor): | |
| v = v.crop(box, keep_ind) | |
| else: | |
| if not isinstance(v, torch.Tensor) and k != "char_masks": | |
| v = v.crop(box, keep_ind) | |
| bbox.add_field(k, v) | |
| return bbox.convert(self.mode) | |
| # Tensor-like methods | |
| def to(self, device): | |
| bbox = BoxList(self.bbox.to(device), self.size, self.mode, self.use_char_ann) | |
| for k, v in self.extra_fields.items(): | |
| if hasattr(v, "to"): | |
| v = v.to(device) | |
| bbox.add_field(k, v) | |
| return bbox | |
| def __getitem__(self, item): | |
| bbox = BoxList(self.bbox[item], self.size, self.mode, self.use_char_ann) | |
| for k, v in self.extra_fields.items(): | |
| bbox.add_field(k, v[item]) | |
| return bbox | |
| def __len__(self): | |
| return self.bbox.shape[0] | |
| def clip_to_image(self, remove_empty=True): | |
| TO_REMOVE = 1 | |
| self.bbox[:, 0].clamp_(min=0, max=self.size[0] - TO_REMOVE) | |
| self.bbox[:, 1].clamp_(min=0, max=self.size[1] - TO_REMOVE) | |
| self.bbox[:, 2].clamp_(min=0, max=self.size[0] - TO_REMOVE) | |
| self.bbox[:, 3].clamp_(min=0, max=self.size[1] - TO_REMOVE) | |
| if remove_empty: | |
| box = self.bbox | |
| keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) | |
| return self[keep] | |
| return self | |
| def area(self): | |
| TO_REMOVE = 1 | |
| box = self.bbox | |
| area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE) | |
| return area | |
| def copy_with_fields(self, fields): | |
| bbox = BoxList(self.bbox, self.size, self.mode, self.use_char_ann) | |
| if not isinstance(fields, (list, tuple)): | |
| fields = [fields] | |
| for field in fields: | |
| bbox.add_field(field, self.get_field(field)) | |
| return bbox | |
| def __repr__(self): | |
| s = self.__class__.__name__ + "(" | |
| s += "num_boxes={}, ".format(len(self)) | |
| s += "image_width={}, ".format(self.size[0]) | |
| s += "image_height={}, ".format(self.size[1]) | |
| s += "mode={})".format(self.mode) | |
| return s | |
| if __name__ == "__main__": | |
| bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10)) | |
| s_bbox = bbox.resize((5, 5)) | |
| print(s_bbox) | |
| print(s_bbox.bbox) | |
| t_bbox = bbox.transpose(0) | |
| print(t_bbox) | |
| print(t_bbox.bbox) | |