Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,632 Bytes
77f10a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
if len(image.shape) == 4:
return image.shape[1], image.shape[2]
elif len(image.shape) == 3:
return image.shape[0], image.shape[1]
else:
raise ValueError("Invalid image tensor shape.")
def validate_image_dimensions(
image: torch.Tensor,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
height, width = get_image_dimensions(image)
if min_width is not None and width < min_width:
raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
if max_width is not None and width > max_width:
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Image height must be at least {min_height}px, got {height}px"
)
if max_height is not None and height > max_height:
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
def validate_image_aspect_ratio(
image: torch.Tensor,
min_aspect_ratio: Optional[float] = None,
max_aspect_ratio: Optional[float] = None,
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
)
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
)
def validate_video_dimensions(
video: VideoInput,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
try:
width, height = video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
return
if min_width is not None and width < min_width:
raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
if max_width is not None and width > max_width:
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Video height must be at least {min_height}px, got {height}px"
)
if max_height is not None and height > max_height:
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
def validate_video_duration(
video: VideoInput,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
try:
duration = video.get_duration()
except Exception as e:
logging.error("Error getting duration of video: %s", e)
return
epsilon = 0.0001
if min_duration is not None and min_duration - epsilon > duration:
raise ValueError(
f"Video duration must be at least {min_duration}s, got {duration}s"
)
if max_duration is not None and duration > max_duration + epsilon:
raise ValueError(
f"Video duration must be at most {max_duration}s, got {duration}s"
)
|