Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,783 Bytes
4f7b5ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import pdb
import torch
import numpy as np
from .utils import _broadcast_tensor, _extract_into_tensor
class _WrappedModel_DiT:
def __init__(self, model, diffusion, device=None, class_emb_null=None):
self.model = model
self.diffusion = diffusion
self._predict_xstart_from_eps = diffusion._predict_xstart_from_eps
self.diffusion_t_map = list(diffusion.use_timesteps)
self.diffusion_t_map.sort()
self.diffusion_t = [self.diffusion_t_map[i] for i in range(diffusion.num_timesteps)] # list(range(diffusion.num_timesteps))
self.diffusion_t = np.array(self.diffusion_t)
self.diffusion_sqrt_alpha_cumprod = np.array([diffusion.sqrt_alphas_cumprod[i] for i in range(diffusion.num_timesteps)])
self.fm_steps = [(1 - self.diffusion_sqrt_alpha_cumprod[i]**2)**0.5/(self.diffusion_sqrt_alpha_cumprod[i] + (1 - self.diffusion_sqrt_alpha_cumprod[i]**2)**0.5) for i in range(len(self.diffusion_t))]
self.fm_steps = torch.tensor([0] + self.fm_steps, device=device)
self.y_null = class_emb_null
def __call__(self, x, t, y, kwargs):
N = len(self.diffusion_t)
B,C,H,W = x.shape
diffusion_x = torch.zeros_like(x)
diffusion_t = _extract_into_tensor(self.diffusion_t, t-1, t.shape).long()
t_fm = self.fm_steps[t]
diffusion_x_tmp = _extract_into_tensor(self.diffusion.sqrt_alphas_cumprod, t-1, x.shape) * x / ( 1 + 1e-4 - _broadcast_tensor(t_fm,x.shape))
diffusion_x_tmp = diffusion_x_tmp.to(torch.float)
diffusion_x = torch.where(_broadcast_tensor(t,x.shape) == N, x, diffusion_x_tmp)
y_null_batch = torch.cat([self.y_null[0].unsqueeze(0)]*B, dim=0)
y_new = torch.cat([y, y_null_batch], 0)
model_output = self.model(torch.cat([diffusion_x,diffusion_x],dim=0), torch.cat([diffusion_t,diffusion_t],dim=0), y_new, **kwargs)
model_output = model_output[:B]
model_output, _ = torch.split(model_output, C, dim=1)
x0_diffusion = self._predict_xstart_from_eps(x_t=diffusion_x, t=t-1, eps=model_output)
vt = (x - x0_diffusion) / (_broadcast_tensor(t_fm,x.shape))
vt = vt.to(diffusion_x.dtype)
return vt
class _WrappedModel_Sora:
def __init__(self, model, guidance_scale, y_null, timesteps, num_timesteps, mask_t):
self.model = model
self.guidance_scale = guidance_scale
self.y_null = y_null
self.timesteps = [torch.tensor([0], device=model.device)] + timesteps[::-1]
self.timesteps = torch.cat(self.timesteps, dim=0)
self.fm_steps = [x/num_timesteps for x in self.timesteps]
self.mask_t = mask_t
def __call__(self, x, t, y, kwargs):
y = torch.cat([y, self.y_null], dim=0)
t_in = self.timesteps[t]
x_in = torch.cat([x,x], dim=0)
# breakpoint()
mask_t_upper = self.mask_t >= t_in.unsqueeze(1)
kwargs["x_mask"] = mask_t_upper.repeat(2, 1)
t_in = torch.cat([t_in,t_in], dim=0)
with torch.no_grad():
pred = self.model(x_in, t_in, y, **kwargs).chunk(2, dim=1)[0]
# breakpoint()
pred_cond, pred_uncond = pred.chunk(2, dim=0)
v_pred = pred_uncond + self.guidance_scale * (pred_cond - pred_uncond)
return -v_pred
class _WrappedModel_Wan:
def __init__(self, model, timesteps, num_timesteps, context_null, guide_scale):
self.model = model
self.context_null = context_null
self.guide_scale = guide_scale
fm_steps = torch.cat([timesteps,torch.zeros_like(timesteps[0]).view(1)])
self.time_steps = torch.flip(fm_steps, dims=[0])
self.fm_steps = self.time_steps/num_timesteps
def __call__(self, x, t, y, kwargs):
self.time_steps = self.time_steps.to(t.device)
t = self.time_steps[t]
noise_pred_cond = self.model(x, t=t, context=y, **kwargs)[0]
noise_pred_uncond = self.model(x, t=t, context=self.context_null, **kwargs)[0]
noise_pred = noise_pred_uncond + self.guide_scale * (noise_pred_cond - noise_pred_uncond)
return noise_pred
class _WrappedModel_FLUX:
def __init__(self, model, timesteps, num_timesteps):
self.model = model
fm_steps = torch.cat([timesteps,torch.zeros_like(timesteps[0]).view(1)])
self.time_steps = torch.flip(fm_steps, dims=[0])
self.fm_steps = self.time_steps/num_timesteps
def __call__(self, x, t, y, kwargs):
t = self.time_steps[t]
t = t.expand(x.shape[0]).to(x.dtype) / 1000
pred = self.model(hidden_states=x, timestep=t, **kwargs)[0]
return pred
|