Spaces:
Running
on
Zero
Running
on
Zero
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +199 -295
pipelines/pipeline_seesr.py
CHANGED
@@ -22,7 +22,6 @@ import numpy as np
|
|
22 |
import PIL.Image
|
23 |
import torch
|
24 |
import torch.nn.functional as F
|
25 |
-
from torch.nn.functional import unfold, fold
|
26 |
from torchvision.utils import save_image
|
27 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
28 |
|
@@ -96,31 +95,7 @@ EXAMPLE_DOC_STRING = """
|
|
96 |
... ).images[0]
|
97 |
```
|
98 |
"""
|
99 |
-
|
100 |
-
# x0: (N, C, H, W) in float32
|
101 |
-
N, C, H, W = x0.shape
|
102 |
-
patches = unfold(
|
103 |
-
x0, kernel_size=patch_size, stride=patch_size//2
|
104 |
-
) # (N, C*ps*ps, M)
|
105 |
-
P, M = patches.shape[1], patches.shape[2]
|
106 |
-
p_i = patches.unsqueeze(1) # (N,1,P,M)
|
107 |
-
p_j = patches.unsqueeze(0) # (1,N,P,M)
|
108 |
-
diff = p_j - p_i # (N,N,P,M)
|
109 |
-
# Gaussian weights
|
110 |
-
w = torch.exp((-0.5 / bandwidth**2) *
|
111 |
-
(diff.square().sum(dim=2))) # (N,N,M)
|
112 |
-
# mean-shift numerator & normalizer
|
113 |
-
num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
|
114 |
-
denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
|
115 |
-
mshift = num / denom # (N,P,M)
|
116 |
-
# fold back
|
117 |
-
grad = fold(
|
118 |
-
mshift / bandwidth**2,
|
119 |
-
output_size=(H, W),
|
120 |
-
kernel_size=patch_size,
|
121 |
-
stride=patch_size//2
|
122 |
-
) # (N, C, H, W)
|
123 |
-
return grad
|
124 |
|
125 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
126 |
r"""
|
@@ -803,6 +778,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
803 |
return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
|
804 |
|
805 |
@perfcount
|
|
|
806 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
807 |
def __call__(
|
808 |
self,
|
@@ -832,12 +808,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
832 |
ram_encoder_hidden_states=None,
|
833 |
latent_tiled_size=320,
|
834 |
latent_tiled_overlap=4,
|
835 |
-
|
836 |
-
gamma_0: Optional[float] = 0.1, # base steering strength
|
837 |
-
use_KDS = True,
|
838 |
-
patch_size = 16,
|
839 |
-
bandwidth = 0.1,
|
840 |
-
args=None,
|
841 |
):
|
842 |
r"""
|
843 |
Function invoked when calling the pipeline for generation.
|
@@ -1025,17 +996,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
1025 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
1026 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
1027 |
|
1028 |
-
if use_KDS:
|
1029 |
-
# 1) update batch_size to account for the new particles
|
1030 |
-
batch_size = batch_size * num_particles
|
1031 |
-
|
1032 |
-
# 2) now repeat latents/images/prompts
|
1033 |
-
latents = latents.repeat_interleave(num_particles, dim=0)
|
1034 |
-
image = image.repeat_interleave(num_particles, dim=0)
|
1035 |
-
ram_encoder_hidden_states = ram_encoder_hidden_states.repeat_interleave(num_particles, dim=0)
|
1036 |
-
prompt_embeds = prompt_embeds.repeat_interleave(num_particles, dim=0)
|
1037 |
-
latents.requires_grad_(True)
|
1038 |
-
|
1039 |
# 8. Denoising loop
|
1040 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1041 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
@@ -1048,220 +1008,184 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
1048 |
print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
|
1049 |
|
1050 |
for i, t in enumerate(timesteps):
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1056 |
|
1057 |
-
# expand the latents if we are doing classifier free guidance
|
1058 |
-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1059 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1060 |
|
1061 |
-
# controlnet(s) inference
|
1062 |
if guess_mode and do_classifier_free_guidance:
|
1063 |
-
#
|
1064 |
-
|
1065 |
-
|
1066 |
-
|
1067 |
-
|
1068 |
-
|
1069 |
-
|
1070 |
-
|
1071 |
-
|
1072 |
-
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
|
1141 |
-
|
1142 |
-
|
1143 |
-
|
1144 |
-
|
1145 |
-
|
1146 |
-
|
1147 |
-
|
1148 |
-
|
1149 |
-
|
1150 |
-
|
1151 |
-
|
1152 |
-
|
1153 |
-
|
1154 |
-
|
1155 |
-
|
1156 |
-
|
1157 |
-
|
1158 |
-
|
1159 |
-
|
1160 |
-
|
1161 |
-
|
1162 |
-
|
1163 |
-
|
1164 |
-
|
1165 |
-
|
1166 |
-
|
1167 |
-
|
1168 |
-
|
1169 |
-
|
1170 |
-
|
1171 |
-
|
1172 |
-
|
1173 |
-
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
-
|
1178 |
-
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
-
ofs_x = max(row * tile_size-tile_overlap * row, 0)
|
1203 |
-
ofs_y = max(col * tile_size-tile_overlap * col, 0)
|
1204 |
-
# input tile area on total image
|
1205 |
-
if row == grid_rows-1:
|
1206 |
-
ofs_x = w - tile_size
|
1207 |
-
if col == grid_cols-1:
|
1208 |
-
ofs_y = h - tile_size
|
1209 |
-
|
1210 |
-
input_start_x = ofs_x
|
1211 |
-
input_end_x = ofs_x + tile_size
|
1212 |
-
input_start_y = ofs_y
|
1213 |
-
input_end_y = ofs_y + tile_size
|
1214 |
-
|
1215 |
-
noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
|
1216 |
-
contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
|
1217 |
-
# Average overlapping areas with more than 1 contributor
|
1218 |
-
noise_pred /= contributors
|
1219 |
-
|
1220 |
-
|
1221 |
-
# perform guidance
|
1222 |
-
if do_classifier_free_guidance:
|
1223 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1224 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1225 |
|
1226 |
|
1227 |
-
if use_KDS:
|
1228 |
|
1229 |
-
|
1230 |
-
|
1231 |
-
alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
|
1232 |
-
sigma_t = beta_t.sqrt()
|
1233 |
-
x0_pred = (latents - sigma_t * noise_pred) / alpha_t # shape [2N, C, H, W]
|
1234 |
-
|
1235 |
-
# — split into unconditional vs. conditional
|
1236 |
-
x0_uncond, x0_cond = x0_pred.chunk(2, dim=0) # each [N, C, H, W]
|
1237 |
-
|
1238 |
-
# 3) Apply KDE steering *only* on the conditional batch
|
1239 |
-
m_shift_cond = kde_grad(x0_cond, bandwidth=bandwidth) # [N, C, H, W]
|
1240 |
-
delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
|
1241 |
-
x0_cond_steer = x0_cond + delta_t * m_shift_cond # steered conditional
|
1242 |
-
|
1243 |
-
# 4) Recombine the latents: leave uncond untouched, use steered cond
|
1244 |
-
x0_steer = torch.cat([x0_uncond, x0_cond_steer], dim=0) # [2N, C, H, W]
|
1245 |
-
|
1246 |
-
# 5) Recompute “noise” for DDIM step
|
1247 |
-
noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
|
1248 |
-
|
1249 |
-
# 6) Determine prev alphas and form next latent per DDIM
|
1250 |
-
if i < len(timesteps) - 1:
|
1251 |
-
next_t = timesteps[i + 1]
|
1252 |
-
alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
|
1253 |
-
else:
|
1254 |
-
alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
|
1255 |
-
sigma_prev = (1 - alpha_prev**2).sqrt()
|
1256 |
-
|
1257 |
-
latents = (
|
1258 |
-
alpha_prev * x0_steer
|
1259 |
-
+ sigma_prev * noise_pred_kds
|
1260 |
-
).detach().requires_grad_(True)
|
1261 |
-
else:
|
1262 |
-
|
1263 |
-
# compute the previous noisy sample x_t -> x_t-1
|
1264 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1265 |
|
1266 |
# call the callback, if provided
|
1267 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
@@ -1269,53 +1193,33 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
1269 |
if callback is not None and i % callback_steps == 0:
|
1270 |
callback(i, t, latents)
|
1271 |
|
1272 |
-
|
1273 |
-
|
1274 |
-
|
1275 |
-
|
1276 |
-
|
1277 |
-
|
1278 |
-
|
1279 |
-
|
1280 |
-
|
1281 |
-
|
1282 |
-
|
1283 |
-
|
1284 |
-
|
1285 |
-
# 3) best index
|
1286 |
-
best_idx = dists.argmin().item()
|
1287 |
-
# 4) select that latent (and its uncond pair)
|
1288 |
-
best_uncond = uncond_latents[best_idx:best_idx+1]
|
1289 |
-
best_cond = cond_latents [best_idx:best_idx+1]
|
1290 |
-
latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
|
1291 |
-
|
1292 |
-
# If we do sequential model offloading, let's offload unet and controlnet
|
1293 |
-
# manually for max memory savings
|
1294 |
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1295 |
-
self.unet.to("cpu")
|
1296 |
-
self.controlnet.to("cpu")
|
1297 |
-
torch.cuda.empty_cache()
|
1298 |
-
|
1299 |
has_nsfw_concept = None
|
1300 |
-
if not output_type == "latent":
|
1301 |
-
image = self.vae.decode(latents.detach() / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
|
1302 |
-
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1303 |
-
else:
|
1304 |
-
image = latents.detach()
|
1305 |
-
has_nsfw_concept = None
|
1306 |
|
1307 |
-
|
1308 |
-
|
1309 |
-
|
1310 |
-
|
1311 |
|
1312 |
-
|
1313 |
|
1314 |
-
|
1315 |
-
|
1316 |
-
|
1317 |
|
1318 |
-
|
1319 |
-
|
1320 |
|
1321 |
-
|
|
|
22 |
import PIL.Image
|
23 |
import torch
|
24 |
import torch.nn.functional as F
|
|
|
25 |
from torchvision.utils import save_image
|
26 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
27 |
|
|
|
95 |
... ).images[0]
|
96 |
```
|
97 |
"""
|
98 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
101 |
r"""
|
|
|
778 |
return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
|
779 |
|
780 |
@perfcount
|
781 |
+
@torch.no_grad()
|
782 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
783 |
def __call__(
|
784 |
self,
|
|
|
808 |
ram_encoder_hidden_states=None,
|
809 |
latent_tiled_size=320,
|
810 |
latent_tiled_overlap=4,
|
811 |
+
args=None
|
|
|
|
|
|
|
|
|
|
|
812 |
):
|
813 |
r"""
|
814 |
Function invoked when calling the pipeline for generation.
|
|
|
996 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
997 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
998 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
999 |
# 8. Denoising loop
|
1000 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1001 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
|
1008 |
print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
|
1009 |
|
1010 |
for i, t in enumerate(timesteps):
|
1011 |
+
# pass, if the timestep is larger than start_steps
|
1012 |
+
if t > start_steps:
|
1013 |
+
print(f'pass {t} steps.')
|
1014 |
+
continue
|
1015 |
+
|
1016 |
+
# expand the latents if we are doing classifier free guidance
|
1017 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1018 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1019 |
+
|
1020 |
+
# controlnet(s) inference
|
1021 |
+
if guess_mode and do_classifier_free_guidance:
|
1022 |
+
# Infer ControlNet only for the conditional batch.
|
1023 |
+
controlnet_latent_model_input = latents
|
1024 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
1025 |
+
|
1026 |
+
else:
|
1027 |
+
controlnet_latent_model_input = latent_model_input
|
1028 |
+
controlnet_prompt_embeds = prompt_embeds
|
1029 |
+
|
1030 |
+
if h*w<=tile_size*tile_size: # tiled latent input
|
1031 |
+
down_block_res_samples, mid_block_res_sample = [None]*10, None
|
1032 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1033 |
+
controlnet_latent_model_input,
|
1034 |
+
t,
|
1035 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
1036 |
+
controlnet_cond=image,
|
1037 |
+
conditioning_scale=conditioning_scale,
|
1038 |
+
guess_mode=guess_mode,
|
1039 |
+
return_dict=False,
|
1040 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1041 |
+
)
|
1042 |
|
|
|
|
|
|
|
1043 |
|
|
|
1044 |
if guess_mode and do_classifier_free_guidance:
|
1045 |
+
# Infered ControlNet only for the conditional batch.
|
1046 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1047 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1048 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1049 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1050 |
+
|
1051 |
+
# predict the noise residual
|
1052 |
+
noise_pred = self.unet(
|
1053 |
+
latent_model_input,
|
1054 |
+
t,
|
1055 |
+
encoder_hidden_states=prompt_embeds,
|
1056 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1057 |
+
down_block_additional_residuals=down_block_res_samples,
|
1058 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1059 |
+
return_dict=False,
|
1060 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1061 |
+
)[0]
|
1062 |
+
else:
|
1063 |
+
tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
|
1064 |
+
tile_size = min(tile_size, min(h, w))
|
1065 |
+
tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
|
1066 |
+
|
1067 |
+
grid_rows = 0
|
1068 |
+
cur_x = 0
|
1069 |
+
while cur_x < latent_model_input.size(-1):
|
1070 |
+
cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
|
1071 |
+
grid_rows += 1
|
1072 |
+
|
1073 |
+
grid_cols = 0
|
1074 |
+
cur_y = 0
|
1075 |
+
while cur_y < latent_model_input.size(-2):
|
1076 |
+
cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
|
1077 |
+
grid_cols += 1
|
1078 |
+
|
1079 |
+
input_list = []
|
1080 |
+
cond_list = []
|
1081 |
+
img_list = []
|
1082 |
+
noise_preds = []
|
1083 |
+
for row in range(grid_rows):
|
1084 |
+
noise_preds_row = []
|
1085 |
+
for col in range(grid_cols):
|
1086 |
+
if col < grid_cols-1 or row < grid_rows-1:
|
1087 |
+
# extract tile from input image
|
1088 |
+
ofs_x = max(row * tile_size-tile_overlap * row, 0)
|
1089 |
+
ofs_y = max(col * tile_size-tile_overlap * col, 0)
|
1090 |
+
# input tile area on total image
|
1091 |
+
if row == grid_rows-1:
|
1092 |
+
ofs_x = w - tile_size
|
1093 |
+
if col == grid_cols-1:
|
1094 |
+
ofs_y = h - tile_size
|
1095 |
+
|
1096 |
+
input_start_x = ofs_x
|
1097 |
+
input_end_x = ofs_x + tile_size
|
1098 |
+
input_start_y = ofs_y
|
1099 |
+
input_end_y = ofs_y + tile_size
|
1100 |
+
|
1101 |
+
# input tile dimensions
|
1102 |
+
input_tile = latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
|
1103 |
+
input_list.append(input_tile)
|
1104 |
+
cond_tile = controlnet_latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
|
1105 |
+
cond_list.append(cond_tile)
|
1106 |
+
img_tile = image[:, :, input_start_y*8:input_end_y*8, input_start_x*8:input_end_x*8]
|
1107 |
+
img_list.append(img_tile)
|
1108 |
+
|
1109 |
+
if len(input_list) == batch_size or col == grid_cols-1:
|
1110 |
+
input_list_t = torch.cat(input_list, dim=0)
|
1111 |
+
cond_list_t = torch.cat(cond_list, dim=0)
|
1112 |
+
img_list_t = torch.cat(img_list, dim=0)
|
1113 |
+
#print(input_list_t.shape, cond_list_t.shape, img_list_t.shape, fg_mask_list_t.shape)
|
1114 |
+
|
1115 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1116 |
+
cond_list_t,
|
1117 |
+
t,
|
1118 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
1119 |
+
controlnet_cond=img_list_t,
|
1120 |
+
conditioning_scale=conditioning_scale,
|
1121 |
+
guess_mode=guess_mode,
|
1122 |
+
return_dict=False,
|
1123 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
if guess_mode and do_classifier_free_guidance:
|
1127 |
+
# Infered ControlNet only for the conditional batch.
|
1128 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1129 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1130 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1131 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1132 |
+
|
1133 |
+
# predict the noise residual
|
1134 |
+
model_out = self.unet(
|
1135 |
+
input_list_t,
|
1136 |
+
t,
|
1137 |
+
encoder_hidden_states=prompt_embeds,
|
1138 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1139 |
+
down_block_additional_residuals=down_block_res_samples,
|
1140 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1141 |
+
return_dict=False,
|
1142 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1143 |
+
)[0]
|
1144 |
+
|
1145 |
+
#for sample_i in range(model_out.size(0)):
|
1146 |
+
# noise_preds_row.append(model_out[sample_i].unsqueeze(0))
|
1147 |
+
input_list = []
|
1148 |
+
cond_list = []
|
1149 |
+
img_list = []
|
1150 |
+
|
1151 |
+
noise_preds.append(model_out)
|
1152 |
+
|
1153 |
+
# Stitch noise predictions for all tiles
|
1154 |
+
noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
|
1155 |
+
contributors = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
|
1156 |
+
# Add each tile contribution to overall latents
|
1157 |
+
for row in range(grid_rows):
|
1158 |
+
for col in range(grid_cols):
|
1159 |
+
if col < grid_cols-1 or row < grid_rows-1:
|
1160 |
+
# extract tile from input image
|
1161 |
+
ofs_x = max(row * tile_size-tile_overlap * row, 0)
|
1162 |
+
ofs_y = max(col * tile_size-tile_overlap * col, 0)
|
1163 |
+
# input tile area on total image
|
1164 |
+
if row == grid_rows-1:
|
1165 |
+
ofs_x = w - tile_size
|
1166 |
+
if col == grid_cols-1:
|
1167 |
+
ofs_y = h - tile_size
|
1168 |
+
|
1169 |
+
input_start_x = ofs_x
|
1170 |
+
input_end_x = ofs_x + tile_size
|
1171 |
+
input_start_y = ofs_y
|
1172 |
+
input_end_y = ofs_y + tile_size
|
1173 |
+
|
1174 |
+
noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
|
1175 |
+
contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
|
1176 |
+
# Average overlapping areas with more than 1 contributor
|
1177 |
+
noise_pred /= contributors
|
1178 |
+
|
1179 |
+
|
1180 |
+
# perform guidance
|
1181 |
+
if do_classifier_free_guidance:
|
1182 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1183 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1184 |
|
1185 |
|
|
|
1186 |
|
1187 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1188 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1189 |
|
1190 |
# call the callback, if provided
|
1191 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
|
1193 |
if callback is not None and i % callback_steps == 0:
|
1194 |
callback(i, t, latents)
|
1195 |
|
1196 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
1197 |
+
# manually for max memory savings
|
1198 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1199 |
+
self.unet.to("cpu")
|
1200 |
+
self.controlnet.to("cpu")
|
1201 |
+
torch.cuda.empty_cache()
|
1202 |
+
|
1203 |
+
has_nsfw_concept = None
|
1204 |
+
if not output_type == "latent":
|
1205 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
|
1206 |
+
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1207 |
+
else:
|
1208 |
+
image = latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1209 |
has_nsfw_concept = None
|
|
|
|
|
|
|
|
|
|
|
|
|
1210 |
|
1211 |
+
if has_nsfw_concept is None:
|
1212 |
+
do_denormalize = [True] * image.shape[0]
|
1213 |
+
else:
|
1214 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
1215 |
|
1216 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
1217 |
|
1218 |
+
# Offload last model to CPU
|
1219 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1220 |
+
self.final_offload_hook.offload()
|
1221 |
|
1222 |
+
if not return_dict:
|
1223 |
+
return (image, has_nsfw_concept)
|
1224 |
|
1225 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|