Fabrice-TIERCELIN commited on
Commit
8b838b9
·
verified ·
1 Parent(s): 2e62c1a

Update diffusers_helper/models/hunyuan_video_packed.py

Browse files
diffusers_helper/models/hunyuan_video_packed.py CHANGED
@@ -122,21 +122,17 @@ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq
122
  x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
  return x
124
 
125
- B, L, H, C = q.shape
126
-
127
- q = q.flatten(0, 1)
128
- k = k.flatten(0, 1)
129
- v = v.flatten(0, 1)
130
-
131
  if sageattn_varlen is not None:
132
  x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
  elif flash_attn_varlen_func is not None:
134
  x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
135
  else:
136
  raise NotImplementedError('No Attn Installed!')
137
-
138
- x = x.unflatten(0, (B, L))
139
-
140
  return x
141
 
142
 
@@ -362,7 +358,7 @@ class HunyuanVideoIndividualTokenRefiner(nn.Module):
362
  batch_size = attention_mask.shape[0]
363
  seq_len = attention_mask.shape[1]
364
  attention_mask = attention_mask.to(hidden_states.device).bool()
365
- self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).expand(-1, -1, seq_len, -1)
366
  self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
367
  self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
368
  self_attn_mask[:, :, :, 0] = True
@@ -930,22 +926,23 @@ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterM
930
  encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
931
  encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
932
 
933
- if batch_size == 1:
934
- # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
935
- # If they are not same, then their impls are wrong. Ours are always the correct one.
936
- text_len = encoder_attention_mask.sum().item()
937
- encoder_hidden_states = encoder_hidden_states[:, :text_len]
938
- attention_mask = None, None, None, None
939
- else:
940
- img_seq_len = hidden_states.shape[1]
941
- txt_seq_len = encoder_hidden_states.shape[1]
 
942
 
943
- cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
944
- cu_seqlens_kv = cu_seqlens_q
945
- max_seqlen_q = img_seq_len + txt_seq_len
946
- max_seqlen_kv = max_seqlen_q
947
 
948
- attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
949
 
950
  if self.enable_teacache:
951
  modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
 
122
  x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
  return x
124
 
125
+ batch_size = q.shape[0]
126
+ q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
127
+ k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
128
+ v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
 
 
129
  if sageattn_varlen is not None:
130
  x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
131
  elif flash_attn_varlen_func is not None:
132
  x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
  else:
134
  raise NotImplementedError('No Attn Installed!')
135
+ x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
 
 
136
  return x
137
 
138
 
 
358
  batch_size = attention_mask.shape[0]
359
  seq_len = attention_mask.shape[1]
360
  attention_mask = attention_mask.to(hidden_states.device).bool()
361
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
362
  self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
363
  self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
364
  self_attn_mask[:, :, :, 0] = True
 
926
  encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
927
  encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
928
 
929
+ with torch.no_grad():
930
+ if batch_size == 1:
931
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
932
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
933
+ text_len = encoder_attention_mask.sum().item()
934
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
935
+ attention_mask = None, None, None, None
936
+ else:
937
+ img_seq_len = hidden_states.shape[1]
938
+ txt_seq_len = encoder_hidden_states.shape[1]
939
 
940
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
941
+ cu_seqlens_kv = cu_seqlens_q
942
+ max_seqlen_q = img_seq_len + txt_seq_len
943
+ max_seqlen_kv = max_seqlen_q
944
 
945
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
946
 
947
  if self.enable_teacache:
948
  modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]