|
from typing import Callable |
|
import torch |
|
import torch.nn as nn |
|
|
|
class ModulateDiT(nn.Module): |
|
def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None): |
|
factory_kwargs = {"dtype": dtype, "device": device} |
|
super().__init__() |
|
self.act = act_layer() |
|
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) |
|
|
|
nn.init.zeros_(self.linear.weight) |
|
nn.init.zeros_(self.linear.bias) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.linear(self.act(x)) |
|
|
|
|
|
def modulate(x, shift=None, scale=None): |
|
if x.ndim == 3: |
|
shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None |
|
scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None |
|
if scale is None and shift is None: |
|
return x |
|
elif shift is None: |
|
return x * (1 + scale) |
|
elif scale is None: |
|
return x + shift |
|
else: |
|
return x * (1 + scale) + shift |
|
|
|
|
|
def apply_gate(x, gate=None, tanh=False): |
|
if gate is None: |
|
return x |
|
if gate.ndim == 2 and x.ndim == 3: |
|
gate = gate.unsqueeze(1) |
|
if tanh: |
|
return x * gate.tanh() |
|
else: |
|
return x * gate |
|
|
|
|
|
def ckpt_wrapper(module): |
|
def ckpt_forward(*inputs): |
|
outputs = module(*inputs) |
|
return outputs |
|
|
|
return ckpt_forward |
|
|