Spaces:
Configuration error
Configuration error
import numpy as np | |
import yaml | |
import argparse | |
import math | |
import torch | |
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import * | |
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.encoder_decoder import AutoencoderKL | |
# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.transmodel import TransModel | |
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.uncond_unet import Unet | |
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.data import * | |
from fvcore.common.config import CfgNode | |
from pathlib import Path | |
def load_conf(config_file, conf={}): | |
with open(config_file) as f: | |
exp_conf = yaml.load(f, Loader=yaml.FullLoader) | |
for k, v in exp_conf.items(): | |
conf[k] = v | |
return conf | |
def prepare_args(ckpt_path, sampling_timesteps=1): | |
return argparse.Namespace( | |
cfg=load_conf(Path(__file__).parent / "default.yaml"), | |
pre_weight=ckpt_path, | |
sampling_timesteps=sampling_timesteps | |
) | |
class DiffusionEdge: | |
def __init__(self, args) -> None: | |
self.cfg = CfgNode(args.cfg) | |
torch.manual_seed(42) | |
np.random.seed(42) | |
model_cfg = self.cfg.model | |
first_stage_cfg = model_cfg.first_stage | |
first_stage_model = AutoencoderKL( | |
ddconfig=first_stage_cfg.ddconfig, | |
lossconfig=first_stage_cfg.lossconfig, | |
embed_dim=first_stage_cfg.embed_dim, | |
ckpt_path=first_stage_cfg.ckpt_path, | |
) | |
if model_cfg.model_name == 'cond_unet': | |
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.mask_cond_unet import Unet | |
unet_cfg = model_cfg.unet | |
unet = Unet(dim=unet_cfg.dim, | |
channels=unet_cfg.channels, | |
dim_mults=unet_cfg.dim_mults, | |
learned_variance=unet_cfg.get('learned_variance', False), | |
out_mul=unet_cfg.out_mul, | |
cond_in_dim=unet_cfg.cond_in_dim, | |
cond_dim=unet_cfg.cond_dim, | |
cond_dim_mults=unet_cfg.cond_dim_mults, | |
window_sizes1=unet_cfg.window_sizes1, | |
window_sizes2=unet_cfg.window_sizes2, | |
fourier_scale=unet_cfg.fourier_scale, | |
cfg=unet_cfg, | |
) | |
else: | |
raise NotImplementedError | |
if model_cfg.model_type == 'const_sde': | |
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.ddm_const_sde import LatentDiffusion | |
else: | |
raise NotImplementedError(f'{model_cfg.model_type} is not surportted !') | |
self.model = LatentDiffusion( | |
model=unet, | |
auto_encoder=first_stage_model, | |
train_sample=model_cfg.train_sample, | |
image_size=model_cfg.image_size, | |
timesteps=model_cfg.timesteps, | |
sampling_timesteps=args.sampling_timesteps, | |
loss_type=model_cfg.loss_type, | |
objective=model_cfg.objective, | |
scale_factor=model_cfg.scale_factor, | |
scale_by_std=model_cfg.scale_by_std, | |
scale_by_softsign=model_cfg.scale_by_softsign, | |
default_scale=model_cfg.get('default_scale', False), | |
input_keys=model_cfg.input_keys, | |
ckpt_path=model_cfg.ckpt_path, | |
ignore_keys=model_cfg.ignore_keys, | |
only_model=model_cfg.only_model, | |
start_dist=model_cfg.start_dist, | |
perceptual_weight=model_cfg.perceptual_weight, | |
use_l1=model_cfg.get('use_l1', True), | |
cfg=model_cfg, | |
) | |
self.cfg.sampler.ckpt_path = args.pre_weight | |
data = torch.load(self.cfg.sampler.ckpt_path, map_location="cpu") | |
if self.cfg.sampler.use_ema: | |
sd = data['ema'] | |
new_sd = {} | |
for k in sd.keys(): | |
if k.startswith("ema_model."): | |
new_k = k[10:] # remove ema_model. | |
new_sd[new_k] = sd[k] | |
sd = new_sd | |
self.model.load_state_dict(sd) | |
else: | |
self.model.load_state_dict(data['model']) | |
if 'scale_factor' in data['model']: | |
self.model.scale_factor = data['model']['scale_factor'] | |
self.model.eval() | |
self.device = "cpu" | |
def to(self, device): | |
self.model.to(device) | |
self.device = device | |
return self | |
def __call__(self, image, batch_size=8): | |
image = normalize_to_neg_one_to_one(image).to(self.device) | |
mask = None | |
if self.cfg.sampler.sample_type == 'whole': | |
return self.whole_sample(image, raw_size=image.shape[2:], mask=mask) | |
elif self.cfg.sampler.sample_type == 'slide': | |
return self.slide_sample(image, crop_size=self.cfg.sampler.get('crop_size', [320, 320]), | |
stride=self.cfg.sampler.stride, mask=mask, bs=batch_size) | |
def whole_sample(self, inputs, raw_size, mask=None): | |
inputs = F.interpolate(inputs, size=(416, 416), mode='bilinear', align_corners=True) | |
seg_logits = self.model.sample(batch_size=inputs.shape[0], cond=inputs, mask=mask) | |
seg_logits = F.interpolate(seg_logits, size=raw_size, mode='bilinear', align_corners=True) | |
return seg_logits | |
def slide_sample(self, inputs, crop_size, stride, mask=None, bs=8): | |
"""Inference by sliding-window with overlap. | |
If h_crop > h_img or w_crop > w_img, the small patch will be used to | |
decode without padding. | |
Args: | |
inputs (tensor): the tensor should have a shape NxCxHxW, | |
which contains all images in the batch. | |
batch_img_metas (List[dict]): List of image metainfo where each may | |
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', | |
'ori_shape', and 'pad_shape'. | |
For details on the values of these keys see | |
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`. | |
Returns: | |
Tensor: The segmentation results, seg_logits from model of each | |
input image. | |
""" | |
h_stride, w_stride = stride | |
h_crop, w_crop = crop_size | |
batch_size, _, h_img, w_img = inputs.size() | |
out_channels = 1 | |
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 | |
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 | |
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) | |
# aux_out1 = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) | |
# aux_out2 = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) | |
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) | |
crop_imgs = [] | |
x1s = [] | |
x2s = [] | |
y1s = [] | |
y2s = [] | |
for h_idx in range(h_grids): | |
for w_idx in range(w_grids): | |
y1 = h_idx * h_stride | |
x1 = w_idx * w_stride | |
y2 = min(y1 + h_crop, h_img) | |
x2 = min(x1 + w_crop, w_img) | |
y1 = max(y2 - h_crop, 0) | |
x1 = max(x2 - w_crop, 0) | |
crop_img = inputs[:, :, y1:y2, x1:x2] | |
crop_imgs.append(crop_img) | |
x1s.append(x1) | |
x2s.append(x2) | |
y1s.append(y1) | |
y2s.append(y2) | |
crop_imgs = torch.cat(crop_imgs, dim=0) | |
crop_seg_logits_list = [] | |
num_windows = crop_imgs.shape[0] | |
bs = bs | |
length = math.ceil(num_windows / bs) | |
for i in range(length): | |
if i == length - 1: | |
crop_imgs_temp = crop_imgs[bs * i:num_windows, ...] | |
else: | |
crop_imgs_temp = crop_imgs[bs * i:bs * (i + 1), ...] | |
crop_seg_logits = self.model.sample(batch_size=crop_imgs_temp.shape[0], cond=crop_imgs_temp, mask=mask) | |
crop_seg_logits_list.append(crop_seg_logits) | |
crop_seg_logits = torch.cat(crop_seg_logits_list, dim=0) | |
for crop_seg_logit, x1, x2, y1, y2 in zip(crop_seg_logits, x1s, x2s, y1s, y2s): | |
preds += F.pad(crop_seg_logit, | |
(int(x1), int(preds.shape[3] - x2), int(y1), | |
int(preds.shape[2] - y2))) | |
count_mat[:, :, y1:y2, x1:x2] += 1 | |
assert (count_mat == 0).sum() == 0 | |
seg_logits = preds / count_mat | |
return seg_logits | |