Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| def set_requires_grad(module: nn.Module, requires_grad: bool): | |
| for p in module.parameters(): | |
| p.requires_grad_(requires_grad) | |
| def compute_distance_transform(mask: torch.Tensor): | |
| image_size = mask.shape[-1] | |
| distance_transform = torch.stack([ | |
| torch.from_numpy(cv2.distanceTransform( | |
| (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3 | |
| ) / (image_size / 2)) | |
| for m in mask.squeeze(1).detach().cpu().numpy().astype(np.uint8) | |
| ]).unsqueeze(1).clip(0, 1).to(mask.device) | |
| return distance_transform | |
| def default(x, d): | |
| return d if x is None else x | |
| def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000): | |
| """Custom beta schedule""" | |
| betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) | |
| warmup_frac = 0.3 | |
| warmup_time = int(num_train_timesteps * warmup_frac) | |
| warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) | |
| warmup_time = min(warmup_time, num_train_timesteps) | |
| betas[:warmup_time] = warmup_steps[:warmup_time] | |
| return betas | |