Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import argparse | |
import binascii | |
import logging | |
import os | |
import os.path as osp | |
import imageio | |
import torch | |
import torchvision | |
__all__ = ['cache_video', 'cache_image', 'str2bool'] | |
def rand_name(length=8, suffix=''): | |
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') | |
if suffix: | |
if not suffix.startswith('.'): | |
suffix = '.' + suffix | |
name += suffix | |
return name | |
def cache_video(tensor, | |
save_file=None, | |
fps=30, | |
suffix='.mp4', | |
nrow=8, | |
normalize=True, | |
value_range=(-1, 1), | |
retry=5): | |
# cache file | |
cache_file = osp.join('/tmp', rand_name( | |
suffix=suffix)) if save_file is None else save_file | |
# save to cache | |
error = None | |
for _ in range(retry): | |
try: | |
# preprocess | |
tensor = tensor.clamp(min(value_range), max(value_range)) | |
tensor = torch.stack([ | |
torchvision.utils.make_grid( | |
u, nrow=nrow, normalize=normalize, value_range=value_range) | |
for u in tensor.unbind(2) | |
], | |
dim=1).permute(1, 2, 3, 0) | |
tensor = (tensor * 255).type(torch.uint8).cpu() | |
# write video | |
writer = imageio.get_writer( | |
cache_file, fps=fps, codec='libx264', quality=8) | |
for frame in tensor.numpy(): | |
writer.append_data(frame) | |
writer.close() | |
return cache_file | |
except Exception as e: | |
error = e | |
continue | |
else: | |
logging.info(f'cache_video failed, error: {error}', flush=True) | |
return None | |
def cache_image(tensor, | |
save_file, | |
nrow=8, | |
normalize=True, | |
value_range=(-1, 1), | |
retry=5): | |
# cache file | |
suffix = osp.splitext(save_file)[1] | |
if suffix.lower() not in [ | |
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' | |
]: | |
suffix = '.png' | |
# save to cache | |
error = None | |
for _ in range(retry): | |
try: | |
tensor = tensor.clamp(min(value_range), max(value_range)) | |
torchvision.utils.save_image( | |
tensor, | |
save_file, | |
nrow=nrow, | |
normalize=normalize, | |
value_range=value_range) | |
return save_file | |
except Exception as e: | |
error = e | |
continue | |
def str2bool(v): | |
""" | |
Convert a string to a boolean. | |
Supported true values: 'yes', 'true', 't', 'y', '1' | |
Supported false values: 'no', 'false', 'f', 'n', '0' | |
Args: | |
v (str): String to convert. | |
Returns: | |
bool: Converted boolean value. | |
Raises: | |
argparse.ArgumentTypeError: If the value cannot be converted to boolean. | |
""" | |
if isinstance(v, bool): | |
return v | |
v_lower = v.lower() | |
if v_lower in ('yes', 'true', 't', 'y', '1'): | |
return True | |
elif v_lower in ('no', 'false', 'f', 'n', '0'): | |
return False | |
else: | |
raise argparse.ArgumentTypeError('Boolean value expected (True/False)') | |
def masks_like(tensor, zero=False, generator=None, p=0.2): | |
assert isinstance(tensor, list) | |
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] | |
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] | |
if zero: | |
if generator is not None: | |
for u, v in zip(out1, out2): | |
random_num = torch.rand( | |
1, generator=generator, device=generator.device).item() | |
if random_num < p: | |
u[:, 0] = torch.normal( | |
mean=-3.5, | |
std=0.5, | |
size=(1,), | |
device=u.device, | |
generator=generator).expand_as(u[:, 0]).exp() | |
v[:, 0] = torch.zeros_like(v[:, 0]) | |
else: | |
u[:, 0] = u[:, 0] | |
v[:, 0] = v[:, 0] | |
else: | |
for u, v in zip(out1, out2): | |
u[:, 0] = torch.zeros_like(u[:, 0]) | |
v[:, 0] = torch.zeros_like(v[:, 0]) | |
return out1, out2 | |
def best_output_size(w, h, dw, dh, expected_area): | |
# float output size | |
ratio = w / h | |
ow = (expected_area * ratio)**0.5 | |
oh = expected_area / ow | |
# process width first | |
ow1 = int(ow // dw * dw) | |
oh1 = int(expected_area / ow1 // dh * dh) | |
assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area | |
ratio1 = ow1 / oh1 | |
# process height first | |
oh2 = int(oh // dh * dh) | |
ow2 = int(expected_area / oh2 // dw * dw) | |
assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area | |
ratio2 = ow2 / oh2 | |
# compare ratios | |
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, | |
ratio2 / ratio): | |
return ow1, oh1 | |
else: | |
return ow2, oh2 | |