Pusheen commited on
Commit
912b880
·
verified ·
1 Parent(s): 087de1a

Update gligen/ldm/models/diffusion/loss.py

Browse files
gligen/ldm/models/diffusion/loss.py CHANGED
@@ -242,7 +242,7 @@ def caculate_loss_att_fixed_cnt(attn_maps_mid, attn_maps_up, attn_maps_down, bbo
242
 
243
  return total_loss/obj_number
244
 
245
- def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ):
246
  attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
247
  # attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32)
248
  # attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64)
@@ -290,7 +290,7 @@ def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, obje
290
  # 选中物体对应位置(例如[6])的map,然后reshape到[4, 16, 16]
291
 
292
  # print(attn_map[:, :, obj_position].shape)
293
- ca_map_obj = attn_map[:, :, obj_position].sum(-1)
294
 
295
  # print(ca_map_obj.shape)
296
  if smooth_att:
 
242
 
243
  return total_loss/obj_number
244
 
245
+ def caculate_loss_LoCo(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, t, res=16, smooth_att = False,sigma=0.5,kernel_size=3 ):
246
  attn16 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
247
  # attn32 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 32)
248
  # attn64 = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, 64)
 
290
  # 选中物体对应位置(例如[6])的map,然后reshape到[4, 16, 16]
291
 
292
  # print(attn_map[:, :, obj_position].shape)
293
+ ca_map_obj = attn_map[:, :, obj_position].mean(-1)
294
 
295
  # print(ca_map_obj.shape)
296
  if smooth_att: