File size: 1,984 Bytes
b5ce381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from scipy import integrate
from einops import repeat, rearrange
from ...util import append_dims


def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
    if order - 1 > i:
        raise ValueError(f"Order {order} too high for step {i}")

    def fn(tau):
        prod = 1.0
        for k in range(order):
            if j == k:
                continue
            prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
        return prod

    return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]


def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
    if not eta:
        return sigma_to, 0.0
    sigma_up = torch.minimum(
        sigma_to,
        eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
    )
    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
    return sigma_down, sigma_up


def to_d(x, sigma, denoised):
    return (x - denoised) / append_dims(sigma, x.ndim)


def to_neg_log_sigma(sigma):
    return sigma.log().neg()


def to_sigma(neg_log_sigma):
    return neg_log_sigma.neg().exp()


def chunk_inputs(
    input,
    cond,
    additional_model_inputs,
    sigma,
    sigma_next,
    start_idx,
    end_idx,
    num_frames=14,
):
    input_chunk = input[:, :, start_idx:end_idx].to(torch.float32).clone()

    sigma_chunk = sigma[start_idx:end_idx].to(torch.float32)
    sigma_next_chunk = sigma_next[start_idx:end_idx].to(torch.float32)

    cond_chunk = {}
    for k, v in cond.items():
        if isinstance(v, torch.Tensor):
            cond_chunk[k] = v[start_idx:end_idx]
        else:
            cond_chunk[k] = v

    additional_model_inputs_chunk = {}
    for k, v in additional_model_inputs.items():
        if isinstance(v, torch.Tensor):
            cond_chunk[k] = v[start_idx:end_idx]
        else:
            additional_model_inputs_chunk[k] = v

    return (
        input_chunk,
        sigma_chunk,
        sigma_next_chunk,
        cond_chunk,
        additional_model_inputs_chunk,
    )