import logging from typing import Optional import torch from comfy_api.input.video_types import VideoInput def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: if len(image.shape) == 4: return image.shape[1], image.shape[2] elif len(image.shape) == 3: return image.shape[0], image.shape[1] else: raise ValueError("Invalid image tensor shape.") def validate_image_dimensions( image: torch.Tensor, min_width: Optional[int] = None, max_width: Optional[int] = None, min_height: Optional[int] = None, max_height: Optional[int] = None, ): height, width = get_image_dimensions(image) if min_width is not None and width < min_width: raise ValueError(f"Image width must be at least {min_width}px, got {width}px") if max_width is not None and width > max_width: raise ValueError(f"Image width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: raise ValueError( f"Image height must be at least {min_height}px, got {height}px" ) if max_height is not None and height > max_height: raise ValueError(f"Image height must be at most {max_height}px, got {height}px") def validate_image_aspect_ratio( image: torch.Tensor, min_aspect_ratio: Optional[float] = None, max_aspect_ratio: Optional[float] = None, ): width, height = get_image_dimensions(image) aspect_ratio = width / height if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: raise ValueError( f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}" ) if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: raise ValueError( f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}" ) def validate_video_dimensions( video: VideoInput, min_width: Optional[int] = None, max_width: Optional[int] = None, min_height: Optional[int] = None, max_height: Optional[int] = None, ): try: width, height = video.get_dimensions() except Exception as e: logging.error("Error getting dimensions of video: %s", e) return if min_width is not None and width < min_width: raise ValueError(f"Video width must be at least {min_width}px, got {width}px") if max_width is not None and width > max_width: raise ValueError(f"Video width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: raise ValueError( f"Video height must be at least {min_height}px, got {height}px" ) if max_height is not None and height > max_height: raise ValueError(f"Video height must be at most {max_height}px, got {height}px") def validate_video_duration( video: VideoInput, min_duration: Optional[float] = None, max_duration: Optional[float] = None, ): try: duration = video.get_duration() except Exception as e: logging.error("Error getting duration of video: %s", e) return epsilon = 0.0001 if min_duration is not None and min_duration - epsilon > duration: raise ValueError( f"Video duration must be at least {min_duration}s, got {duration}s" ) if max_duration is not None and duration > max_duration + epsilon: raise ValueError( f"Video duration must be at most {max_duration}s, got {duration}s" )