alexnasa commited on
Commit
454560a
·
verified ·
1 Parent(s): 548082b

Update pipelines/pipeline_seesr.py

Browse files
Files changed (1) hide show
  1. pipelines/pipeline_seesr.py +204 -202
pipelines/pipeline_seesr.py CHANGED
@@ -1047,187 +1047,188 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1047
  print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
1048
 
1049
  for i, t in enumerate(timesteps):
1050
- # pass, if the timestep is larger than start_steps
1051
- if t > start_steps:
1052
- print(f'pass {t} steps.')
1053
- continue
1054
-
1055
- # expand the latents if we are doing classifier free guidance
1056
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1057
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1058
-
1059
- # controlnet(s) inference
1060
- if guess_mode and do_classifier_free_guidance:
1061
- # Infer ControlNet only for the conditional batch.
1062
- controlnet_latent_model_input = latents
1063
- controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1064
- print("well unexpected")
1065
-
1066
- else:
1067
- controlnet_latent_model_input = latent_model_input
1068
- controlnet_prompt_embeds = prompt_embeds
1069
- print("a possiblity")
1070
-
1071
- if h*w<=tile_size*tile_size: # tiled latent input
1072
- down_block_res_samples, mid_block_res_sample = [None]*10, None
1073
-
1074
- print(f"controlnet 1 started with {controlnet_latent_model_input.shape}:{ram_encoder_hidden_states.shape}")
1075
-
1076
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1077
- controlnet_latent_model_input,
1078
- t,
1079
- encoder_hidden_states=controlnet_prompt_embeds,
1080
- controlnet_cond=image,
1081
- conditioning_scale=conditioning_scale,
1082
- guess_mode=guess_mode,
1083
- return_dict=False,
1084
- image_encoder_hidden_states = ram_encoder_hidden_states,
1085
- )
1086
 
 
 
 
1087
 
 
1088
  if guess_mode and do_classifier_free_guidance:
1089
- # Infered ControlNet only for the conditional batch.
1090
- # To apply the output of ControlNet to both the unconditional and conditional batches,
1091
- # add 0 to the unconditional batch to keep it unchanged.
1092
- down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1093
- mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1094
-
1095
- # predict the noise residual
1096
- print(f"unet started with {latent_model_input.shape}:{prompt_embeds.shape}")
1097
- noise_pred = self.unet(
1098
- latent_model_input,
1099
- t,
1100
- encoder_hidden_states=prompt_embeds,
1101
- cross_attention_kwargs=cross_attention_kwargs,
1102
- down_block_additional_residuals=down_block_res_samples,
1103
- mid_block_additional_residual=mid_block_res_sample,
1104
- return_dict=False,
1105
- image_encoder_hidden_states = ram_encoder_hidden_states,
1106
- )[0]
1107
- else:
1108
- tile_weights = self._gaussian_weights(tile_size, tile_size, batch_size)
1109
- tile_size = min(tile_size, min(h, w))
1110
- tile_weights = self._gaussian_weights(tile_size, tile_size, batch_size)
1111
-
1112
- grid_rows = 0
1113
- cur_x = 0
1114
- while cur_x < latent_model_input.size(-1):
1115
- cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
1116
- grid_rows += 1
1117
-
1118
- grid_cols = 0
1119
- cur_y = 0
1120
- while cur_y < latent_model_input.size(-2):
1121
- cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
1122
- grid_cols += 1
1123
-
1124
- input_list = []
1125
- cond_list = []
1126
- img_list = []
1127
- noise_preds = []
1128
- for row in range(grid_rows):
1129
- noise_preds_row = []
1130
- for col in range(grid_cols):
1131
- if col < grid_cols-1 or row < grid_rows-1:
1132
- # extract tile from input image
1133
- ofs_x = max(row * tile_size-tile_overlap * row, 0)
1134
- ofs_y = max(col * tile_size-tile_overlap * col, 0)
1135
- # input tile area on total image
1136
- if row == grid_rows-1:
1137
- ofs_x = w - tile_size
1138
- if col == grid_cols-1:
1139
- ofs_y = h - tile_size
1140
-
1141
- input_start_x = ofs_x
1142
- input_end_x = ofs_x + tile_size
1143
- input_start_y = ofs_y
1144
- input_end_y = ofs_y + tile_size
1145
-
1146
- # input tile dimensions
1147
- input_tile = latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1148
- input_list.append(input_tile)
1149
- cond_tile = controlnet_latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1150
- cond_list.append(cond_tile)
1151
- img_tile = image[:, :, input_start_y*8:input_end_y*8, input_start_x*8:input_end_x*8]
1152
- img_list.append(img_tile)
1153
-
1154
- if len(input_list) == batch_size or col == grid_cols-1:
1155
- input_list_t = torch.cat(input_list, dim=0)
1156
- cond_list_t = torch.cat(cond_list, dim=0)
1157
- img_list_t = torch.cat(img_list, dim=0)
1158
- #print(input_list_t.shape, cond_list_t.shape, img_list_t.shape, fg_mask_list_t.shape)
1159
- print(f"controlnet 2 started with {cond_list_t.shape}:{controlnet_prompt_embeds.shape}")
1160
-
1161
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1162
- cond_list_t,
1163
- t,
1164
- encoder_hidden_states=controlnet_prompt_embeds,
1165
- controlnet_cond=img_list_t,
1166
- conditioning_scale=conditioning_scale,
1167
- guess_mode=guess_mode,
1168
- return_dict=False,
1169
- image_encoder_hidden_states = ram_encoder_hidden_states,
1170
- )
1171
-
1172
- if guess_mode and do_classifier_free_guidance:
1173
- # Infered ControlNet only for the conditional batch.
1174
- # To apply the output of ControlNet to both the unconditional and conditional batches,
1175
- # add 0 to the unconditional batch to keep it unchanged.
1176
- down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1177
- mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1178
-
1179
- # predict the noise residual
1180
- print(f"unet started with {input_list_t.shape}:{prompt_embeds.shape}")
1181
- model_out = self.unet(
1182
- input_list_t,
1183
- t,
1184
- encoder_hidden_states=prompt_embeds,
1185
- cross_attention_kwargs=cross_attention_kwargs,
1186
- down_block_additional_residuals=down_block_res_samples,
1187
- mid_block_additional_residual=mid_block_res_sample,
1188
- return_dict=False,
1189
- image_encoder_hidden_states = ram_encoder_hidden_states,
1190
- )[0]
1191
-
1192
- #for sample_i in range(model_out.size(0)):
1193
- # noise_preds_row.append(model_out[sample_i].unsqueeze(0))
1194
- input_list = []
1195
- cond_list = []
1196
- img_list = []
1197
-
1198
- noise_preds.append(model_out)
1199
-
1200
- # Stitch noise predictions for all tiles
1201
- noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1202
- contributors = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1203
- # Add each tile contribution to overall latents
1204
- for row in range(grid_rows):
1205
- for col in range(grid_cols):
1206
- if col < grid_cols-1 or row < grid_rows-1:
1207
- # extract tile from input image
1208
- ofs_x = max(row * tile_size-tile_overlap * row, 0)
1209
- ofs_y = max(col * tile_size-tile_overlap * col, 0)
1210
- # input tile area on total image
1211
- if row == grid_rows-1:
1212
- ofs_x = w - tile_size
1213
- if col == grid_cols-1:
1214
- ofs_y = h - tile_size
1215
-
1216
- input_start_x = ofs_x
1217
- input_end_x = ofs_x + tile_size
1218
- input_start_y = ofs_y
1219
- input_end_y = ofs_y + tile_size
1220
-
1221
- noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
1222
- contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
1223
- # Average overlapping areas with more than 1 contributor
1224
- noise_pred /= contributors
1225
-
1226
-
1227
- # perform guidance
1228
- if do_classifier_free_guidance:
1229
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1230
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1231
 
1232
 
1233
  if use_KDS:
@@ -1285,33 +1286,34 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
1285
  if callback is not None and i % callback_steps == 0:
1286
  callback(i, t, latents)
1287
 
1288
- # If we do sequential model offloading, let's offload unet and controlnet
1289
- # manually for max memory savings
1290
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1291
- self.unet.to("cpu")
1292
- self.controlnet.to("cpu")
1293
- torch.cuda.empty_cache()
1294
-
1295
- has_nsfw_concept = None
1296
- if not output_type == "latent":
1297
- image = self.vae.decode(latents.detach() / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
1298
- #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1299
- else:
1300
- image = latents.detach()
1301
  has_nsfw_concept = None
 
 
 
 
 
 
1302
 
1303
- if has_nsfw_concept is None:
1304
- do_denormalize = [True] * image.shape[0]
1305
- else:
1306
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1307
 
1308
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1309
 
1310
- # Offload last model to CPU
1311
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1312
- self.final_offload_hook.offload()
1313
 
1314
- if not return_dict:
1315
- return (image, has_nsfw_concept)
1316
 
1317
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
1047
  print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
1048
 
1049
  for i, t in enumerate(timesteps):
1050
+ with torch.no_grad():
1051
+ # pass, if the timestep is larger than start_steps
1052
+ if t > start_steps:
1053
+ print(f'pass {t} steps.')
1054
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1055
 
1056
+ # expand the latents if we are doing classifier free guidance
1057
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1058
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1059
 
1060
+ # controlnet(s) inference
1061
  if guess_mode and do_classifier_free_guidance:
1062
+ # Infer ControlNet only for the conditional batch.
1063
+ controlnet_latent_model_input = latents
1064
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1065
+ print("well unexpected")
1066
+
1067
+ else:
1068
+ controlnet_latent_model_input = latent_model_input
1069
+ controlnet_prompt_embeds = prompt_embeds
1070
+ print("a possiblity")
1071
+
1072
+ if h*w<=tile_size*tile_size: # tiled latent input
1073
+ down_block_res_samples, mid_block_res_sample = [None]*10, None
1074
+
1075
+ print(f"controlnet 1 started with {controlnet_latent_model_input.shape}:{ram_encoder_hidden_states.shape}")
1076
+
1077
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1078
+ controlnet_latent_model_input,
1079
+ t,
1080
+ encoder_hidden_states=controlnet_prompt_embeds,
1081
+ controlnet_cond=image,
1082
+ conditioning_scale=conditioning_scale,
1083
+ guess_mode=guess_mode,
1084
+ return_dict=False,
1085
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1086
+ )
1087
+
1088
+
1089
+ if guess_mode and do_classifier_free_guidance:
1090
+ # Infered ControlNet only for the conditional batch.
1091
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1092
+ # add 0 to the unconditional batch to keep it unchanged.
1093
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1094
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1095
+
1096
+ # predict the noise residual
1097
+ print(f"unet started with {latent_model_input.shape}:{prompt_embeds.shape}")
1098
+ noise_pred = self.unet(
1099
+ latent_model_input,
1100
+ t,
1101
+ encoder_hidden_states=prompt_embeds,
1102
+ cross_attention_kwargs=cross_attention_kwargs,
1103
+ down_block_additional_residuals=down_block_res_samples,
1104
+ mid_block_additional_residual=mid_block_res_sample,
1105
+ return_dict=False,
1106
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1107
+ )[0]
1108
+ else:
1109
+ tile_weights = self._gaussian_weights(tile_size, tile_size, batch_size)
1110
+ tile_size = min(tile_size, min(h, w))
1111
+ tile_weights = self._gaussian_weights(tile_size, tile_size, batch_size)
1112
+
1113
+ grid_rows = 0
1114
+ cur_x = 0
1115
+ while cur_x < latent_model_input.size(-1):
1116
+ cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
1117
+ grid_rows += 1
1118
+
1119
+ grid_cols = 0
1120
+ cur_y = 0
1121
+ while cur_y < latent_model_input.size(-2):
1122
+ cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
1123
+ grid_cols += 1
1124
+
1125
+ input_list = []
1126
+ cond_list = []
1127
+ img_list = []
1128
+ noise_preds = []
1129
+ for row in range(grid_rows):
1130
+ noise_preds_row = []
1131
+ for col in range(grid_cols):
1132
+ if col < grid_cols-1 or row < grid_rows-1:
1133
+ # extract tile from input image
1134
+ ofs_x = max(row * tile_size-tile_overlap * row, 0)
1135
+ ofs_y = max(col * tile_size-tile_overlap * col, 0)
1136
+ # input tile area on total image
1137
+ if row == grid_rows-1:
1138
+ ofs_x = w - tile_size
1139
+ if col == grid_cols-1:
1140
+ ofs_y = h - tile_size
1141
+
1142
+ input_start_x = ofs_x
1143
+ input_end_x = ofs_x + tile_size
1144
+ input_start_y = ofs_y
1145
+ input_end_y = ofs_y + tile_size
1146
+
1147
+ # input tile dimensions
1148
+ input_tile = latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1149
+ input_list.append(input_tile)
1150
+ cond_tile = controlnet_latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
1151
+ cond_list.append(cond_tile)
1152
+ img_tile = image[:, :, input_start_y*8:input_end_y*8, input_start_x*8:input_end_x*8]
1153
+ img_list.append(img_tile)
1154
+
1155
+ if len(input_list) == batch_size or col == grid_cols-1:
1156
+ input_list_t = torch.cat(input_list, dim=0)
1157
+ cond_list_t = torch.cat(cond_list, dim=0)
1158
+ img_list_t = torch.cat(img_list, dim=0)
1159
+ #print(input_list_t.shape, cond_list_t.shape, img_list_t.shape, fg_mask_list_t.shape)
1160
+ print(f"controlnet 2 started with {cond_list_t.shape}:{controlnet_prompt_embeds.shape}")
1161
+
1162
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1163
+ cond_list_t,
1164
+ t,
1165
+ encoder_hidden_states=controlnet_prompt_embeds,
1166
+ controlnet_cond=img_list_t,
1167
+ conditioning_scale=conditioning_scale,
1168
+ guess_mode=guess_mode,
1169
+ return_dict=False,
1170
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1171
+ )
1172
+
1173
+ if guess_mode and do_classifier_free_guidance:
1174
+ # Infered ControlNet only for the conditional batch.
1175
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1176
+ # add 0 to the unconditional batch to keep it unchanged.
1177
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1178
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1179
+
1180
+ # predict the noise residual
1181
+ print(f"unet started with {input_list_t.shape}:{prompt_embeds.shape}")
1182
+ model_out = self.unet(
1183
+ input_list_t,
1184
+ t,
1185
+ encoder_hidden_states=prompt_embeds,
1186
+ cross_attention_kwargs=cross_attention_kwargs,
1187
+ down_block_additional_residuals=down_block_res_samples,
1188
+ mid_block_additional_residual=mid_block_res_sample,
1189
+ return_dict=False,
1190
+ image_encoder_hidden_states = ram_encoder_hidden_states,
1191
+ )[0]
1192
+
1193
+ #for sample_i in range(model_out.size(0)):
1194
+ # noise_preds_row.append(model_out[sample_i].unsqueeze(0))
1195
+ input_list = []
1196
+ cond_list = []
1197
+ img_list = []
1198
+
1199
+ noise_preds.append(model_out)
1200
+
1201
+ # Stitch noise predictions for all tiles
1202
+ noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1203
+ contributors = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
1204
+ # Add each tile contribution to overall latents
1205
+ for row in range(grid_rows):
1206
+ for col in range(grid_cols):
1207
+ if col < grid_cols-1 or row < grid_rows-1:
1208
+ # extract tile from input image
1209
+ ofs_x = max(row * tile_size-tile_overlap * row, 0)
1210
+ ofs_y = max(col * tile_size-tile_overlap * col, 0)
1211
+ # input tile area on total image
1212
+ if row == grid_rows-1:
1213
+ ofs_x = w - tile_size
1214
+ if col == grid_cols-1:
1215
+ ofs_y = h - tile_size
1216
+
1217
+ input_start_x = ofs_x
1218
+ input_end_x = ofs_x + tile_size
1219
+ input_start_y = ofs_y
1220
+ input_end_y = ofs_y + tile_size
1221
+
1222
+ noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
1223
+ contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
1224
+ # Average overlapping areas with more than 1 contributor
1225
+ noise_pred /= contributors
1226
+
1227
+
1228
+ # perform guidance
1229
+ if do_classifier_free_guidance:
1230
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1231
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1232
 
1233
 
1234
  if use_KDS:
 
1286
  if callback is not None and i % callback_steps == 0:
1287
  callback(i, t, latents)
1288
 
1289
+ with torch.no_grad():
1290
+ # If we do sequential model offloading, let's offload unet and controlnet
1291
+ # manually for max memory savings
1292
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1293
+ self.unet.to("cpu")
1294
+ self.controlnet.to("cpu")
1295
+ torch.cuda.empty_cache()
1296
+
 
 
 
 
 
1297
  has_nsfw_concept = None
1298
+ if not output_type == "latent":
1299
+ image = self.vae.decode(latents.detach() / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
1300
+ #image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1301
+ else:
1302
+ image = latents.detach()
1303
+ has_nsfw_concept = None
1304
 
1305
+ if has_nsfw_concept is None:
1306
+ do_denormalize = [True] * image.shape[0]
1307
+ else:
1308
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1309
 
1310
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1311
 
1312
+ # Offload last model to CPU
1313
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1314
+ self.final_offload_hook.offload()
1315
 
1316
+ if not return_dict:
1317
+ return (image, has_nsfw_concept)
1318
 
1319
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)