Spaces:
Running
on
Zero
Running
on
Zero
import pdb | |
import torch | |
import numpy as np | |
from .utils import _broadcast_tensor, _extract_into_tensor | |
class _WrappedModel_DiT: | |
def __init__(self, model, diffusion, device=None, class_emb_null=None): | |
self.model = model | |
self.diffusion = diffusion | |
self._predict_xstart_from_eps = diffusion._predict_xstart_from_eps | |
self.diffusion_t_map = list(diffusion.use_timesteps) | |
self.diffusion_t_map.sort() | |
self.diffusion_t = [self.diffusion_t_map[i] for i in range(diffusion.num_timesteps)] # list(range(diffusion.num_timesteps)) | |
self.diffusion_t = np.array(self.diffusion_t) | |
self.diffusion_sqrt_alpha_cumprod = np.array([diffusion.sqrt_alphas_cumprod[i] for i in range(diffusion.num_timesteps)]) | |
self.fm_steps = [(1 - self.diffusion_sqrt_alpha_cumprod[i]**2)**0.5/(self.diffusion_sqrt_alpha_cumprod[i] + (1 - self.diffusion_sqrt_alpha_cumprod[i]**2)**0.5) for i in range(len(self.diffusion_t))] | |
self.fm_steps = torch.tensor([0] + self.fm_steps, device=device) | |
self.y_null = class_emb_null | |
def __call__(self, x, t, y, kwargs): | |
N = len(self.diffusion_t) | |
B,C,H,W = x.shape | |
diffusion_x = torch.zeros_like(x) | |
diffusion_t = _extract_into_tensor(self.diffusion_t, t-1, t.shape).long() | |
t_fm = self.fm_steps[t] | |
diffusion_x_tmp = _extract_into_tensor(self.diffusion.sqrt_alphas_cumprod, t-1, x.shape) * x / ( 1 + 1e-4 - _broadcast_tensor(t_fm,x.shape)) | |
diffusion_x_tmp = diffusion_x_tmp.to(torch.float) | |
diffusion_x = torch.where(_broadcast_tensor(t,x.shape) == N, x, diffusion_x_tmp) | |
y_null_batch = torch.cat([self.y_null[0].unsqueeze(0)]*B, dim=0) | |
y_new = torch.cat([y, y_null_batch], 0) | |
model_output = self.model(torch.cat([diffusion_x,diffusion_x],dim=0), torch.cat([diffusion_t,diffusion_t],dim=0), y_new, **kwargs) | |
model_output = model_output[:B] | |
model_output, _ = torch.split(model_output, C, dim=1) | |
x0_diffusion = self._predict_xstart_from_eps(x_t=diffusion_x, t=t-1, eps=model_output) | |
vt = (x - x0_diffusion) / (_broadcast_tensor(t_fm,x.shape)) | |
vt = vt.to(diffusion_x.dtype) | |
return vt | |
class _WrappedModel_Sora: | |
def __init__(self, model, guidance_scale, y_null, timesteps, num_timesteps, mask_t): | |
self.model = model | |
self.guidance_scale = guidance_scale | |
self.y_null = y_null | |
self.timesteps = [torch.tensor([0], device=model.device)] + timesteps[::-1] | |
self.timesteps = torch.cat(self.timesteps, dim=0) | |
self.fm_steps = [x/num_timesteps for x in self.timesteps] | |
self.mask_t = mask_t | |
def __call__(self, x, t, y, kwargs): | |
y = torch.cat([y, self.y_null], dim=0) | |
t_in = self.timesteps[t] | |
x_in = torch.cat([x,x], dim=0) | |
# breakpoint() | |
mask_t_upper = self.mask_t >= t_in.unsqueeze(1) | |
kwargs["x_mask"] = mask_t_upper.repeat(2, 1) | |
t_in = torch.cat([t_in,t_in], dim=0) | |
with torch.no_grad(): | |
pred = self.model(x_in, t_in, y, **kwargs).chunk(2, dim=1)[0] | |
# breakpoint() | |
pred_cond, pred_uncond = pred.chunk(2, dim=0) | |
v_pred = pred_uncond + self.guidance_scale * (pred_cond - pred_uncond) | |
return -v_pred | |
class _WrappedModel_Wan: | |
def __init__(self, model, timesteps, num_timesteps, context_null, guide_scale): | |
self.model = model | |
self.context_null = context_null | |
self.guide_scale = guide_scale | |
fm_steps = torch.cat([timesteps,torch.zeros_like(timesteps[0]).view(1)]) | |
self.time_steps = torch.flip(fm_steps, dims=[0]) | |
self.fm_steps = self.time_steps/num_timesteps | |
def __call__(self, x, t, y, kwargs): | |
self.time_steps = self.time_steps.to(t.device) | |
t = self.time_steps[t] | |
noise_pred_cond = self.model(x, t=t, context=y, **kwargs)[0] | |
noise_pred_uncond = self.model(x, t=t, context=self.context_null, **kwargs)[0] | |
noise_pred = noise_pred_uncond + self.guide_scale * (noise_pred_cond - noise_pred_uncond) | |
return noise_pred | |
class _WrappedModel_FLUX: | |
def __init__(self, model, timesteps, num_timesteps): | |
self.model = model | |
fm_steps = torch.cat([timesteps,torch.zeros_like(timesteps[0]).view(1)]) | |
self.time_steps = torch.flip(fm_steps, dims=[0]) | |
self.fm_steps = self.time_steps/num_timesteps | |
def __call__(self, x, t, y, kwargs): | |
t = self.time_steps[t] | |
t = t.expand(x.shape[0]).to(x.dtype) / 1000 | |
pred = self.model(hidden_states=x, timestep=t, **kwargs)[0] | |
return pred | |