Spaces:
Running
on
Zero
Running
on
Zero
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +204 -202
pipelines/pipeline_seesr.py
CHANGED
@@ -1047,187 +1047,188 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
1047 |
print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
|
1048 |
|
1049 |
for i, t in enumerate(timesteps):
|
1050 |
-
|
1051 |
-
|
1052 |
-
|
1053 |
-
|
1054 |
-
|
1055 |
-
# expand the latents if we are doing classifier free guidance
|
1056 |
-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1057 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1058 |
-
|
1059 |
-
# controlnet(s) inference
|
1060 |
-
if guess_mode and do_classifier_free_guidance:
|
1061 |
-
# Infer ControlNet only for the conditional batch.
|
1062 |
-
controlnet_latent_model_input = latents
|
1063 |
-
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
1064 |
-
print("well unexpected")
|
1065 |
-
|
1066 |
-
else:
|
1067 |
-
controlnet_latent_model_input = latent_model_input
|
1068 |
-
controlnet_prompt_embeds = prompt_embeds
|
1069 |
-
print("a possiblity")
|
1070 |
-
|
1071 |
-
if h*w<=tile_size*tile_size: # tiled latent input
|
1072 |
-
down_block_res_samples, mid_block_res_sample = [None]*10, None
|
1073 |
-
|
1074 |
-
print(f"controlnet 1 started with {controlnet_latent_model_input.shape}:{ram_encoder_hidden_states.shape}")
|
1075 |
-
|
1076 |
-
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1077 |
-
controlnet_latent_model_input,
|
1078 |
-
t,
|
1079 |
-
encoder_hidden_states=controlnet_prompt_embeds,
|
1080 |
-
controlnet_cond=image,
|
1081 |
-
conditioning_scale=conditioning_scale,
|
1082 |
-
guess_mode=guess_mode,
|
1083 |
-
return_dict=False,
|
1084 |
-
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1085 |
-
)
|
1086 |
|
|
|
|
|
|
|
1087 |
|
|
|
1088 |
if guess_mode and do_classifier_free_guidance:
|
1089 |
-
#
|
1090 |
-
|
1091 |
-
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
|
1098 |
-
|
1099 |
-
|
1100 |
-
|
1101 |
-
|
1102 |
-
|
1103 |
-
|
1104 |
-
|
1105 |
-
|
1106 |
-
|
1107 |
-
|
1108 |
-
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
|
1120 |
-
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
|
1141 |
-
|
1142 |
-
|
1143 |
-
|
1144 |
-
|
1145 |
-
|
1146 |
-
|
1147 |
-
|
1148 |
-
|
1149 |
-
|
1150 |
-
|
1151 |
-
|
1152 |
-
|
1153 |
-
|
1154 |
-
|
1155 |
-
|
1156 |
-
|
1157 |
-
|
1158 |
-
|
1159 |
-
|
1160 |
-
|
1161 |
-
|
1162 |
-
|
1163 |
-
|
1164 |
-
|
1165 |
-
|
1166 |
-
|
1167 |
-
|
1168 |
-
|
1169 |
-
|
1170 |
-
|
1171 |
-
|
1172 |
-
|
1173 |
-
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
-
|
1178 |
-
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
-
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
|
1213 |
-
|
1214 |
-
|
1215 |
-
|
1216 |
-
|
1217 |
-
|
1218 |
-
|
1219 |
-
|
1220 |
-
|
1221 |
-
|
1222 |
-
|
1223 |
-
|
1224 |
-
|
1225 |
-
|
1226 |
-
|
1227 |
-
|
1228 |
-
|
1229 |
-
|
1230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1231 |
|
1232 |
|
1233 |
if use_KDS:
|
@@ -1285,33 +1286,34 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
1285 |
if callback is not None and i % callback_steps == 0:
|
1286 |
callback(i, t, latents)
|
1287 |
|
1288 |
-
|
1289 |
-
|
1290 |
-
|
1291 |
-
self
|
1292 |
-
|
1293 |
-
|
1294 |
-
|
1295 |
-
|
1296 |
-
if not output_type == "latent":
|
1297 |
-
image = self.vae.decode(latents.detach() / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
|
1298 |
-
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1299 |
-
else:
|
1300 |
-
image = latents.detach()
|
1301 |
has_nsfw_concept = None
|
|
|
|
|
|
|
|
|
|
|
|
|
1302 |
|
1303 |
-
|
1304 |
-
|
1305 |
-
|
1306 |
-
|
1307 |
|
1308 |
-
|
1309 |
|
1310 |
-
|
1311 |
-
|
1312 |
-
|
1313 |
|
1314 |
-
|
1315 |
-
|
1316 |
|
1317 |
-
|
|
|
1047 |
print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
|
1048 |
|
1049 |
for i, t in enumerate(timesteps):
|
1050 |
+
with torch.no_grad():
|
1051 |
+
# pass, if the timestep is larger than start_steps
|
1052 |
+
if t > start_steps:
|
1053 |
+
print(f'pass {t} steps.')
|
1054 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1055 |
|
1056 |
+
# expand the latents if we are doing classifier free guidance
|
1057 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1058 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1059 |
|
1060 |
+
# controlnet(s) inference
|
1061 |
if guess_mode and do_classifier_free_guidance:
|
1062 |
+
# Infer ControlNet only for the conditional batch.
|
1063 |
+
controlnet_latent_model_input = latents
|
1064 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
1065 |
+
print("well unexpected")
|
1066 |
+
|
1067 |
+
else:
|
1068 |
+
controlnet_latent_model_input = latent_model_input
|
1069 |
+
controlnet_prompt_embeds = prompt_embeds
|
1070 |
+
print("a possiblity")
|
1071 |
+
|
1072 |
+
if h*w<=tile_size*tile_size: # tiled latent input
|
1073 |
+
down_block_res_samples, mid_block_res_sample = [None]*10, None
|
1074 |
+
|
1075 |
+
print(f"controlnet 1 started with {controlnet_latent_model_input.shape}:{ram_encoder_hidden_states.shape}")
|
1076 |
+
|
1077 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1078 |
+
controlnet_latent_model_input,
|
1079 |
+
t,
|
1080 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
1081 |
+
controlnet_cond=image,
|
1082 |
+
conditioning_scale=conditioning_scale,
|
1083 |
+
guess_mode=guess_mode,
|
1084 |
+
return_dict=False,
|
1085 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
|
1089 |
+
if guess_mode and do_classifier_free_guidance:
|
1090 |
+
# Infered ControlNet only for the conditional batch.
|
1091 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1092 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1093 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1094 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1095 |
+
|
1096 |
+
# predict the noise residual
|
1097 |
+
print(f"unet started with {latent_model_input.shape}:{prompt_embeds.shape}")
|
1098 |
+
noise_pred = self.unet(
|
1099 |
+
latent_model_input,
|
1100 |
+
t,
|
1101 |
+
encoder_hidden_states=prompt_embeds,
|
1102 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1103 |
+
down_block_additional_residuals=down_block_res_samples,
|
1104 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1105 |
+
return_dict=False,
|
1106 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1107 |
+
)[0]
|
1108 |
+
else:
|
1109 |
+
tile_weights = self._gaussian_weights(tile_size, tile_size, batch_size)
|
1110 |
+
tile_size = min(tile_size, min(h, w))
|
1111 |
+
tile_weights = self._gaussian_weights(tile_size, tile_size, batch_size)
|
1112 |
+
|
1113 |
+
grid_rows = 0
|
1114 |
+
cur_x = 0
|
1115 |
+
while cur_x < latent_model_input.size(-1):
|
1116 |
+
cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
|
1117 |
+
grid_rows += 1
|
1118 |
+
|
1119 |
+
grid_cols = 0
|
1120 |
+
cur_y = 0
|
1121 |
+
while cur_y < latent_model_input.size(-2):
|
1122 |
+
cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
|
1123 |
+
grid_cols += 1
|
1124 |
+
|
1125 |
+
input_list = []
|
1126 |
+
cond_list = []
|
1127 |
+
img_list = []
|
1128 |
+
noise_preds = []
|
1129 |
+
for row in range(grid_rows):
|
1130 |
+
noise_preds_row = []
|
1131 |
+
for col in range(grid_cols):
|
1132 |
+
if col < grid_cols-1 or row < grid_rows-1:
|
1133 |
+
# extract tile from input image
|
1134 |
+
ofs_x = max(row * tile_size-tile_overlap * row, 0)
|
1135 |
+
ofs_y = max(col * tile_size-tile_overlap * col, 0)
|
1136 |
+
# input tile area on total image
|
1137 |
+
if row == grid_rows-1:
|
1138 |
+
ofs_x = w - tile_size
|
1139 |
+
if col == grid_cols-1:
|
1140 |
+
ofs_y = h - tile_size
|
1141 |
+
|
1142 |
+
input_start_x = ofs_x
|
1143 |
+
input_end_x = ofs_x + tile_size
|
1144 |
+
input_start_y = ofs_y
|
1145 |
+
input_end_y = ofs_y + tile_size
|
1146 |
+
|
1147 |
+
# input tile dimensions
|
1148 |
+
input_tile = latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
|
1149 |
+
input_list.append(input_tile)
|
1150 |
+
cond_tile = controlnet_latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
|
1151 |
+
cond_list.append(cond_tile)
|
1152 |
+
img_tile = image[:, :, input_start_y*8:input_end_y*8, input_start_x*8:input_end_x*8]
|
1153 |
+
img_list.append(img_tile)
|
1154 |
+
|
1155 |
+
if len(input_list) == batch_size or col == grid_cols-1:
|
1156 |
+
input_list_t = torch.cat(input_list, dim=0)
|
1157 |
+
cond_list_t = torch.cat(cond_list, dim=0)
|
1158 |
+
img_list_t = torch.cat(img_list, dim=0)
|
1159 |
+
#print(input_list_t.shape, cond_list_t.shape, img_list_t.shape, fg_mask_list_t.shape)
|
1160 |
+
print(f"controlnet 2 started with {cond_list_t.shape}:{controlnet_prompt_embeds.shape}")
|
1161 |
+
|
1162 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1163 |
+
cond_list_t,
|
1164 |
+
t,
|
1165 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
1166 |
+
controlnet_cond=img_list_t,
|
1167 |
+
conditioning_scale=conditioning_scale,
|
1168 |
+
guess_mode=guess_mode,
|
1169 |
+
return_dict=False,
|
1170 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
if guess_mode and do_classifier_free_guidance:
|
1174 |
+
# Infered ControlNet only for the conditional batch.
|
1175 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1176 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1177 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1178 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1179 |
+
|
1180 |
+
# predict the noise residual
|
1181 |
+
print(f"unet started with {input_list_t.shape}:{prompt_embeds.shape}")
|
1182 |
+
model_out = self.unet(
|
1183 |
+
input_list_t,
|
1184 |
+
t,
|
1185 |
+
encoder_hidden_states=prompt_embeds,
|
1186 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1187 |
+
down_block_additional_residuals=down_block_res_samples,
|
1188 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1189 |
+
return_dict=False,
|
1190 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
1191 |
+
)[0]
|
1192 |
+
|
1193 |
+
#for sample_i in range(model_out.size(0)):
|
1194 |
+
# noise_preds_row.append(model_out[sample_i].unsqueeze(0))
|
1195 |
+
input_list = []
|
1196 |
+
cond_list = []
|
1197 |
+
img_list = []
|
1198 |
+
|
1199 |
+
noise_preds.append(model_out)
|
1200 |
+
|
1201 |
+
# Stitch noise predictions for all tiles
|
1202 |
+
noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
|
1203 |
+
contributors = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
|
1204 |
+
# Add each tile contribution to overall latents
|
1205 |
+
for row in range(grid_rows):
|
1206 |
+
for col in range(grid_cols):
|
1207 |
+
if col < grid_cols-1 or row < grid_rows-1:
|
1208 |
+
# extract tile from input image
|
1209 |
+
ofs_x = max(row * tile_size-tile_overlap * row, 0)
|
1210 |
+
ofs_y = max(col * tile_size-tile_overlap * col, 0)
|
1211 |
+
# input tile area on total image
|
1212 |
+
if row == grid_rows-1:
|
1213 |
+
ofs_x = w - tile_size
|
1214 |
+
if col == grid_cols-1:
|
1215 |
+
ofs_y = h - tile_size
|
1216 |
+
|
1217 |
+
input_start_x = ofs_x
|
1218 |
+
input_end_x = ofs_x + tile_size
|
1219 |
+
input_start_y = ofs_y
|
1220 |
+
input_end_y = ofs_y + tile_size
|
1221 |
+
|
1222 |
+
noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
|
1223 |
+
contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
|
1224 |
+
# Average overlapping areas with more than 1 contributor
|
1225 |
+
noise_pred /= contributors
|
1226 |
+
|
1227 |
+
|
1228 |
+
# perform guidance
|
1229 |
+
if do_classifier_free_guidance:
|
1230 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1231 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1232 |
|
1233 |
|
1234 |
if use_KDS:
|
|
|
1286 |
if callback is not None and i % callback_steps == 0:
|
1287 |
callback(i, t, latents)
|
1288 |
|
1289 |
+
with torch.no_grad():
|
1290 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
1291 |
+
# manually for max memory savings
|
1292 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1293 |
+
self.unet.to("cpu")
|
1294 |
+
self.controlnet.to("cpu")
|
1295 |
+
torch.cuda.empty_cache()
|
1296 |
+
|
|
|
|
|
|
|
|
|
|
|
1297 |
has_nsfw_concept = None
|
1298 |
+
if not output_type == "latent":
|
1299 |
+
image = self.vae.decode(latents.detach() / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
|
1300 |
+
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1301 |
+
else:
|
1302 |
+
image = latents.detach()
|
1303 |
+
has_nsfw_concept = None
|
1304 |
|
1305 |
+
if has_nsfw_concept is None:
|
1306 |
+
do_denormalize = [True] * image.shape[0]
|
1307 |
+
else:
|
1308 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
1309 |
|
1310 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
1311 |
|
1312 |
+
# Offload last model to CPU
|
1313 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1314 |
+
self.final_offload_hook.offload()
|
1315 |
|
1316 |
+
if not return_dict:
|
1317 |
+
return (image, has_nsfw_concept)
|
1318 |
|
1319 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|