alexnasa commited on
Commit
acd710e
·
verified ·
1 Parent(s): 39bbec4

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +199 -295
pipelines/pipeline_seesr.py CHANGED
@@ -22,7 +22,6 @@ import numpy as np
22
  import PIL.Image
23
  import torch
24
  import torch.nn.functional as F
25
- from torch.nn.functional import unfold, fold
26
  from torchvision.utils import save_image
27
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
28
 
@@ -96,31 +95,7 @@ EXAMPLE_DOC_STRING = """
96
  ... ).images[0]
97
  ```
98
  """
99
- def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
100
- # x0: (N, C, H, W) in float32
101
- N, C, H, W = x0.shape
102
- patches = unfold(
103
- x0, kernel_size=patch_size, stride=patch_size//2
104
- ) # (N, C*ps*ps, M)
105
- P, M = patches.shape[1], patches.shape[2]
106
- p_i = patches.unsqueeze(1) # (N,1,P,M)
107
- p_j = patches.unsqueeze(0) # (1,N,P,M)
108
- diff = p_j - p_i # (N,N,P,M)
109
- # Gaussian weights
110
- w = torch.exp((-0.5 / bandwidth**2) *
111
- (diff.square().sum(dim=2))) # (N,N,M)
112
- # mean-shift numerator & normalizer
113
- num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
114
- denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
115
- mshift = num / denom # (N,P,M)
116
- # fold back
117
- grad = fold(
118
- mshift / bandwidth**2,
119
- output_size=(H, W),
120
- kernel_size=patch_size,
121
- stride=patch_size//2
122
- ) # (N, C, H, W)
123
- return grad
124
 
125
  class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
126
  r"""
@@ -803,6 +778,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
803
  return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
804
 
805
  @perfcount
 
806
  @replace_example_docstring(EXAMPLE_DOC_STRING)
807
  def __call__(
808
  self,
@@ -832,12 +808,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
832
  ram_encoder_hidden_states=None,
833
  latent_tiled_size=320,
834
  latent_tiled_overlap=4,
835
- num_particles: Optional[int] = 4,
836
- gamma_0: Optional[float] = 0.1, # base steering strength
837
- use_KDS = True,
838
- patch_size = 16,
839
- bandwidth = 0.1,
840
- args=None,
841
  ):
842
  r"""
843
  Function invoked when calling the pipeline for generation.
@@ -1025,17 +996,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1025
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1026
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1027
 
1028
- if use_KDS:
1029
- # 1) update batch_size to account for the new particles
1030
- batch_size = batch_size * num_particles
1031
-
1032
- # 2) now repeat latents/images/prompts
1033
- latents = latents.repeat_interleave(num_particles, dim=0)
1034
- image = image.repeat_interleave(num_particles, dim=0)
1035
- ram_encoder_hidden_states = ram_encoder_hidden_states.repeat_interleave(num_particles, dim=0)
1036
- prompt_embeds = prompt_embeds.repeat_interleave(num_particles, dim=0)
1037
- latents.requires_grad_(True)
1038
-
1039
  # 8. Denoising loop
1040
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1041
  with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1048,220 +1008,184 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1048
  print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
1049
 
1050
  for i, t in enumerate(timesteps):
1051
- with torch.no_grad():
1052
- # pass, if the timestep is larger than start_steps
1053
- if t > start_steps:
1054
- print(f'pass {t} steps.')
1055
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1056
 
1057
- # expand the latents if we are doing classifier free guidance
1058
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1059
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1060
 
1061
- # controlnet(s) inference
1062
  if guess_mode and do_classifier_free_guidance:
1063
- # Infer ControlNet only for the conditional batch.
1064
- controlnet_latent_model_input = latents
1065
- controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1066
- else:
1067
- controlnet_latent_model_input = latent_model_input
1068
- controlnet_prompt_embeds = prompt_embeds
1069
-
1070
- if h*w<=tile_size*tile_size: # tiled latent input
1071
- down_block_res_samples, mid_block_res_sample = [None]*10, None
1072
-
1073
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1074
- controlnet_latent_model_input,
1075
- t,
1076
- encoder_hidden_states=controlnet_prompt_embeds,
1077
- controlnet_cond=image,
1078
- conditioning_scale=conditioning_scale,
1079
- guess_mode=guess_mode,
1080
- return_dict=False,
1081
- image_encoder_hidden_states = ram_encoder_hidden_states,
1082
- )
1083
-
1084
-
1085
- if guess_mode and do_classifier_free_guidance:
1086
- # Infered ControlNet only for the conditional batch.
1087
- # To apply the output of ControlNet to both the unconditional and conditional batches,
1088
- # add 0 to the unconditional batch to keep it unchanged.
1089
- down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1090
- mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1091
-
1092
- # predict the noise residual
1093
- noise_pred = self.unet(
1094
- latent_model_input,
1095
- t,
1096
- encoder_hidden_states=prompt_embeds,
1097
- cross_attention_kwargs=cross_attention_kwargs,
1098
- down_block_additional_residuals=down_block_res_samples,
1099
- mid_block_additional_residual=mid_block_res_sample,
1100
- return_dict=False,
1101
- image_encoder_hidden_states = ram_encoder_hidden_states,
1102
- )[0]
1103
- else:
1104
- tile_weights = self._gaussian_weights(tile_size, tile_size, batch_size)
1105
- tile_size = min(tile_size, min(h, w))
1106
- tile_weights = self._gaussian_weights(tile_size, tile_size, batch_size)
1107
-
1108
- grid_rows = 0
1109
- cur_x = 0
1110
- while cur_x < latent_model_input.size(-1):
1111
- cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
1112
- grid_rows += 1
1113
-
1114
- grid_cols = 0
1115
- cur_y = 0
1116
- while cur_y < latent_model_input.size(-2):
1117
- cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
1118
- grid_cols += 1
1119
-
1120
- input_list = []
1121
- cond_list = []
1122
- img_list = []
1123
- noise_preds = []
1124
- for row in range(grid_rows):
1125
- noise_preds_row = []
1126
- for col in range(grid_cols):
1127
- if col < grid_cols-1 or row < grid_rows-1:
1128
- # extract tile from input image
1129
- ofs_x = max(row * tile_size-tile_overlap * row, 0)
1130
- ofs_y = max(col * tile_size-tile_overlap * col, 0)
1131
- # input tile area on total image
1132
- if row == grid_rows-1:
1133
- ofs_x = w - tile_size
1134
- if col == grid_cols-1:
1135
- ofs_y = h - tile_size
1136
-
1137
- input_start_x = ofs_x
1138
- input_end_x = ofs_x + tile_size
1139
- input_start_y = ofs_y
1140
- input_end_y = ofs_y + tile_size
1141
-
1142
- # input tile dimensions
1143
- input_tile = latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1144
- input_list.append(input_tile)
1145
- cond_tile = controlnet_latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1146
- cond_list.append(cond_tile)
1147
- img_tile = image[:, :, input_start_y*8:input_end_y*8, input_start_x*8:input_end_x*8]
1148
- img_list.append(img_tile)
1149
-
1150
- if len(input_list) == batch_size or col == grid_cols-1:
1151
- input_list_t = torch.cat(input_list, dim=0)
1152
- cond_list_t = torch.cat(cond_list, dim=0)
1153
- img_list_t = torch.cat(img_list, dim=0)
1154
- #print(input_list_t.shape, cond_list_t.shape, img_list_t.shape, fg_mask_list_t.shape)
1155
-
1156
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1157
- cond_list_t,
1158
- t,
1159
- encoder_hidden_states=controlnet_prompt_embeds,
1160
- controlnet_cond=img_list_t,
1161
- conditioning_scale=conditioning_scale,
1162
- guess_mode=guess_mode,
1163
- return_dict=False,
1164
- image_encoder_hidden_states = ram_encoder_hidden_states,
1165
- )
1166
-
1167
- if guess_mode and do_classifier_free_guidance:
1168
- # Infered ControlNet only for the conditional batch.
1169
- # To apply the output of ControlNet to both the unconditional and conditional batches,
1170
- # add 0 to the unconditional batch to keep it unchanged.
1171
- down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1172
- mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1173
-
1174
- # predict the noise residual
1175
- model_out = self.unet(
1176
- input_list_t,
1177
- t,
1178
- encoder_hidden_states=prompt_embeds,
1179
- cross_attention_kwargs=cross_attention_kwargs,
1180
- down_block_additional_residuals=down_block_res_samples,
1181
- mid_block_additional_residual=mid_block_res_sample,
1182
- return_dict=False,
1183
- image_encoder_hidden_states = ram_encoder_hidden_states,
1184
- )[0]
1185
-
1186
- #for sample_i in range(model_out.size(0)):
1187
- # noise_preds_row.append(model_out[sample_i].unsqueeze(0))
1188
- input_list = []
1189
- cond_list = []
1190
- img_list = []
1191
-
1192
- noise_preds.append(model_out)
1193
-
1194
- # Stitch noise predictions for all tiles
1195
- noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1196
- contributors = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1197
- # Add each tile contribution to overall latents
1198
- for row in range(grid_rows):
1199
- for col in range(grid_cols):
1200
- if col < grid_cols-1 or row < grid_rows-1:
1201
- # extract tile from input image
1202
- ofs_x = max(row * tile_size-tile_overlap * row, 0)
1203
- ofs_y = max(col * tile_size-tile_overlap * col, 0)
1204
- # input tile area on total image
1205
- if row == grid_rows-1:
1206
- ofs_x = w - tile_size
1207
- if col == grid_cols-1:
1208
- ofs_y = h - tile_size
1209
-
1210
- input_start_x = ofs_x
1211
- input_end_x = ofs_x + tile_size
1212
- input_start_y = ofs_y
1213
- input_end_y = ofs_y + tile_size
1214
-
1215
- noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
1216
- contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
1217
- # Average overlapping areas with more than 1 contributor
1218
- noise_pred /= contributors
1219
-
1220
-
1221
- # perform guidance
1222
- if do_classifier_free_guidance:
1223
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1224
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1225
 
1226
 
1227
- if use_KDS:
1228
 
1229
- # 2) Compute x₀ prediction for all particles
1230
- beta_t = 1 - self.scheduler.alphas_cumprod[t]
1231
- alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
1232
- sigma_t = beta_t.sqrt()
1233
- x0_pred = (latents - sigma_t * noise_pred) / alpha_t # shape [2N, C, H, W]
1234
-
1235
- # — split into unconditional vs. conditional
1236
- x0_uncond, x0_cond = x0_pred.chunk(2, dim=0) # each [N, C, H, W]
1237
-
1238
- # 3) Apply KDE steering *only* on the conditional batch
1239
- m_shift_cond = kde_grad(x0_cond, bandwidth=bandwidth) # [N, C, H, W]
1240
- delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
1241
- x0_cond_steer = x0_cond + delta_t * m_shift_cond # steered conditional
1242
-
1243
- # 4) Recombine the latents: leave uncond untouched, use steered cond
1244
- x0_steer = torch.cat([x0_uncond, x0_cond_steer], dim=0) # [2N, C, H, W]
1245
-
1246
- # 5) Recompute “noise” for DDIM step
1247
- noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
1248
-
1249
- # 6) Determine prev alphas and form next latent per DDIM
1250
- if i < len(timesteps) - 1:
1251
- next_t = timesteps[i + 1]
1252
- alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
1253
- else:
1254
- alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
1255
- sigma_prev = (1 - alpha_prev**2).sqrt()
1256
-
1257
- latents = (
1258
- alpha_prev * x0_steer
1259
- + sigma_prev * noise_pred_kds
1260
- ).detach().requires_grad_(True)
1261
- else:
1262
-
1263
- # compute the previous noisy sample x_t -> x_t-1
1264
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1265
 
1266
  # call the callback, if provided
1267
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1269,53 +1193,33 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1269
  if callback is not None and i % callback_steps == 0:
1270
  callback(i, t, latents)
1271
 
1272
- with torch.no_grad():
1273
-
1274
- if use_KDS:
1275
- # Final-latent selection (once!)
1276
- # latents shape: [2*N, C, H, W]
1277
- uncond_latents, cond_latents = latents.chunk(2, dim=0) # each [N, C, H, W]
1278
- # 1) ensemble mean
1279
- mean_cond = cond_latents.mean(dim=0, keepdim=True) # [1, C, H, W]
1280
- # 2) distances
1281
- dists = ((cond_latents - mean_cond)
1282
- .view(cond_latents.size(0), -1)
1283
- .pow(2)
1284
- .sum(dim=1)) # [N]
1285
- # 3) best index
1286
- best_idx = dists.argmin().item()
1287
- # 4) select that latent (and its uncond pair)
1288
- best_uncond = uncond_latents[best_idx:best_idx+1]
1289
- best_cond = cond_latents [best_idx:best_idx+1]
1290
- latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
1291
-
1292
- # If we do sequential model offloading, let's offload unet and controlnet
1293
- # manually for max memory savings
1294
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1295
- self.unet.to("cpu")
1296
- self.controlnet.to("cpu")
1297
- torch.cuda.empty_cache()
1298
-
1299
  has_nsfw_concept = None
1300
- if not output_type == "latent":
1301
- image = self.vae.decode(latents.detach() / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
1302
- #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1303
- else:
1304
- image = latents.detach()
1305
- has_nsfw_concept = None
1306
 
1307
- if has_nsfw_concept is None:
1308
- do_denormalize = [True] * image.shape[0]
1309
- else:
1310
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1311
 
1312
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1313
 
1314
- # Offload last model to CPU
1315
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1316
- self.final_offload_hook.offload()
1317
 
1318
- if not return_dict:
1319
- return (image, has_nsfw_concept)
1320
 
1321
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
22
  import PIL.Image
23
  import torch
24
  import torch.nn.functional as F
 
25
  from torchvision.utils import save_image
26
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
27
 
 
95
  ... ).images[0]
96
  ```
97
  """
98
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
101
  r"""
 
778
  return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
779
 
780
  @perfcount
781
+ @torch.no_grad()
782
  @replace_example_docstring(EXAMPLE_DOC_STRING)
783
  def __call__(
784
  self,
 
808
  ram_encoder_hidden_states=None,
809
  latent_tiled_size=320,
810
  latent_tiled_overlap=4,
811
+ args=None
 
 
 
 
 
812
  ):
813
  r"""
814
  Function invoked when calling the pipeline for generation.
 
996
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
997
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
998
 
 
 
 
 
 
 
 
 
 
 
 
999
  # 8. Denoising loop
1000
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1001
  with self.progress_bar(total=num_inference_steps) as progress_bar:
 
1008
  print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
1009
 
1010
  for i, t in enumerate(timesteps):
1011
+ # pass, if the timestep is larger than start_steps
1012
+ if t > start_steps:
1013
+ print(f'pass {t} steps.')
1014
+ continue
1015
+
1016
+ # expand the latents if we are doing classifier free guidance
1017
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1018
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1019
+
1020
+ # controlnet(s) inference
1021
+ if guess_mode and do_classifier_free_guidance:
1022
+ # Infer ControlNet only for the conditional batch.
1023
+ controlnet_latent_model_input = latents
1024
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1025
+
1026
+ else:
1027
+ controlnet_latent_model_input = latent_model_input
1028
+ controlnet_prompt_embeds = prompt_embeds
1029
+
1030
+ if h*w<=tile_size*tile_size: # tiled latent input
1031
+ down_block_res_samples, mid_block_res_sample = [None]*10, None
1032
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1033
+ controlnet_latent_model_input,
1034
+ t,
1035
+ encoder_hidden_states=controlnet_prompt_embeds,
1036
+ controlnet_cond=image,
1037
+ conditioning_scale=conditioning_scale,
1038
+ guess_mode=guess_mode,
1039
+ return_dict=False,
1040
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1041
+ )
1042
 
 
 
 
1043
 
 
1044
  if guess_mode and do_classifier_free_guidance:
1045
+ # Infered ControlNet only for the conditional batch.
1046
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1047
+ # add 0 to the unconditional batch to keep it unchanged.
1048
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1049
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1050
+
1051
+ # predict the noise residual
1052
+ noise_pred = self.unet(
1053
+ latent_model_input,
1054
+ t,
1055
+ encoder_hidden_states=prompt_embeds,
1056
+ cross_attention_kwargs=cross_attention_kwargs,
1057
+ down_block_additional_residuals=down_block_res_samples,
1058
+ mid_block_additional_residual=mid_block_res_sample,
1059
+ return_dict=False,
1060
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1061
+ )[0]
1062
+ else:
1063
+ tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
1064
+ tile_size = min(tile_size, min(h, w))
1065
+ tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
1066
+
1067
+ grid_rows = 0
1068
+ cur_x = 0
1069
+ while cur_x < latent_model_input.size(-1):
1070
+ cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
1071
+ grid_rows += 1
1072
+
1073
+ grid_cols = 0
1074
+ cur_y = 0
1075
+ while cur_y < latent_model_input.size(-2):
1076
+ cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
1077
+ grid_cols += 1
1078
+
1079
+ input_list = []
1080
+ cond_list = []
1081
+ img_list = []
1082
+ noise_preds = []
1083
+ for row in range(grid_rows):
1084
+ noise_preds_row = []
1085
+ for col in range(grid_cols):
1086
+ if col < grid_cols-1 or row < grid_rows-1:
1087
+ # extract tile from input image
1088
+ ofs_x = max(row * tile_size-tile_overlap * row, 0)
1089
+ ofs_y = max(col * tile_size-tile_overlap * col, 0)
1090
+ # input tile area on total image
1091
+ if row == grid_rows-1:
1092
+ ofs_x = w - tile_size
1093
+ if col == grid_cols-1:
1094
+ ofs_y = h - tile_size
1095
+
1096
+ input_start_x = ofs_x
1097
+ input_end_x = ofs_x + tile_size
1098
+ input_start_y = ofs_y
1099
+ input_end_y = ofs_y + tile_size
1100
+
1101
+ # input tile dimensions
1102
+ input_tile = latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1103
+ input_list.append(input_tile)
1104
+ cond_tile = controlnet_latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1105
+ cond_list.append(cond_tile)
1106
+ img_tile = image[:, :, input_start_y*8:input_end_y*8, input_start_x*8:input_end_x*8]
1107
+ img_list.append(img_tile)
1108
+
1109
+ if len(input_list) == batch_size or col == grid_cols-1:
1110
+ input_list_t = torch.cat(input_list, dim=0)
1111
+ cond_list_t = torch.cat(cond_list, dim=0)
1112
+ img_list_t = torch.cat(img_list, dim=0)
1113
+ #print(input_list_t.shape, cond_list_t.shape, img_list_t.shape, fg_mask_list_t.shape)
1114
+
1115
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1116
+ cond_list_t,
1117
+ t,
1118
+ encoder_hidden_states=controlnet_prompt_embeds,
1119
+ controlnet_cond=img_list_t,
1120
+ conditioning_scale=conditioning_scale,
1121
+ guess_mode=guess_mode,
1122
+ return_dict=False,
1123
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1124
+ )
1125
+
1126
+ if guess_mode and do_classifier_free_guidance:
1127
+ # Infered ControlNet only for the conditional batch.
1128
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1129
+ # add 0 to the unconditional batch to keep it unchanged.
1130
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1131
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1132
+
1133
+ # predict the noise residual
1134
+ model_out = self.unet(
1135
+ input_list_t,
1136
+ t,
1137
+ encoder_hidden_states=prompt_embeds,
1138
+ cross_attention_kwargs=cross_attention_kwargs,
1139
+ down_block_additional_residuals=down_block_res_samples,
1140
+ mid_block_additional_residual=mid_block_res_sample,
1141
+ return_dict=False,
1142
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1143
+ )[0]
1144
+
1145
+ #for sample_i in range(model_out.size(0)):
1146
+ # noise_preds_row.append(model_out[sample_i].unsqueeze(0))
1147
+ input_list = []
1148
+ cond_list = []
1149
+ img_list = []
1150
+
1151
+ noise_preds.append(model_out)
1152
+
1153
+ # Stitch noise predictions for all tiles
1154
+ noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1155
+ contributors = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1156
+ # Add each tile contribution to overall latents
1157
+ for row in range(grid_rows):
1158
+ for col in range(grid_cols):
1159
+ if col < grid_cols-1 or row < grid_rows-1:
1160
+ # extract tile from input image
1161
+ ofs_x = max(row * tile_size-tile_overlap * row, 0)
1162
+ ofs_y = max(col * tile_size-tile_overlap * col, 0)
1163
+ # input tile area on total image
1164
+ if row == grid_rows-1:
1165
+ ofs_x = w - tile_size
1166
+ if col == grid_cols-1:
1167
+ ofs_y = h - tile_size
1168
+
1169
+ input_start_x = ofs_x
1170
+ input_end_x = ofs_x + tile_size
1171
+ input_start_y = ofs_y
1172
+ input_end_y = ofs_y + tile_size
1173
+
1174
+ noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
1175
+ contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
1176
+ # Average overlapping areas with more than 1 contributor
1177
+ noise_pred /= contributors
1178
+
1179
+
1180
+ # perform guidance
1181
+ if do_classifier_free_guidance:
1182
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1183
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1184
 
1185
 
 
1186
 
1187
+ # compute the previous noisy sample x_t -> x_t-1
1188
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1189
 
1190
  # call the callback, if provided
1191
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
1193
  if callback is not None and i % callback_steps == 0:
1194
  callback(i, t, latents)
1195
 
1196
+ # If we do sequential model offloading, let's offload unet and controlnet
1197
+ # manually for max memory savings
1198
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1199
+ self.unet.to("cpu")
1200
+ self.controlnet.to("cpu")
1201
+ torch.cuda.empty_cache()
1202
+
1203
+ has_nsfw_concept = None
1204
+ if not output_type == "latent":
1205
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
1206
+ #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1207
+ else:
1208
+ image = latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1209
  has_nsfw_concept = None
 
 
 
 
 
 
1210
 
1211
+ if has_nsfw_concept is None:
1212
+ do_denormalize = [True] * image.shape[0]
1213
+ else:
1214
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1215
 
1216
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1217
 
1218
+ # Offload last model to CPU
1219
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1220
+ self.final_offload_hook.offload()
1221
 
1222
+ if not return_dict:
1223
+ return (image, has_nsfw_concept)
1224
 
1225
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)