Spaces:
Runtime error
Runtime error
Commit
·
798ede0
1
Parent(s):
77ade98
Update utils.py
Browse files
utils.py
CHANGED
@@ -44,7 +44,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
44 |
print (padding_mask)
|
45 |
padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
|
46 |
print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
47 |
-
c['c_concat'] = c['c_concat'] * (~padding_mask) # Zero out the corresponding features
|
48 |
|
49 |
if pos_map is not None:
|
50 |
print (pos_map.shape, c['c_concat'].shape)
|
|
|
44 |
print (padding_mask)
|
45 |
padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
|
46 |
print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
47 |
+
c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
48 |
|
49 |
if pos_map is not None:
|
50 |
print (pos_map.shape, c['c_concat'].shape)
|