|
import numpy as np |
|
import PIL |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
def pil_to_tensor( |
|
img: Image.Image, |
|
target_image_size=512, |
|
lock_ratio=True, |
|
center_crop=True, |
|
padding=False, |
|
standardize=True, |
|
**kwarg |
|
) -> torch.Tensor: |
|
if img.mode != "RGB": |
|
img = img.convert("RGB") |
|
|
|
if isinstance(target_image_size, int): |
|
target_size = (target_image_size, target_image_size) |
|
if target_image_size < 0: |
|
target_size = img.size |
|
else: |
|
target_size = target_image_size |
|
|
|
if lock_ratio: |
|
original_width, original_height = img.size |
|
target_width, target_height = target_size |
|
|
|
scale_w = target_width / original_width |
|
scale_h = target_height / original_height |
|
|
|
if center_crop: |
|
scale = max(scale_w, scale_h) |
|
elif padding: |
|
scale = min(scale_w, scale_h) |
|
else: |
|
scale = 1.0 |
|
|
|
new_size = (round(original_width * scale), round(original_height * scale)) |
|
img = img.resize(new_size, Image.LANCZOS) |
|
|
|
if center_crop: |
|
left = (img.width - target_width) // 2 |
|
top = (img.height - target_height) // 2 |
|
img = img.crop((left, top, left + target_width, top + target_height)) |
|
elif padding: |
|
new_img = Image.new("RGB", target_size, (0, 0, 0)) |
|
left = (target_width - img.width) // 2 |
|
top = (target_height - img.height) // 2 |
|
new_img.paste(img, (left, top)) |
|
img = new_img |
|
else: |
|
img = img.resize(target_size, Image.LANCZOS) |
|
|
|
np_img = np.array(img) / 255.0 |
|
if standardize: |
|
np_img = np_img * 2 - 1 |
|
tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() |
|
|
|
return tensor_img |
|
|
|
|
|
def tensor_to_pil(chw_tensor: torch.Tensor, standardize=True, **kwarg) -> PIL.Image: |
|
|
|
detached_chw_tensor = chw_tensor.detach().cpu() |
|
|
|
|
|
if standardize: |
|
normalized_chw_tensor = ( |
|
torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0 |
|
) / 2.0 |
|
else: |
|
normalized_chw_tensor = torch.clamp(detached_chw_tensor, 0.0, 1.0) |
|
|
|
|
|
hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy() |
|
|
|
|
|
image_array_uint8 = (hwc_array * 255).astype(np.uint8) |
|
|
|
|
|
pil_image = Image.fromarray(image_array_uint8) |
|
|
|
|
|
if pil_image.mode != "RGB": |
|
pil_image = pil_image.convert("RGB") |
|
|
|
return pil_image |
|
|