ByteMorph-Demo / src /flux /sampling.py
bol
init
99738e0
raw
history blame
10.5 kB
import math
from typing import Callable
import torch
from einops import rearrange, repeat
from torch import Tensor
from .model import Flux
from .modules.conditioner import HFEmbedder
def get_noise(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
16,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str], use_spatial_condition=False, use_share_weight_referencenet=False, share_position_embedding=False) -> dict[str, Tensor]:
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
if use_spatial_condition:
if share_position_embedding:
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_ids = torch.cat([img_ids, img_ids], dim=1)
else:
img_ids = torch.zeros(h // 2, w, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
elif use_share_weight_referencenet:
if share_position_embedding:
single_img_ids = torch.zeros(h // 2, w // 2, 3)
single_img_ids[..., 1] = single_img_ids[..., 1] + torch.arange(h // 2)[:, None]
single_img_ids[..., 2] = single_img_ids[..., 2] + torch.arange(w // 2)[None, :]
single_img_ids = repeat(single_img_ids, "h w c -> b (h w) c", b=bs)
img_ids = torch.cat([single_img_ids, single_img_ids], dim=1)
else:
# single_img_position_embedding
single_img_ids = torch.zeros(h // 2, w // 2, 3)
single_img_ids[..., 1] = single_img_ids[..., 1] + torch.arange(h // 2)[:, None]
single_img_ids[..., 2] = single_img_ids[..., 2] + torch.arange(w // 2)[None, :]
single_img_ids = repeat(single_img_ids, "h w c -> b (h w) c", b=bs)
# ref_and_noise_img_position_embedding
img_ids = torch.zeros(h // 2, w, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
else:
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
if not use_share_weight_referencenet:
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
}
else:
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
"single_img_ids": single_img_ids.to(img.device),
}
def time_shift(mu: float, sigma: float, t: Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: Flux,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
neg_txt: Tensor,
neg_txt_ids: Tensor,
neg_vec: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
true_gs = 1,
timestep_to_start_cfg=0,
# ip-adapter parameters
image_proj: Tensor=None,
neg_image_proj: Tensor=None,
ip_scale: Tensor | float = 1.0,
neg_ip_scale: Tensor | float = 1.0,
source_image: Tensor=None,
use_share_weight_referencenet=False,
single_img_ids=None,
neg_single_img_ids=None,
single_block_refnet=False,
double_block_refnet=False,
):
i = 0
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
if source_image is not None: # spatial condition or refnet
img = torch.cat([source_image, img],dim=-2)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
image_proj=image_proj,
ip_scale=ip_scale,
use_share_weight_referencenet=use_share_weight_referencenet,
single_img_ids=single_img_ids,
single_block_refnet=single_block_refnet,
double_block_refnet=double_block_refnet,
)
if i >= timestep_to_start_cfg:
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
image_proj=neg_image_proj,
ip_scale=neg_ip_scale,
use_share_weight_referencenet=use_share_weight_referencenet,
single_img_ids=neg_single_img_ids,
single_block_refnet=single_block_refnet,
double_block_refnet=double_block_refnet,
)
pred = neg_pred + true_gs * (pred - neg_pred)
if use_share_weight_referencenet:
zero_buffer = torch.zeros_like(pred)
pred = torch.cat([zero_buffer, pred], dim=1)
img = img + (t_prev - t_curr) * pred
if (source_image is not None): # spatial condition or refnet
latent_length = img.shape[-2] // 2
img = img[:,latent_length:,:]
i += 1
return img
def denoise_controlnet(
model: Flux,
controlnet:None,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
vec: Tensor,
neg_txt: Tensor,
neg_txt_ids: Tensor,
neg_vec: Tensor,
controlnet_cond,
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
true_gs = 1,
controlnet_gs=0.7,
timestep_to_start_cfg=0,
# ip-adapter parameters
image_proj: Tensor=None,
neg_image_proj: Tensor=None,
ip_scale: Tensor | float = 1,
neg_ip_scale: Tensor | float = 1,
):
# this is ignored for schnell
i = 0
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
block_res_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_cond,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples],
image_proj=image_proj,
ip_scale=ip_scale,
)
if i >= timestep_to_start_cfg:
neg_block_res_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_cond,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
)
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
timesteps=t_vec,
guidance=guidance_vec,
block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples],
image_proj=neg_image_proj,
ip_scale=neg_ip_scale,
)
pred = neg_pred + true_gs * (pred - neg_pred)
img = img + (t_prev - t_curr) * pred
i += 1
return img
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(height / 16),
w=math.ceil(width / 16),
ph=2,
pw=2,
)