Cbrown9's picture
Upload 24 files
3569ac3 verified
from PIL import Image
from IPython.display import display
import torch as th
import torch.nn as nn
from clip.model_creation import create_clip_model
from download import load_checkpoint
from create_model import (
create_model_and_diffusion,
model_and_diffusion_defaults,
model_and_diffusion_defaults_upsampler,
)
from tokenizer.simple_tokenizer import SimpleTokenizer
has_cuda = th.cuda.is_available()
device = th.device('cpu' if not has_cuda else 'cuda')
options = model_and_diffusion_defaults()
options['use_fp16'] = has_cuda
options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
model, diffusion = create_model_and_diffusion(**options)
model.eval()
if has_cuda:
model.convert_to_fp16()
model.to(device)
model.load_state_dict(load_checkpoint('base', device))
print('total base parameters', sum(x.numel() for x in model.parameters()))
options_up = model_and_diffusion_defaults_upsampler()
options_up['use_fp16'] = has_cuda
options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
model_up, diffusion_up = create_model_and_diffusion(**options_up)
model_up.eval()
if has_cuda:
model_up.convert_to_fp16()
model_up.to(device)
model_up.load_state_dict(load_checkpoint('upsample', device))
print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))
def show_images(batch: th.Tensor):
""" Display a batch of images inline. """
scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
image = Image.fromarray(reshaped.numpy())
image.show()
prompt = "an oil painting of a corgi"
batch_size = 1
guidance_scale = 3.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
# Create the text tokens to feed to the model.
tokens = model.tokenizer.encode(prompt)
tokens, mask = model.tokenizer.padded_tokens_and_mask(
tokens, options['text_ctx']
)
# Create the classifier-free guidance tokens (empty)
full_batch_size = batch_size * 2
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
[], options['text_ctx']
)
# Pack the tokens together into model kwargs.
model_kwargs = dict(
tokens=th.tensor(
[tokens] * batch_size + [uncond_tokens] * batch_size, device=device
),
mask=th.tensor(
[mask] * batch_size + [uncond_mask] * batch_size,
dtype=th.bool,
device=device,
),
)
# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // 2]
combined = th.cat([half, half], dim=0)
model_out = model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = th.cat([half_eps, half_eps], dim=0)
return th.cat([eps, rest], dim=1)
# Sample from the base model.
model.del_cache()
samples = diffusion.p_sample_loop(
model_fn,
(full_batch_size, 3, options["image_size"], options["image_size"]),
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
model.del_cache()
# Show the output
show_images(samples)
tokens = model_up.tokenizer.encode(prompt)
tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
tokens, options_up['text_ctx']
)
# Create the model conditioning dict.
model_kwargs = dict(
# Low-res image to upsample.
low_res=((samples+1)*127.5).round()/127.5 - 1,
# Text tokens
tokens=th.tensor(
[tokens] * batch_size, device=device
),
mask=th.tensor(
[mask] * batch_size,
dtype=th.bool,
device=device,
),
)
# Sample from the base model.
model_up.del_cache()
up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
up_samples = diffusion_up.ddim_sample_loop(
model_up,
up_shape,
noise=th.randn(up_shape, device=device) * upsample_temp,
device=device,
clip_denoised=True,
progress=True,
model_kwargs=model_kwargs,
cond_fn=None,
)[:batch_size]
model_up.del_cache()
# Show the output
show_images(up_samples)