Spaces:
Running
on
Zero
Running
on
Zero
| ''' | |
| crop | |
| for torch tensor | |
| Given image, bbox(center, bboxsize) | |
| return: cropped image, tform(used for transform the keypoint accordingly) | |
| only support crop to squared images | |
| ''' | |
| import torch | |
| from kornia.geometry.transform.imgwarp import (warp_perspective, | |
| get_perspective_transform, | |
| warp_affine) | |
| def points2bbox(points, points_scale=None): | |
| if points_scale: | |
| assert points_scale[0] == points_scale[1] | |
| points = points.clone() | |
| points[:, :, :2] = (points[:, :, :2] * 0.5 + 0.5) * points_scale[0] | |
| min_coords, _ = torch.min(points, dim=1) | |
| xmin, ymin = min_coords[:, 0], min_coords[:, 1] | |
| max_coords, _ = torch.max(points, dim=1) | |
| xmax, ymax = max_coords[:, 0], max_coords[:, 1] | |
| center = torch.stack([xmax + xmin, ymax + ymin], dim=-1) * 0.5 | |
| width = (xmax - xmin) | |
| height = (ymax - ymin) | |
| # Convert the bounding box to a square box | |
| size = torch.max(width, height).unsqueeze(-1) | |
| return center, size | |
| def augment_bbox(center, bbox_size, scale=[1.0, 1.0], trans_scale=0.): | |
| batch_size = center.shape[0] | |
| trans_scale = (torch.rand([batch_size, 2], device=center.device) * 2. - | |
| 1.) * trans_scale | |
| center = center + trans_scale * bbox_size # 0.5 | |
| scale = torch.rand([batch_size, 1], device=center.device) * \ | |
| (scale[1] - scale[0]) + scale[0] | |
| size = bbox_size * scale | |
| return center, size | |
| def crop_tensor(image, | |
| center, | |
| bbox_size, | |
| crop_size, | |
| interpolation='bilinear', | |
| align_corners=False): | |
| ''' for batch image | |
| Args: | |
| image (torch.Tensor): the reference tensor of shape BXHxWXC. | |
| center: [bz, 2] | |
| bboxsize: [bz, 1] | |
| crop_size; | |
| interpolation (str): Interpolation flag. Default: 'bilinear'. | |
| align_corners (bool): mode for grid_generation. Default: False. See | |
| https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.interpolate for details | |
| Returns: | |
| cropped_image | |
| tform | |
| ''' | |
| dtype = image.dtype | |
| device = image.device | |
| batch_size = image.shape[0] | |
| # points: top-left, top-right, bottom-right, bottom-left | |
| src_pts = torch.zeros([4, 2], dtype=dtype, | |
| device=device).unsqueeze(0).expand( | |
| batch_size, -1, -1).contiguous() | |
| src_pts[:, 0, :] = center - bbox_size * 0.5 # / (self.crop_size - 1) | |
| src_pts[:, 1, 0] = center[:, 0] + bbox_size[:, 0] * 0.5 | |
| src_pts[:, 1, 1] = center[:, 1] - bbox_size[:, 0] * 0.5 | |
| src_pts[:, 2, :] = center + bbox_size * 0.5 | |
| src_pts[:, 3, 0] = center[:, 0] - bbox_size[:, 0] * 0.5 | |
| src_pts[:, 3, 1] = center[:, 1] + bbox_size[:, 0] * 0.5 | |
| DST_PTS = torch.tensor([[ | |
| [0, 0], | |
| [crop_size - 1, 0], | |
| [crop_size - 1, crop_size - 1], | |
| [0, crop_size - 1], | |
| ]], | |
| dtype=dtype, | |
| device=device).expand(batch_size, -1, -1) | |
| # estimate transformation between points | |
| dst_trans_src = get_perspective_transform(src_pts, DST_PTS) | |
| # simulate broadcasting | |
| # dst_trans_src = dst_trans_src.expand(batch_size, -1, -1) | |
| # warp images | |
| cropped_image = warp_affine(image, | |
| dst_trans_src[:, :2, :], | |
| (crop_size, crop_size), | |
| mode=interpolation, | |
| align_corners=align_corners) | |
| tform = torch.transpose(dst_trans_src, 2, 1) | |
| # tform = torch.inverse(dst_trans_src) | |
| return cropped_image, tform | |
| class Cropper(object): | |
| def __init__(self, crop_size, scale=[1, 1], trans_scale=0.): | |
| self.crop_size = crop_size | |
| self.scale = scale | |
| self.trans_scale = trans_scale | |
| def crop(self, image, points, points_scale=None): | |
| # points to bbox | |
| center, bbox_size = points2bbox(points.clone(), points_scale) | |
| # argument bbox. TODO: add rotation? | |
| center, bbox_size = augment_bbox(center, | |
| bbox_size, | |
| scale=self.scale, | |
| trans_scale=self.trans_scale) | |
| # crop | |
| cropped_image, tform = crop_tensor(image, center, bbox_size, | |
| self.crop_size) | |
| return cropped_image, tform | |
| def transform_points(self, | |
| points, | |
| tform, | |
| points_scale=None, | |
| normalize=True): | |
| points_2d = points[:, :, :2] | |
| #'input points must use original range' | |
| if points_scale: | |
| assert points_scale[0] == points_scale[1] | |
| points_2d = (points_2d * 0.5 + 0.5) * points_scale[0] | |
| batch_size, n_points, _ = points.shape | |
| trans_points_2d = torch.bmm( | |
| torch.cat([ | |
| points_2d, | |
| torch.ones([batch_size, n_points, 1], | |
| device=points.device, | |
| dtype=points.dtype) | |
| ], | |
| dim=-1), tform) | |
| trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], | |
| dim=-1) | |
| if normalize: | |
| trans_points[:, :, :2] = trans_points[:, :, :2] / \ | |
| self.crop_size*2 - 1 | |
| return trans_points | |
| def transform_points(points, tform, points_scale=None): | |
| points_2d = points[:, :, :2] | |
| #'input points must use original range' | |
| if points_scale: | |
| assert points_scale[0] == points_scale[1] | |
| points_2d = (points_2d * 0.5 + 0.5) * points_scale[0] | |
| # import ipdb; ipdb.set_trace() | |
| batch_size, n_points, _ = points.shape | |
| trans_points_2d = torch.bmm( | |
| torch.cat([ | |
| points_2d, | |
| torch.ones([batch_size, n_points, 1], | |
| device=points.device, | |
| dtype=points.dtype) | |
| ], | |
| dim=-1), tform) | |
| trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], | |
| dim=-1) | |
| return trans_points | |