JasonSmithSO's picture
Upload 777 files
0034848 verified
import numpy as np
import yaml
import argparse
import math
import torch
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import *
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.encoder_decoder import AutoencoderKL
# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.transmodel import TransModel
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.uncond_unet import Unet
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.data import *
from fvcore.common.config import CfgNode
from pathlib import Path
def load_conf(config_file, conf={}):
with open(config_file) as f:
exp_conf = yaml.load(f, Loader=yaml.FullLoader)
for k, v in exp_conf.items():
conf[k] = v
return conf
def prepare_args(ckpt_path, sampling_timesteps=1):
return argparse.Namespace(
cfg=load_conf(Path(__file__).parent / "default.yaml"),
pre_weight=ckpt_path,
sampling_timesteps=sampling_timesteps
)
class DiffusionEdge:
def __init__(self, args) -> None:
self.cfg = CfgNode(args.cfg)
torch.manual_seed(42)
np.random.seed(42)
model_cfg = self.cfg.model
first_stage_cfg = model_cfg.first_stage
first_stage_model = AutoencoderKL(
ddconfig=first_stage_cfg.ddconfig,
lossconfig=first_stage_cfg.lossconfig,
embed_dim=first_stage_cfg.embed_dim,
ckpt_path=first_stage_cfg.ckpt_path,
)
if model_cfg.model_name == 'cond_unet':
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.mask_cond_unet import Unet
unet_cfg = model_cfg.unet
unet = Unet(dim=unet_cfg.dim,
channels=unet_cfg.channels,
dim_mults=unet_cfg.dim_mults,
learned_variance=unet_cfg.get('learned_variance', False),
out_mul=unet_cfg.out_mul,
cond_in_dim=unet_cfg.cond_in_dim,
cond_dim=unet_cfg.cond_dim,
cond_dim_mults=unet_cfg.cond_dim_mults,
window_sizes1=unet_cfg.window_sizes1,
window_sizes2=unet_cfg.window_sizes2,
fourier_scale=unet_cfg.fourier_scale,
cfg=unet_cfg,
)
else:
raise NotImplementedError
if model_cfg.model_type == 'const_sde':
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.ddm_const_sde import LatentDiffusion
else:
raise NotImplementedError(f'{model_cfg.model_type} is not surportted !')
self.model = LatentDiffusion(
model=unet,
auto_encoder=first_stage_model,
train_sample=model_cfg.train_sample,
image_size=model_cfg.image_size,
timesteps=model_cfg.timesteps,
sampling_timesteps=args.sampling_timesteps,
loss_type=model_cfg.loss_type,
objective=model_cfg.objective,
scale_factor=model_cfg.scale_factor,
scale_by_std=model_cfg.scale_by_std,
scale_by_softsign=model_cfg.scale_by_softsign,
default_scale=model_cfg.get('default_scale', False),
input_keys=model_cfg.input_keys,
ckpt_path=model_cfg.ckpt_path,
ignore_keys=model_cfg.ignore_keys,
only_model=model_cfg.only_model,
start_dist=model_cfg.start_dist,
perceptual_weight=model_cfg.perceptual_weight,
use_l1=model_cfg.get('use_l1', True),
cfg=model_cfg,
)
self.cfg.sampler.ckpt_path = args.pre_weight
data = torch.load(self.cfg.sampler.ckpt_path, map_location="cpu")
if self.cfg.sampler.use_ema:
sd = data['ema']
new_sd = {}
for k in sd.keys():
if k.startswith("ema_model."):
new_k = k[10:] # remove ema_model.
new_sd[new_k] = sd[k]
sd = new_sd
self.model.load_state_dict(sd)
else:
self.model.load_state_dict(data['model'])
if 'scale_factor' in data['model']:
self.model.scale_factor = data['model']['scale_factor']
self.model.eval()
self.device = "cpu"
def to(self, device):
self.model.to(device)
self.device = device
return self
def __call__(self, image, batch_size=8):
image = normalize_to_neg_one_to_one(image).to(self.device)
mask = None
if self.cfg.sampler.sample_type == 'whole':
return self.whole_sample(image, raw_size=image.shape[2:], mask=mask)
elif self.cfg.sampler.sample_type == 'slide':
return self.slide_sample(image, crop_size=self.cfg.sampler.get('crop_size', [320, 320]),
stride=self.cfg.sampler.stride, mask=mask, bs=batch_size)
def whole_sample(self, inputs, raw_size, mask=None):
inputs = F.interpolate(inputs, size=(416, 416), mode='bilinear', align_corners=True)
seg_logits = self.model.sample(batch_size=inputs.shape[0], cond=inputs, mask=mask)
seg_logits = F.interpolate(seg_logits, size=raw_size, mode='bilinear', align_corners=True)
return seg_logits
def slide_sample(self, inputs, crop_size, stride, mask=None, bs=8):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
Args:
inputs (tensor): the tensor should have a shape NxCxHxW,
which contains all images in the batch.
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
h_stride, w_stride = stride
h_crop, w_crop = crop_size
batch_size, _, h_img, w_img = inputs.size()
out_channels = 1
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
# aux_out1 = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
# aux_out2 = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
crop_imgs = []
x1s = []
x2s = []
y1s = []
y2s = []
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = inputs[:, :, y1:y2, x1:x2]
crop_imgs.append(crop_img)
x1s.append(x1)
x2s.append(x2)
y1s.append(y1)
y2s.append(y2)
crop_imgs = torch.cat(crop_imgs, dim=0)
crop_seg_logits_list = []
num_windows = crop_imgs.shape[0]
bs = bs
length = math.ceil(num_windows / bs)
for i in range(length):
if i == length - 1:
crop_imgs_temp = crop_imgs[bs * i:num_windows, ...]
else:
crop_imgs_temp = crop_imgs[bs * i:bs * (i + 1), ...]
crop_seg_logits = self.model.sample(batch_size=crop_imgs_temp.shape[0], cond=crop_imgs_temp, mask=mask)
crop_seg_logits_list.append(crop_seg_logits)
crop_seg_logits = torch.cat(crop_seg_logits_list, dim=0)
for crop_seg_logit, x1, x2, y1, y2 in zip(crop_seg_logits, x1s, x2s, y1s, y2s):
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
seg_logits = preds / count_mat
return seg_logits