|
import torch |
|
import torch.nn.functional as F |
|
import torchvision |
|
|
|
|
|
def remap_image_torch(image): |
|
image_torch = ((image + 1) / 2.0) * 255.0 |
|
image_torch = torch.clip(image_torch, 0, 255).to(torch.uint8) |
|
return image_torch |
|
|
|
|
|
class CenterCrop(torch.nn.Module): |
|
"""Crops the given image at the center. Allows to crop to the maximum possible size. |
|
Args: |
|
size (sequence or int): Desired output size of the crop. If size is an |
|
int instead of sequence like (h, w), a square crop (size, size) is |
|
made. |
|
ratio (str): Desired output ratio of the crop that will do the maximum possible crop with the given ratio. |
|
""" |
|
|
|
def __init__(self, size=None, ratio="1:1"): |
|
super().__init__() |
|
self.size = size |
|
self.ratio = ratio |
|
|
|
def forward(self, img): |
|
""" |
|
Args: |
|
img (PIL Image or Tensor): Image to be cropped. |
|
|
|
Returns: |
|
PIL Image or Tensor: Cropped image. |
|
""" |
|
if self.size is None: |
|
if isinstance(img, torch.Tensor): |
|
h, w = img.shape[-2:] |
|
else: |
|
w, h = img.size |
|
ratio = self.ratio.split(":") |
|
ratio = float(ratio[0]) / float(ratio[1]) |
|
ratioed_w = int(h * ratio) |
|
ratioed_h = int(w / ratio) |
|
if w >= h: |
|
if ratioed_h <= h: |
|
size = (ratioed_h, w) |
|
else: |
|
size = (h, ratioed_w) |
|
else: |
|
if ratioed_w <= w: |
|
size = (h, ratioed_w) |
|
else: |
|
size = (ratioed_h, w) |
|
else: |
|
size = self.size |
|
return torchvision.transforms.functional.center_crop(img, size) |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size})" |
|
|