File size: 4,622 Bytes
78360e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from typing import Callable

import torch
import torch.nn as nn
import math

class ModulateDiT(nn.Module):
    """Modulation layer for DiT."""
    def __init__(

        self,

        hidden_size: int,

        factor: int,

        act_layer: Callable,

        dtype=None,

        device=None,

    ):
        factory_kwargs = {"dtype": dtype, "device": device}
        super().__init__()
        self.act = act_layer()
        self.linear = nn.Linear(
            hidden_size, factor * hidden_size, bias=True, **factory_kwargs
        )
        # Zero-initialize the modulation
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)

    def forward(self, x: torch.Tensor, condition_type=None, token_replace_vec=None) -> torch.Tensor:
        x_out = self.linear(self.act(x))

        if condition_type == "token_replace":
            x_token_replace_out = self.linear(self.act(token_replace_vec))
            return x_out, x_token_replace_out
        else:
            return x_out

def modulate(x, shift=None, scale=None):
    """modulate by shift and scale



    Args:

        x (torch.Tensor): input tensor.

        shift (torch.Tensor, optional): shift tensor. Defaults to None.

        scale (torch.Tensor, optional): scale tensor. Defaults to None.



    Returns:

        torch.Tensor: the output tensor after modulate.

    """
    if scale is None and shift is None:
        return x
    elif shift is None:
        return x * (1 + scale.unsqueeze(1))
    elif scale is None:
        return x + shift.unsqueeze(1)
    else:
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

def modulate_(x, shift=None, scale=None):

    if scale is None and shift is None:
        return x
    elif shift is None:
        scale = scale + 1
        scale = scale.unsqueeze(1)
        return x.mul_(scale) 
    elif scale is None:
        return x + shift.unsqueeze(1)
    else:
        scale = scale + 1
        scale = scale.unsqueeze(1)
        # return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        torch.addcmul(shift.unsqueeze(1), x,  scale, out =x )
        return x 
    
def modulate(x, shift=None, scale=None, condition_type=None,

             tr_shift=None, tr_scale=None,

             frist_frame_token_num=None):
    if condition_type == "token_replace":
        x_zero = x[:, :frist_frame_token_num] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
        x_orig = x[:, frist_frame_token_num:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = torch.concat((x_zero, x_orig), dim=1)
        return x
    else:
        if scale is None and shift is None:
            return x
        elif shift is None:
            return x * (1 + scale.unsqueeze(1))
        elif scale is None:
            return x + shift.unsqueeze(1)
        else:
            return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
            
def apply_gate(x, gate=None, tanh=False, condition_type=None, tr_gate=None, frist_frame_token_num=None):
    """AI is creating summary for apply_gate



    Args:

        x (torch.Tensor): input tensor.

        gate (torch.Tensor, optional): gate tensor. Defaults to None.

        tanh (bool, optional): whether to use tanh function. Defaults to False.



    Returns:

        torch.Tensor: the output tensor after apply gate.

    """
    if condition_type == "token_replace":
        if gate is None:
            return x
        if tanh:
            x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1).tanh()
            x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1).tanh()
            x = torch.concat((x_zero, x_orig), dim=1)
            return x
        else:
            x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1)
            x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1)
            x = torch.concat((x_zero, x_orig), dim=1)
            return x
    else:
        if gate is None:
            return x
        if tanh:
            return x * gate.unsqueeze(1).tanh()
        else:
            return x * gate.unsqueeze(1)
        
def apply_gate_and_accumulate_(accumulator, x, gate=None, tanh=False):
    if gate is None:
        return accumulator
    if tanh:
        return accumulator.addcmul_(x, gate.unsqueeze(1).tanh())   
    else:
        return accumulator.addcmul_(x, gate.unsqueeze(1))
    
def ckpt_wrapper(module):
    def ckpt_forward(*inputs):
        outputs = module(*inputs)
        return outputs

    return ckpt_forward