jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
from typing import Optional
import diffusers
import torch
def patch_time_text_image_embedding_forward() -> None:
_patch_time_text_image_embedding_forward()
def _patch_time_text_image_embedding_forward() -> None:
diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
_patched_WanTimeTextImageEmbedding_forward
)
def _patched_WanTimeTextImageEmbedding_forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
# Some code has been removed compared to original implementation in Diffusers
# Also, timestep is typed as that of encoder_hidden_states
timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image