MH0386's picture
Upload folder using huggingface_hub
3e165b2 verified
from abc import ABC, abstractmethod
import numpy as np
import torch as th
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: The name of the sampler.
:param diffusion: The diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the goal.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the goal.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: The number of timesteps.
:param device: The torch device to save to.
:return: A tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, num_timesteps):
self._weights = np.ones([num_timesteps])
def weights(self):
return self._weights