File size: 1,171 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from einops import reduce
from jaxtyping import Float, Int64
from torch import Tensor
def sample_discrete_distribution(
pdf: Float[Tensor, "*batch bucket"],
num_samples: int,
eps: float = torch.finfo(torch.float32).eps,
) -> tuple[
Int64[Tensor, "*batch sample"], # index
Float[Tensor, "*batch sample"], # probability density
]:
*batch, bucket = pdf.shape
normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum"))
cdf = normalized_pdf.cumsum(dim=-1)
samples = torch.rand((*batch, num_samples), device=pdf.device)
index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1)
return index, normalized_pdf.gather(dim=-1, index=index)
def gather_discrete_topk(
pdf: Float[Tensor, "*batch bucket"],
num_samples: int,
eps: float = torch.finfo(torch.float32).eps,
) -> tuple[
Int64[Tensor, "*batch sample"], # index
Float[Tensor, "*batch sample"], # probability density
]:
normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum"))
index = pdf.topk(k=num_samples, dim=-1).indices
return index, normalized_pdf.gather(dim=-1, index=index)
|