Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -1122,25 +1122,12 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 1122 |
else:
|
| 1123 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
| 1124 |
|
| 1125 |
-
ns = []
|
| 1126 |
-
c1, c2, c3 = 0, 0, 0
|
| 1127 |
-
for n, p in self.unet.named_parameters():
|
| 1128 |
-
if p.data.dtype == torch.float:
|
| 1129 |
-
c1 += 1
|
| 1130 |
-
ns.append(n)
|
| 1131 |
-
elif p.data.dtype == torch.half:
|
| 1132 |
-
c2 += 1
|
| 1133 |
-
else:
|
| 1134 |
-
c3 += 1
|
| 1135 |
-
print(c1, c2, c3)
|
| 1136 |
-
print(ns)
|
| 1137 |
model_pred = self.unet(
|
| 1138 |
x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
|
| 1139 |
t_list, # (B,)
|
| 1140 |
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
| 1141 |
return_dict=False,
|
| 1142 |
)[0] # (B, 4, h, w)
|
| 1143 |
-
print('222222222222222', model_pred.dtype)
|
| 1144 |
|
| 1145 |
if self.bootstrap_steps[0] > 0:
|
| 1146 |
# Uncentering.
|
|
@@ -1151,6 +1138,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 1151 |
bootstrap_mask_ = torch.concat([bootstrap_mask, bootstrap_mask], dim=0)
|
| 1152 |
else:
|
| 1153 |
bootstrap_mask_ = bootstrap_mask
|
|
|
|
| 1154 |
model_pred = shift_to_mask_bbox_center(model_pred, bootstrap_mask_)
|
| 1155 |
x_t_latent = shift_to_mask_bbox_center(x_t_latent, bootstrap_mask)
|
| 1156 |
|
|
@@ -1235,7 +1223,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
| 1235 |
self.stock_noise_ = self.stock_noise.repeat_interleave(self.num_layers, dim=0) # (T * p, 77, 768)
|
| 1236 |
|
| 1237 |
x_0_pred_batch = self.unet_step(latent)
|
| 1238 |
-
|
| 1239 |
latent = x_0_pred_batch[-1:]
|
| 1240 |
self.x_t_latent_buffer = (
|
| 1241 |
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
|
|
|
|
| 1122 |
else:
|
| 1123 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
| 1124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1125 |
model_pred = self.unet(
|
| 1126 |
x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
|
| 1127 |
t_list, # (B,)
|
| 1128 |
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
| 1129 |
return_dict=False,
|
| 1130 |
)[0] # (B, 4, h, w)
|
|
|
|
| 1131 |
|
| 1132 |
if self.bootstrap_steps[0] > 0:
|
| 1133 |
# Uncentering.
|
|
|
|
| 1138 |
bootstrap_mask_ = torch.concat([bootstrap_mask, bootstrap_mask], dim=0)
|
| 1139 |
else:
|
| 1140 |
bootstrap_mask_ = bootstrap_mask
|
| 1141 |
+
print('2222222222222222222222222222222222222', model_pred.shape, bootstrap_mask_)
|
| 1142 |
model_pred = shift_to_mask_bbox_center(model_pred, bootstrap_mask_)
|
| 1143 |
x_t_latent = shift_to_mask_bbox_center(x_t_latent, bootstrap_mask)
|
| 1144 |
|
|
|
|
| 1223 |
self.stock_noise_ = self.stock_noise.repeat_interleave(self.num_layers, dim=0) # (T * p, 77, 768)
|
| 1224 |
|
| 1225 |
x_0_pred_batch = self.unet_step(latent)
|
| 1226 |
+
print('111111111111111111111111111111111')
|
| 1227 |
latent = x_0_pred_batch[-1:]
|
| 1228 |
self.x_t_latent_buffer = (
|
| 1229 |
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
|