|
import torch |
|
|
|
|
|
def sample_x0(x1): |
|
"""Sampling x0 & t based on shape of x1 (if needed) |
|
Args: |
|
x1 - data point; [batch, *dim] |
|
""" |
|
if isinstance(x1, (list, tuple)): |
|
x0 = [torch.randn_like(img_start) for img_start in x1] |
|
else: |
|
x0 = torch.randn_like(x1) |
|
|
|
return x0 |
|
|
|
def sample_timestep(x1): |
|
u = torch.normal(mean=0.0, std=1.0, size=(len(x1),)) |
|
t = 1 / (1 + torch.exp(-u)) |
|
t = t.to(x1[0]) |
|
return t |
|
|
|
|
|
def training_losses(model, x1, model_kwargs=None, snr_type='uniform'): |
|
"""Loss for training torche score model |
|
Args: |
|
- model: backbone model; could be score, noise, or velocity |
|
- x1: datapoint |
|
- model_kwargs: additional arguments for torche model |
|
""" |
|
if model_kwargs == None: |
|
model_kwargs = {} |
|
|
|
B = len(x1) |
|
|
|
x0 = sample_x0(x1) |
|
t = sample_timestep(x1) |
|
|
|
if isinstance(x1, (list, tuple)): |
|
xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)] |
|
ut = [x1[i] - x0[i] for i in range(B)] |
|
else: |
|
dims = [1] * (len(x1.size()) - 1) |
|
t_ = t.view(t.size(0), *dims) |
|
xt = t_ * x1 + (1 - t_) * x0 |
|
ut = x1 - x0 |
|
|
|
model_output = model(xt, t, **model_kwargs) |
|
|
|
terms = {} |
|
|
|
if isinstance(x1, (list, tuple)): |
|
assert len(model_output) == len(ut) == len(x1) |
|
for i in range(B): |
|
terms["loss"] = torch.stack( |
|
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)], |
|
dim=0, |
|
) |
|
else: |
|
terms["loss"] = mean_flat(((model_output - ut) ** 2)) |
|
|
|
return terms |
|
|
|
|
|
def mean_flat(x): |
|
""" |
|
Take torche mean over all non-batch dimensions. |
|
""" |
|
return torch.mean(x, dim=list(range(1, len(x.size())))) |
|
|