|
import torch as th |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from nn import timestep_embedding |
|
from unet import UNetModel |
|
from transformer import LayerNorm, Transformer, convert_module_to_f16 |
|
|
|
|
|
class Text2ImUNet(UNetModel): |
|
""" |
|
A UNetModel that conditions on text with an encoding transformer. |
|
|
|
Expects an extra kwarg `tokens` of text. |
|
|
|
:param text_ctx: number of text tokens to expect. |
|
:param xf_width: width of the transformer. |
|
:param xf_layers: depth of the transformer. |
|
:param xf_heads: heads in the transformer. |
|
:param xf_final_ln: use a LayerNorm after the output layer. |
|
:param tokenizer: the text tokenizer for sampling/vocab size. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
text_ctx, |
|
xf_width, |
|
xf_layers, |
|
xf_heads, |
|
xf_final_ln, |
|
tokenizer, |
|
*args, |
|
cache_text_emb=False, |
|
xf_ar=0.0, |
|
xf_padding=False, |
|
share_unemb=False, |
|
**kwargs, |
|
): |
|
self.text_ctx = text_ctx |
|
self.xf_width = xf_width |
|
self.xf_ar = xf_ar |
|
self.xf_padding = xf_padding |
|
self.tokenizer = tokenizer |
|
|
|
if not xf_width: |
|
super().__init__(*args, **kwargs, encoder_channels=None) |
|
else: |
|
super().__init__(*args, **kwargs, encoder_channels=xf_width) |
|
if self.xf_width: |
|
self.transformer = Transformer( |
|
text_ctx, |
|
xf_width, |
|
xf_layers, |
|
xf_heads, |
|
) |
|
if xf_final_ln: |
|
self.final_ln = LayerNorm(xf_width) |
|
else: |
|
self.final_ln = None |
|
|
|
self.token_embedding = nn.Embedding(self.tokenizer.n_vocab, xf_width) |
|
self.positional_embedding = nn.Parameter(th.empty(text_ctx, xf_width, dtype=th.float32)) |
|
self.transformer_proj = nn.Linear(xf_width, self.model_channels * 4) |
|
|
|
if self.xf_padding: |
|
self.padding_embedding = nn.Parameter( |
|
th.empty(text_ctx, xf_width, dtype=th.float32) |
|
) |
|
if self.xf_ar: |
|
self.unemb = nn.Linear(xf_width, self.tokenizer.n_vocab) |
|
if share_unemb: |
|
self.unemb.weight = self.token_embedding.weight |
|
|
|
self.cache_text_emb = cache_text_emb |
|
self.cache = None |
|
|
|
def convert_to_fp16(self): |
|
super().convert_to_fp16() |
|
if self.xf_width: |
|
self.transformer.apply(convert_module_to_f16) |
|
self.transformer_proj.to(th.float16) |
|
self.token_embedding.to(th.float16) |
|
self.positional_embedding.to(th.float16) |
|
if self.xf_padding: |
|
self.padding_embedding.to(th.float16) |
|
if self.xf_ar: |
|
self.unemb.to(th.float16) |
|
|
|
def get_text_emb(self, tokens, mask): |
|
assert tokens is not None |
|
|
|
if self.cache_text_emb and self.cache is not None: |
|
assert ( |
|
tokens == self.cache["tokens"] |
|
).all(), f"Tokens {tokens.cpu().numpy().tolist()} do not match cache {self.cache['tokens'].cpu().numpy().tolist()}" |
|
return self.cache |
|
|
|
xf_in = self.token_embedding(tokens.long()) |
|
xf_in = xf_in + self.positional_embedding[None] |
|
if self.xf_padding: |
|
assert mask is not None |
|
xf_in = th.where(mask[..., None], xf_in, self.padding_embedding[None]) |
|
xf_out = self.transformer(xf_in.to(self.dtype)) |
|
if self.final_ln is not None: |
|
xf_out = self.final_ln(xf_out) |
|
xf_proj = self.transformer_proj(xf_out[:, -1]) |
|
xf_out = xf_out.permute(0, 2, 1) |
|
|
|
outputs = dict(xf_proj=xf_proj, xf_out=xf_out) |
|
|
|
if self.cache_text_emb: |
|
self.cache = dict( |
|
tokens=tokens, |
|
xf_proj=xf_proj.detach(), |
|
xf_out=xf_out.detach() if xf_out is not None else None, |
|
) |
|
|
|
return outputs |
|
|
|
def del_cache(self): |
|
self.cache = None |
|
|
|
def forward(self, x, timesteps, tokens=None, mask=None): |
|
hs = [] |
|
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) |
|
if self.xf_width: |
|
text_outputs = self.get_text_emb(tokens, mask) |
|
xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"] |
|
emb = emb + xf_proj.to(emb) |
|
else: |
|
xf_out = None |
|
h = x.type(self.dtype) |
|
for module in self.input_blocks: |
|
h = module(h, emb, xf_out) |
|
hs.append(h) |
|
h = self.middle_block(h, emb, xf_out) |
|
for module in self.output_blocks: |
|
h = th.cat([h, hs.pop()], dim=1) |
|
h = module(h, emb, xf_out) |
|
h = h.type(x.dtype) |
|
h = self.out(h) |
|
return h |
|
|
|
|
|
class SuperResText2ImUNet(Text2ImUNet): |
|
""" |
|
A text2im model that performs super-resolution. |
|
Expects an extra kwarg `low_res` to condition on a low-resolution image. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
if "in_channels" in kwargs: |
|
kwargs = dict(kwargs) |
|
kwargs["in_channels"] = kwargs["in_channels"] * 2 |
|
else: |
|
|
|
args = list(args) |
|
args[1] = args[1] * 2 |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, x, timesteps, low_res=None, **kwargs): |
|
_, _, new_height, new_width = x.shape |
|
upsampled = F.interpolate( |
|
low_res, (new_height, new_width), mode="bilinear", align_corners=False |
|
) |
|
x = th.cat([x, upsampled], dim=1) |
|
return super().forward(x, timesteps, **kwargs) |
|
|
|
|
|
class InpaintText2ImUNet(Text2ImUNet): |
|
""" |
|
A text2im model which can perform inpainting. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
if "in_channels" in kwargs: |
|
kwargs = dict(kwargs) |
|
kwargs["in_channels"] = kwargs["in_channels"] * 2 + 1 |
|
else: |
|
|
|
args = list(args) |
|
args[1] = args[1] * 2 + 1 |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, x, timesteps, inpaint_image=None, inpaint_mask=None, **kwargs): |
|
if inpaint_image is None: |
|
inpaint_image = th.zeros_like(x) |
|
if inpaint_mask is None: |
|
inpaint_mask = th.zeros_like(x[:, :1]) |
|
return super().forward( |
|
th.cat([x, inpaint_image * inpaint_mask, inpaint_mask], dim=1), |
|
timesteps, |
|
**kwargs, |
|
) |
|
|
|
|
|
class SuperResInpaintText2ImUnet(Text2ImUNet): |
|
""" |
|
A text2im model which can perform both upsampling and inpainting. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
if "in_channels" in kwargs: |
|
kwargs = dict(kwargs) |
|
kwargs["in_channels"] = kwargs["in_channels"] * 3 + 1 |
|
else: |
|
|
|
args = list(args) |
|
args[1] = args[1] * 3 + 1 |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward( |
|
self, |
|
x, |
|
timesteps, |
|
inpaint_image=None, |
|
inpaint_mask=None, |
|
low_res=None, |
|
**kwargs, |
|
): |
|
if inpaint_image is None: |
|
inpaint_image = th.zeros_like(x) |
|
if inpaint_mask is None: |
|
inpaint_mask = th.zeros_like(x[:, :1]) |
|
_, _, new_height, new_width = x.shape |
|
upsampled = F.interpolate( |
|
low_res, (new_height, new_width), mode="bilinear", align_corners=False |
|
) |
|
return super().forward( |
|
th.cat([x, inpaint_image * inpaint_mask, inpaint_mask, upsampled], dim=1), |
|
timesteps, |
|
**kwargs, |
|
) |