Spaces:
Runtime error
Runtime error
File size: 8,044 Bytes
2e82449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union, Dict, TypedDict
import numpy as np
from modules import shared
from lib_controlnet.logging import logger
from lib_controlnet.enums import InputMode, HiResFixOption
from modules.api import api
def get_api_version() -> int:
return 2
class ControlMode(Enum):
"""
The improved guess mode.
"""
BALANCED = "Balanced"
PROMPT = "My prompt is more important"
CONTROL = "ControlNet is more important"
class BatchOption(Enum):
DEFAULT = "All ControlNet units for all images in a batch"
SEPARATE = "Each ControlNet unit for each image in a batch"
class ResizeMode(Enum):
"""
Resize modes for ControlNet input images.
"""
RESIZE = "Just Resize"
INNER_FIT = "Crop and Resize"
OUTER_FIT = "Resize and Fill"
def int_value(self):
if self == ResizeMode.RESIZE:
return 0
elif self == ResizeMode.INNER_FIT:
return 1
elif self == ResizeMode.OUTER_FIT:
return 2
assert False, "NOTREACHED"
resize_mode_aliases = {
'Inner Fit (Scale to Fit)': 'Crop and Resize',
'Outer Fit (Shrink to Fit)': 'Resize and Fill',
'Scale to Fit (Inner Fit)': 'Crop and Resize',
'Envelope (Outer Fit)': 'Resize and Fill',
}
def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode:
if isinstance(value, str):
return ResizeMode(resize_mode_aliases.get(value, value))
elif isinstance(value, int):
assert value >= 0
if value == 3: # 'Just Resize (Latent upscale)'
return ResizeMode.RESIZE
if value >= len(ResizeMode):
logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.')
return ResizeMode.RESIZE
return [e for e in ResizeMode][value]
else:
return value
def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode:
if isinstance(value, str):
return ControlMode(value)
elif isinstance(value, int):
return [e for e in ControlMode][value]
else:
return value
def visualize_inpaint_mask(img):
if img.ndim == 3 and img.shape[2] == 4:
result = img.copy()
mask = result[:, :, 3]
mask = 255 - mask // 2
result[:, :, 3] = mask
return np.ascontiguousarray(result.copy())
return img
def pixel_perfect_resolution(
image: np.ndarray,
target_H: int,
target_W: int,
resize_mode: ResizeMode,
) -> int:
"""
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
The function first calculates scaling factors for height and width of the image based on the target
height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger
scaling factor to estimate the new resolution.
If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image
fits within the target dimensions, potentially leaving some empty space.
If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target
dimensions are fully filled, potentially cropping the image.
After calculating the estimated resolution, the function prints some debugging information.
Args:
image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels].
target_H (int): The target height for the image.
target_W (int): The target width for the image.
resize_mode (ResizeMode): The mode for resizing.
Returns:
int: The estimated resolution after resizing.
"""
raw_H, raw_W, _ = image.shape
k0 = float(target_H) / float(raw_H)
k1 = float(target_W) / float(raw_W)
if resize_mode == ResizeMode.OUTER_FIT:
estimation = min(k0, k1) * float(min(raw_H, raw_W))
else:
estimation = max(k0, k1) * float(min(raw_H, raw_W))
logger.debug(f"Pixel Perfect Computation:")
logger.debug(f"resize_mode = {resize_mode}")
logger.debug(f"raw_H = {raw_H}")
logger.debug(f"raw_W = {raw_W}")
logger.debug(f"target_H = {target_H}")
logger.debug(f"target_W = {target_W}")
logger.debug(f"estimation = {estimation}")
return int(np.round(estimation))
class GradioImageMaskPair(TypedDict):
"""Represents the dict object from Gradio's image component if `tool="sketch"`
is specified.
{
"image": np.ndarray,
"mask": np.ndarray,
}
"""
image: np.ndarray
mask: np.ndarray
@dataclass
class ControlNetUnit:
input_mode: InputMode = InputMode.SIMPLE
use_preview_as_input: bool = False
batch_image_dir: str = ''
batch_mask_dir: str = ''
batch_input_gallery: Optional[List[str]] = None
batch_mask_gallery: Optional[List[str]] = None
generated_image: Optional[np.ndarray] = None
mask_image: Optional[GradioImageMaskPair] = None
mask_image_fg: Optional[GradioImageMaskPair] = None
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
enabled: bool = True
module: str = "None"
model: str = "None"
weight: float = 1.0
image: Optional[GradioImageMaskPair] = None
image_fg: Optional[GradioImageMaskPair] = None
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
processor_res: int = -1
threshold_a: float = -1
threshold_b: float = -1
guidance_start: float = 0.0
guidance_end: float = 1.0
pixel_perfect: bool = False
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
save_detected_map: bool = True
@staticmethod
def infotext_fields():
"""Fields that should be included in infotext.
You should define a Gradio element with exact same name in ControlNetUiGroup
as well, so that infotext can wire the value to correct field when pasting
infotext.
"""
return (
"module",
"model",
"weight",
"resize_mode",
"processor_res",
"threshold_a",
"threshold_b",
"guidance_start",
"guidance_end",
"pixel_perfect",
"control_mode",
"hr_option",
)
@staticmethod
def from_dict(d: Dict) -> "ControlNetUnit":
"""Create ControlNetUnit from dict. This is primarily used to convert
API json dict to ControlNetUnit."""
unit = ControlNetUnit(
**{k: v for k, v in d.items() if k in vars(ControlNetUnit)}
)
if isinstance(unit.image, str):
img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8')
unit.image = {
"image": img,
"mask": np.zeros_like(img),
}
if isinstance(unit.mask_image, str):
mask = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8')
if unit.image is not None:
# Attach mask on image if ControlNet has input image.
assert isinstance(unit.image, dict)
unit.image["mask"] = mask
unit.mask_image = None
else:
# Otherwise, wire to standalone mask.
# This happens in img2img when using A1111 img2img input.
unit.mask_image = {
"image": mask,
"mask": np.zeros_like(mask),
}
return unit
# Backward Compatible
UiControlNetUnit = ControlNetUnit
def to_base64_nparray(encoding: str):
"""
Convert a base64 image into the image type the extension uses
"""
return np.array(api.decode_base64_to_image(encoding)).astype('uint8')
def get_max_models_num():
"""
Fetch the maximum number of allowed ControlNet models.
"""
max_models_num = shared.opts.data.get("control_net_unit_count", 3)
return max_models_num
|