File size: 4,622 Bytes
78360e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from typing import Callable
import torch
import torch.nn as nn
import math
class ModulateDiT(nn.Module):
"""Modulation layer for DiT."""
def __init__(
self,
hidden_size: int,
factor: int,
act_layer: Callable,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.act = act_layer()
self.linear = nn.Linear(
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
)
# Zero-initialize the modulation
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor, condition_type=None, token_replace_vec=None) -> torch.Tensor:
x_out = self.linear(self.act(x))
if condition_type == "token_replace":
x_token_replace_out = self.linear(self.act(token_replace_vec))
return x_out, x_token_replace_out
else:
return x_out
def modulate(x, shift=None, scale=None):
"""modulate by shift and scale
Args:
x (torch.Tensor): input tensor.
shift (torch.Tensor, optional): shift tensor. Defaults to None.
scale (torch.Tensor, optional): scale tensor. Defaults to None.
Returns:
torch.Tensor: the output tensor after modulate.
"""
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def modulate_(x, shift=None, scale=None):
if scale is None and shift is None:
return x
elif shift is None:
scale = scale + 1
scale = scale.unsqueeze(1)
return x.mul_(scale)
elif scale is None:
return x + shift.unsqueeze(1)
else:
scale = scale + 1
scale = scale.unsqueeze(1)
# return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
torch.addcmul(shift.unsqueeze(1), x, scale, out =x )
return x
def modulate(x, shift=None, scale=None, condition_type=None,
tr_shift=None, tr_scale=None,
frist_frame_token_num=None):
if condition_type == "token_replace":
x_zero = x[:, :frist_frame_token_num] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
x_orig = x[:, frist_frame_token_num:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
else:
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False, condition_type=None, tr_gate=None, frist_frame_token_num=None):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if condition_type == "token_replace":
if gate is None:
return x
if tanh:
x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1).tanh()
x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1).tanh()
x = torch.concat((x_zero, x_orig), dim=1)
return x
else:
x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1)
x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
else:
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def apply_gate_and_accumulate_(accumulator, x, gate=None, tanh=False):
if gate is None:
return accumulator
if tanh:
return accumulator.addcmul_(x, gate.unsqueeze(1).tanh())
else:
return accumulator.addcmul_(x, gate.unsqueeze(1))
def ckpt_wrapper(module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
|