Spaces:
Configuration error
Configuration error
File size: 8,579 Bytes
0034848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import numpy as np
import yaml
import argparse
import math
import torch
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import *
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.encoder_decoder import AutoencoderKL
# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.transmodel import TransModel
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.uncond_unet import Unet
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.data import *
from fvcore.common.config import CfgNode
from pathlib import Path
def load_conf(config_file, conf={}):
with open(config_file) as f:
exp_conf = yaml.load(f, Loader=yaml.FullLoader)
for k, v in exp_conf.items():
conf[k] = v
return conf
def prepare_args(ckpt_path, sampling_timesteps=1):
return argparse.Namespace(
cfg=load_conf(Path(__file__).parent / "default.yaml"),
pre_weight=ckpt_path,
sampling_timesteps=sampling_timesteps
)
class DiffusionEdge:
def __init__(self, args) -> None:
self.cfg = CfgNode(args.cfg)
torch.manual_seed(42)
np.random.seed(42)
model_cfg = self.cfg.model
first_stage_cfg = model_cfg.first_stage
first_stage_model = AutoencoderKL(
ddconfig=first_stage_cfg.ddconfig,
lossconfig=first_stage_cfg.lossconfig,
embed_dim=first_stage_cfg.embed_dim,
ckpt_path=first_stage_cfg.ckpt_path,
)
if model_cfg.model_name == 'cond_unet':
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.mask_cond_unet import Unet
unet_cfg = model_cfg.unet
unet = Unet(dim=unet_cfg.dim,
channels=unet_cfg.channels,
dim_mults=unet_cfg.dim_mults,
learned_variance=unet_cfg.get('learned_variance', False),
out_mul=unet_cfg.out_mul,
cond_in_dim=unet_cfg.cond_in_dim,
cond_dim=unet_cfg.cond_dim,
cond_dim_mults=unet_cfg.cond_dim_mults,
window_sizes1=unet_cfg.window_sizes1,
window_sizes2=unet_cfg.window_sizes2,
fourier_scale=unet_cfg.fourier_scale,
cfg=unet_cfg,
)
else:
raise NotImplementedError
if model_cfg.model_type == 'const_sde':
from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.ddm_const_sde import LatentDiffusion
else:
raise NotImplementedError(f'{model_cfg.model_type} is not surportted !')
self.model = LatentDiffusion(
model=unet,
auto_encoder=first_stage_model,
train_sample=model_cfg.train_sample,
image_size=model_cfg.image_size,
timesteps=model_cfg.timesteps,
sampling_timesteps=args.sampling_timesteps,
loss_type=model_cfg.loss_type,
objective=model_cfg.objective,
scale_factor=model_cfg.scale_factor,
scale_by_std=model_cfg.scale_by_std,
scale_by_softsign=model_cfg.scale_by_softsign,
default_scale=model_cfg.get('default_scale', False),
input_keys=model_cfg.input_keys,
ckpt_path=model_cfg.ckpt_path,
ignore_keys=model_cfg.ignore_keys,
only_model=model_cfg.only_model,
start_dist=model_cfg.start_dist,
perceptual_weight=model_cfg.perceptual_weight,
use_l1=model_cfg.get('use_l1', True),
cfg=model_cfg,
)
self.cfg.sampler.ckpt_path = args.pre_weight
data = torch.load(self.cfg.sampler.ckpt_path, map_location="cpu")
if self.cfg.sampler.use_ema:
sd = data['ema']
new_sd = {}
for k in sd.keys():
if k.startswith("ema_model."):
new_k = k[10:] # remove ema_model.
new_sd[new_k] = sd[k]
sd = new_sd
self.model.load_state_dict(sd)
else:
self.model.load_state_dict(data['model'])
if 'scale_factor' in data['model']:
self.model.scale_factor = data['model']['scale_factor']
self.model.eval()
self.device = "cpu"
def to(self, device):
self.model.to(device)
self.device = device
return self
def __call__(self, image, batch_size=8):
image = normalize_to_neg_one_to_one(image).to(self.device)
mask = None
if self.cfg.sampler.sample_type == 'whole':
return self.whole_sample(image, raw_size=image.shape[2:], mask=mask)
elif self.cfg.sampler.sample_type == 'slide':
return self.slide_sample(image, crop_size=self.cfg.sampler.get('crop_size', [320, 320]),
stride=self.cfg.sampler.stride, mask=mask, bs=batch_size)
def whole_sample(self, inputs, raw_size, mask=None):
inputs = F.interpolate(inputs, size=(416, 416), mode='bilinear', align_corners=True)
seg_logits = self.model.sample(batch_size=inputs.shape[0], cond=inputs, mask=mask)
seg_logits = F.interpolate(seg_logits, size=raw_size, mode='bilinear', align_corners=True)
return seg_logits
def slide_sample(self, inputs, crop_size, stride, mask=None, bs=8):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
Args:
inputs (tensor): the tensor should have a shape NxCxHxW,
which contains all images in the batch.
batch_img_metas (List[dict]): List of image metainfo where each may
also contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
Returns:
Tensor: The segmentation results, seg_logits from model of each
input image.
"""
h_stride, w_stride = stride
h_crop, w_crop = crop_size
batch_size, _, h_img, w_img = inputs.size()
out_channels = 1
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
# aux_out1 = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
# aux_out2 = inputs.new_zeros((batch_size, out_channels, h_img, w_img))
count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img))
crop_imgs = []
x1s = []
x2s = []
y1s = []
y2s = []
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = inputs[:, :, y1:y2, x1:x2]
crop_imgs.append(crop_img)
x1s.append(x1)
x2s.append(x2)
y1s.append(y1)
y2s.append(y2)
crop_imgs = torch.cat(crop_imgs, dim=0)
crop_seg_logits_list = []
num_windows = crop_imgs.shape[0]
bs = bs
length = math.ceil(num_windows / bs)
for i in range(length):
if i == length - 1:
crop_imgs_temp = crop_imgs[bs * i:num_windows, ...]
else:
crop_imgs_temp = crop_imgs[bs * i:bs * (i + 1), ...]
crop_seg_logits = self.model.sample(batch_size=crop_imgs_temp.shape[0], cond=crop_imgs_temp, mask=mask)
crop_seg_logits_list.append(crop_seg_logits)
crop_seg_logits = torch.cat(crop_seg_logits_list, dim=0)
for crop_seg_logit, x1, x2, y1, y2 in zip(crop_seg_logits, x1s, x2s, y1s, y2s):
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
seg_logits = preds / count_mat
return seg_logits
|