Spaces:
Runtime error
Runtime error
Update gligen/ldm/models/diffusion/loss.py
Browse files
gligen/ldm/models/diffusion/loss.py
CHANGED
@@ -242,7 +242,7 @@ def caculate_loss_att_fixed_cnt(attn_maps_mid, attn_maps_up, attn_maps_down, bbo
|
|
242 |
|
243 |
return total_loss/obj_number
|
244 |
|
245 |
-
def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att =
|
246 |
attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
|
247 |
# attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32)
|
248 |
# attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64)
|
@@ -290,7 +290,7 @@ def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, obje
|
|
290 |
# 选中物体对应位置(例如[6])的map,然后reshape到[4, 16, 16]
|
291 |
|
292 |
# print(attn_map[:, :, obj_position].shape)
|
293 |
-
ca_map_obj = attn_map[:, :, obj_position].
|
294 |
|
295 |
# print(ca_map_obj.shape)
|
296 |
if smooth_att:
|
|
|
242 |
|
243 |
return total_loss/obj_number
|
244 |
|
245 |
+
def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = False,sigma=0.5,kernel_size=3 ):
|
246 |
attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
|
247 |
# attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32)
|
248 |
# attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64)
|
|
|
290 |
# 选中物体对应位置(例如[6])的map,然后reshape到[4, 16, 16]
|
291 |
|
292 |
# print(attn_map[:, :, obj_position].shape)
|
293 |
+
ca_map_obj = attn_map[:, :, obj_position].mean(-1)
|
294 |
|
295 |
# print(ca_map_obj.shape)
|
296 |
if smooth_att:
|