File size: 1,198 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Optional

import diffusers
import torch


def patch_time_text_image_embedding_forward() -> None:
    _patch_time_text_image_embedding_forward()


def _patch_time_text_image_embedding_forward() -> None:
    diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
        _patched_WanTimeTextImageEmbedding_forward
    )


def _patched_WanTimeTextImageEmbedding_forward(
    self,
    timestep: torch.Tensor,
    encoder_hidden_states: torch.Tensor,
    encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
    # Some code has been removed compared to original implementation in Diffusers
    # Also, timestep is typed as that of encoder_hidden_states
    timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
    temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
    timestep_proj = self.time_proj(self.act_fn(temb))

    encoder_hidden_states = self.text_embedder(encoder_hidden_states)
    if encoder_hidden_states_image is not None:
        encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)

    return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image