|
|
|
import argparse
|
|
import binascii
|
|
import os
|
|
import os.path as osp
|
|
import torchvision.transforms.functional as TF
|
|
import torch.nn.functional as F
|
|
|
|
import imageio
|
|
import torch
|
|
import decord
|
|
import torchvision
|
|
from PIL import Image
|
|
import numpy as np
|
|
from rembg import remove, new_session
|
|
import random
|
|
|
|
__all__ = ['cache_video', 'cache_image', 'str2bool']
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
def seed_everything(seed: int):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
if torch.backends.mps.is_available():
|
|
torch.mps.manual_seed(seed)
|
|
|
|
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
|
|
import math
|
|
|
|
if video_fps < target_fps :
|
|
video_fps = target_fps
|
|
|
|
video_frame_duration = 1 /video_fps
|
|
target_frame_duration = 1 / target_fps
|
|
|
|
target_time = start_target_frame * target_frame_duration
|
|
frame_no = math.ceil(target_time / video_frame_duration)
|
|
cur_time = frame_no * video_frame_duration
|
|
frame_ids =[]
|
|
while True:
|
|
if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count :
|
|
break
|
|
diff = round( (target_time -cur_time) / video_frame_duration , 5)
|
|
add_frames_count = math.ceil( diff)
|
|
frame_no += add_frames_count
|
|
if frame_no >= video_frames_count:
|
|
break
|
|
frame_ids.append(frame_no)
|
|
cur_time += add_frames_count * video_frame_duration
|
|
target_time += target_frame_duration
|
|
frame_ids = frame_ids[:max_target_frames_count]
|
|
return frame_ids
|
|
|
|
def get_video_frame(file_name, frame_no):
|
|
decord.bridge.set_bridge('torch')
|
|
reader = decord.VideoReader(file_name)
|
|
|
|
frame = reader.get_batch([frame_no]).squeeze(0)
|
|
img = Image.fromarray(frame.numpy().astype(np.uint8))
|
|
return img
|
|
|
|
def resize_lanczos(img, h, w):
|
|
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
|
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
|
|
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
|
|
|
|
|
def remove_background(img, session=None):
|
|
if session ==None:
|
|
session = new_session()
|
|
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
|
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
|
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
|
|
|
|
|
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
|
|
if fit_into_canvas:
|
|
scale1 = min(canvas_height / height, canvas_width / width)
|
|
scale2 = min(canvas_width / height, canvas_height / width)
|
|
scale = max(scale1, scale2)
|
|
else:
|
|
scale = (canvas_height * canvas_width / (height * width))**(1/2)
|
|
|
|
new_height = round( height * scale / block_size) * block_size
|
|
new_width = round( width * scale / block_size) * block_size
|
|
return new_height, new_width
|
|
|
|
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
|
|
if rm_background > 0:
|
|
session = new_session()
|
|
|
|
output_list =[]
|
|
for i, img in enumerate(img_list):
|
|
width, height = img.size
|
|
|
|
if fit_into_canvas:
|
|
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
|
|
scale = min(budget_height / height, budget_width / width)
|
|
new_height = int(height * scale)
|
|
new_width = int(width * scale)
|
|
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
|
top = (budget_height - new_height) // 2
|
|
left = (budget_width - new_width) // 2
|
|
white_canvas[top:top + new_height, left:left + new_width] = np.array(resized_image)
|
|
resized_image = Image.fromarray(white_canvas)
|
|
else:
|
|
scale = (budget_height * budget_width / (height * width))**(1/2)
|
|
new_height = int( round(height * scale / 16) * 16)
|
|
new_width = int( round(width * scale / 16) * 16)
|
|
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
|
if rm_background == 1 or rm_background == 2 and i > 0 :
|
|
|
|
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
|
output_list.append(resized_image)
|
|
return output_list
|
|
|
|
|
|
def rand_name(length=8, suffix=''):
|
|
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
|
if suffix:
|
|
if not suffix.startswith('.'):
|
|
suffix = '.' + suffix
|
|
name += suffix
|
|
return name
|
|
|
|
|
|
def cache_video(tensor,
|
|
save_file=None,
|
|
fps=30,
|
|
suffix='.mp4',
|
|
nrow=8,
|
|
normalize=True,
|
|
value_range=(-1, 1),
|
|
retry=5):
|
|
|
|
cache_file = osp.join('/tmp', rand_name(
|
|
suffix=suffix)) if save_file is None else save_file
|
|
|
|
|
|
error = None
|
|
for _ in range(retry):
|
|
try:
|
|
|
|
tensor = tensor.clamp(min(value_range), max(value_range))
|
|
tensor = torch.stack([
|
|
torchvision.utils.make_grid(
|
|
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
|
for u in tensor.unbind(2)
|
|
],
|
|
dim=1).permute(1, 2, 3, 0)
|
|
tensor = (tensor * 255).type(torch.uint8).cpu()
|
|
|
|
|
|
writer = imageio.get_writer(
|
|
cache_file, fps=fps, codec='libx264', quality=8)
|
|
for frame in tensor.numpy():
|
|
writer.append_data(frame)
|
|
writer.close()
|
|
return cache_file
|
|
except Exception as e:
|
|
error = e
|
|
continue
|
|
else:
|
|
print(f'cache_video failed, error: {error}', flush=True)
|
|
return None
|
|
|
|
|
|
def cache_image(tensor,
|
|
save_file,
|
|
nrow=8,
|
|
normalize=True,
|
|
value_range=(-1, 1),
|
|
retry=5):
|
|
|
|
suffix = osp.splitext(save_file)[1]
|
|
if suffix.lower() not in [
|
|
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
|
|
]:
|
|
suffix = '.png'
|
|
|
|
|
|
error = None
|
|
for _ in range(retry):
|
|
try:
|
|
tensor = tensor.clamp(min(value_range), max(value_range))
|
|
torchvision.utils.save_image(
|
|
tensor,
|
|
save_file,
|
|
nrow=nrow,
|
|
normalize=normalize,
|
|
value_range=value_range)
|
|
return save_file
|
|
except Exception as e:
|
|
error = e
|
|
continue
|
|
|
|
|
|
def str2bool(v):
|
|
"""
|
|
Convert a string to a boolean.
|
|
|
|
Supported true values: 'yes', 'true', 't', 'y', '1'
|
|
Supported false values: 'no', 'false', 'f', 'n', '0'
|
|
|
|
Args:
|
|
v (str): String to convert.
|
|
|
|
Returns:
|
|
bool: Converted boolean value.
|
|
|
|
Raises:
|
|
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
|
"""
|
|
if isinstance(v, bool):
|
|
return v
|
|
v_lower = v.lower()
|
|
if v_lower in ('yes', 'true', 't', 'y', '1'):
|
|
return True
|
|
elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
|
|