Spaces:
Running
on
Zero
Running
on
Zero
<fix> add tp ip attn enhancement control in forward.
Browse files
app.py
CHANGED
|
@@ -264,6 +264,7 @@ def process_image_and_text(condition_image, target_prompt, condition_image_promp
|
|
| 264 |
frame_gap=48,
|
| 265 |
mixup=True,
|
| 266 |
mixup_num_imgs=2,
|
|
|
|
| 267 |
).frames
|
| 268 |
|
| 269 |
gen_img = gen_img[:, 0:1, :, :, :]
|
|
|
|
| 264 |
frame_gap=48,
|
| 265 |
mixup=True,
|
| 266 |
mixup_num_imgs=2,
|
| 267 |
+
enhance_tp=task in ['subject_driven', 'style_transfer'],
|
| 268 |
).frames
|
| 269 |
|
| 270 |
gen_img = gen_img[:, 0:1, :, :, :]
|
models/hyvideo/transformer_hunyuan_video_i2v.py
CHANGED
|
@@ -64,6 +64,7 @@ class HunyuanVideoAttnProcessor2_0:
|
|
| 64 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 65 |
attention_mask: Optional[torch.Tensor] = None,
|
| 66 |
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
|
|
| 67 |
) -> torch.Tensor:
|
| 68 |
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 69 |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
|
@@ -154,7 +155,7 @@ class HunyuanVideoAttnProcessor2_0:
|
|
| 154 |
k_lens = torch.tensor([sum([u[seg_start[seg]:seg_end[seg]].long().sum().item() for seg in segs]) for u in valid_indices for segs in k_segs],
|
| 155 |
dtype=torch.int32, device=valid_indices.device)
|
| 156 |
query = torch.cat([u[i:j][v[i:j]] for u,v in zip(query, valid_indices) for i,j in zip(seg_start, seg_end)], dim=0)
|
| 157 |
-
if self.inference_subject_driven:
|
| 158 |
key = torch.cat([torch.cat([ torch.cat([u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][:144], u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:] + 0.6 * u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:].abs().mean()], dim=0) if segs == [0, 1, 2] and seg == 2 else u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]] for seg in segs], dim=0) \
|
| 159 |
for u,v in zip(key, valid_indices) for segs in k_segs], dim=0)
|
| 160 |
else:
|
|
@@ -756,6 +757,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
|
| 756 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 757 |
token_replace_emb: torch.Tensor = None,
|
| 758 |
num_tokens: int = None,
|
|
|
|
| 759 |
) -> torch.Tensor:
|
| 760 |
text_seq_length = encoder_hidden_states.shape[1]
|
| 761 |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
|
@@ -777,6 +779,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
|
|
| 777 |
encoder_hidden_states=norm_encoder_hidden_states,
|
| 778 |
attention_mask=attention_mask,
|
| 779 |
image_rotary_emb=image_rotary_emb,
|
|
|
|
| 780 |
)
|
| 781 |
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
| 782 |
|
|
@@ -841,6 +844,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
|
| 841 |
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 842 |
token_replace_emb: torch.Tensor = None,
|
| 843 |
num_tokens: int = None,
|
|
|
|
| 844 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 845 |
# 1. Input normalization
|
| 846 |
(
|
|
@@ -864,6 +868,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
|
| 864 |
encoder_hidden_states=norm_encoder_hidden_states,
|
| 865 |
attention_mask=attention_mask,
|
| 866 |
image_rotary_emb=freqs_cis,
|
|
|
|
| 867 |
)
|
| 868 |
|
| 869 |
# 3. Modulation and residual connection
|
|
@@ -1109,6 +1114,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
| 1109 |
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1110 |
return_dict: bool = True,
|
| 1111 |
frame_gap: Union[int, None] = None,
|
|
|
|
| 1112 |
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 1113 |
if attention_kwargs is not None:
|
| 1114 |
attention_kwargs = attention_kwargs.copy()
|
|
@@ -1181,6 +1187,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
| 1181 |
image_rotary_emb,
|
| 1182 |
token_replace_emb,
|
| 1183 |
first_frame_num_tokens,
|
|
|
|
| 1184 |
)
|
| 1185 |
|
| 1186 |
for block in self.single_transformer_blocks:
|
|
@@ -1193,6 +1200,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
| 1193 |
image_rotary_emb,
|
| 1194 |
token_replace_emb,
|
| 1195 |
first_frame_num_tokens,
|
|
|
|
| 1196 |
)
|
| 1197 |
|
| 1198 |
else:
|
|
@@ -1205,6 +1213,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
| 1205 |
image_rotary_emb,
|
| 1206 |
token_replace_emb,
|
| 1207 |
first_frame_num_tokens,
|
|
|
|
| 1208 |
)
|
| 1209 |
|
| 1210 |
for block in self.single_transformer_blocks:
|
|
@@ -1216,6 +1225,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
| 1216 |
image_rotary_emb,
|
| 1217 |
token_replace_emb,
|
| 1218 |
first_frame_num_tokens,
|
|
|
|
| 1219 |
)
|
| 1220 |
|
| 1221 |
# 5. Output projection
|
|
|
|
| 64 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 65 |
attention_mask: Optional[torch.Tensor] = None,
|
| 66 |
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 67 |
+
enhance_tp: bool = False,
|
| 68 |
) -> torch.Tensor:
|
| 69 |
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 70 |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
|
|
|
| 155 |
k_lens = torch.tensor([sum([u[seg_start[seg]:seg_end[seg]].long().sum().item() for seg in segs]) for u in valid_indices for segs in k_segs],
|
| 156 |
dtype=torch.int32, device=valid_indices.device)
|
| 157 |
query = torch.cat([u[i:j][v[i:j]] for u,v in zip(query, valid_indices) for i,j in zip(seg_start, seg_end)], dim=0)
|
| 158 |
+
if self.inference_subject_driven or enhance_tp:
|
| 159 |
key = torch.cat([torch.cat([ torch.cat([u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][:144], u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:] + 0.6 * u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]][144:].abs().mean()], dim=0) if segs == [0, 1, 2] and seg == 2 else u[seg_start[seg]:seg_end[seg]][v[seg_start[seg]:seg_end[seg]]] for seg in segs], dim=0) \
|
| 160 |
for u,v in zip(key, valid_indices) for segs in k_segs], dim=0)
|
| 161 |
else:
|
|
|
|
| 757 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 758 |
token_replace_emb: torch.Tensor = None,
|
| 759 |
num_tokens: int = None,
|
| 760 |
+
enhance_tp: bool = False,
|
| 761 |
) -> torch.Tensor:
|
| 762 |
text_seq_length = encoder_hidden_states.shape[1]
|
| 763 |
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
|
|
|
| 779 |
encoder_hidden_states=norm_encoder_hidden_states,
|
| 780 |
attention_mask=attention_mask,
|
| 781 |
image_rotary_emb=image_rotary_emb,
|
| 782 |
+
enhance_tp=enhance_tp,
|
| 783 |
)
|
| 784 |
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
| 785 |
|
|
|
|
| 844 |
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 845 |
token_replace_emb: torch.Tensor = None,
|
| 846 |
num_tokens: int = None,
|
| 847 |
+
enhance_tp: bool = False,
|
| 848 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 849 |
# 1. Input normalization
|
| 850 |
(
|
|
|
|
| 868 |
encoder_hidden_states=norm_encoder_hidden_states,
|
| 869 |
attention_mask=attention_mask,
|
| 870 |
image_rotary_emb=freqs_cis,
|
| 871 |
+
enhance_tp=enhance_tp,
|
| 872 |
)
|
| 873 |
|
| 874 |
# 3. Modulation and residual connection
|
|
|
|
| 1114 |
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1115 |
return_dict: bool = True,
|
| 1116 |
frame_gap: Union[int, None] = None,
|
| 1117 |
+
enhance_tp: bool = False,
|
| 1118 |
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 1119 |
if attention_kwargs is not None:
|
| 1120 |
attention_kwargs = attention_kwargs.copy()
|
|
|
|
| 1187 |
image_rotary_emb,
|
| 1188 |
token_replace_emb,
|
| 1189 |
first_frame_num_tokens,
|
| 1190 |
+
enhance_tp,
|
| 1191 |
)
|
| 1192 |
|
| 1193 |
for block in self.single_transformer_blocks:
|
|
|
|
| 1200 |
image_rotary_emb,
|
| 1201 |
token_replace_emb,
|
| 1202 |
first_frame_num_tokens,
|
| 1203 |
+
enhance_tp,
|
| 1204 |
)
|
| 1205 |
|
| 1206 |
else:
|
|
|
|
| 1213 |
image_rotary_emb,
|
| 1214 |
token_replace_emb,
|
| 1215 |
first_frame_num_tokens,
|
| 1216 |
+
enhance_tp,
|
| 1217 |
)
|
| 1218 |
|
| 1219 |
for block in self.single_transformer_blocks:
|
|
|
|
| 1225 |
image_rotary_emb,
|
| 1226 |
token_replace_emb,
|
| 1227 |
first_frame_num_tokens,
|
| 1228 |
+
enhance_tp,
|
| 1229 |
)
|
| 1230 |
|
| 1231 |
# 5. Output projection
|
pipelines/pipeline_hunyuan_video_i2v.py
CHANGED
|
@@ -649,6 +649,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
|
| 649 |
frame_gap: Union[int, None] = None,
|
| 650 |
mixup: bool = False,
|
| 651 |
mixup_num_imgs: Union[int, None] = None,
|
|
|
|
| 652 |
):
|
| 653 |
r"""
|
| 654 |
The call function to the pipeline for generation.
|
|
@@ -899,6 +900,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
|
|
| 899 |
attention_kwargs=attention_kwargs,
|
| 900 |
return_dict=False,
|
| 901 |
frame_gap=int(frame_gap / 4) if frame_gap is not None else frame_gap,
|
|
|
|
| 902 |
)[0]
|
| 903 |
|
| 904 |
if do_true_cfg:
|
|
|
|
| 649 |
frame_gap: Union[int, None] = None,
|
| 650 |
mixup: bool = False,
|
| 651 |
mixup_num_imgs: Union[int, None] = None,
|
| 652 |
+
enhance_tp: bool = False,
|
| 653 |
):
|
| 654 |
r"""
|
| 655 |
The call function to the pipeline for generation.
|
|
|
|
| 900 |
attention_kwargs=attention_kwargs,
|
| 901 |
return_dict=False,
|
| 902 |
frame_gap=int(frame_gap / 4) if frame_gap is not None else frame_gap,
|
| 903 |
+
enhance_tp=enhance_tp,
|
| 904 |
)[0]
|
| 905 |
|
| 906 |
if do_true_cfg:
|