刘虹雨
update
8ed2f16
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
# Please use this implementation in your products
# This implementation may produce slightly different results from Saining Xie's official implementations,
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
# Different from official vae and other implementations, this is an RGB-input model (rather than BGR)
# and in this way it works better for gradio's RGB protocol
import sys
from pathlib import Path
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent.parent))
from torch import nn
import torch
import numpy as np
from torchvision import transforms as T
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import json
from PIL import Image
import torchvision.transforms.functional as TF
from accelerate import Accelerator
from diffusers.models import AutoencoderKL
import os
image_resize = 1024
class DoubleConvBlock(nn.Module):
def __init__(self, input_channel, output_channel, layer_number):
super().__init__()
self.convs = torch.nn.Sequential()
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
for i in range(1, layer_number):
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
def forward(self, x, down_sampling=False):
h = x
if down_sampling:
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
for conv in self.convs:
h = conv(h)
h = torch.nn.functional.relu(h)
return h, self.projection(h)
class ControlNetHED_Apache2(nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
def forward(self, x):
h = x - self.norm
h, projection1 = self.block1(h)
h, projection2 = self.block2(h, down_sampling=True)
h, projection3 = self.block3(h, down_sampling=True)
h, projection4 = self.block4(h, down_sampling=True)
h, projection5 = self.block5(h, down_sampling=True)
return projection1, projection2, projection3, projection4, projection5
class InternData(Dataset):
def __init__(self):
####
with open('data/InternData/partition/data_info.json', 'r') as f:
self.j = json.load(f)
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.Resize(image_resize), # Image.BICUBIC
T.CenterCrop(image_resize),
T.ToTensor(),
])
def __len__(self):
return len(self.j)
def getdata(self, idx):
path = self.j[idx]['path']
image = Image.open("data/InternImgs/" + path)
image = self.transform(image)
return image, path
def __getitem__(self, idx):
for i in range(20):
try:
data = self.getdata(idx)
return data
except Exception as e:
print(f"Error details: {str(e)}")
idx = np.random.randint(len(self))
raise RuntimeError('Too many bad data.')
class HEDdetector(nn.Module):
def __init__(self, feature=True, vae=None):
super().__init__()
self.model = ControlNetHED_Apache2()
self.model.load_state_dict(torch.load('output/pretrained_models/ControlNetHED.pth', map_location='cpu'))
self.model.eval()
self.model.requires_grad_(False)
if feature:
if vae is None:
self.vae = AutoencoderKL.from_pretrained("output/pretrained_models/sd-vae-ft-ema")
else:
self.vae = vae
self.vae.eval()
self.vae.requires_grad_(False)
else:
self.vae = None
def forward(self, input_image):
B, C, H, W = input_image.shape
with torch.inference_mode():
edges = self.model(input_image * 255.)
edges = torch.cat([TF.resize(e, [H, W]) for e in edges], dim=1)
edge = 1 / (1 + torch.exp(-torch.mean(edges, dim=1, keepdim=True)))
edge.clip_(0, 1)
if self.vae:
edge = TF.normalize(edge, [.5], [.5])
edge = edge.repeat(1, 3, 1, 1)
posterior = self.vae.encode(edge).latent_dist
edge = torch.cat([posterior.mean, posterior.std], dim=1).cpu().numpy()
return edge
def main():
dataset = InternData()
dataloader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=8, pin_memory=True)
hed = HEDdetector()
accelerator = Accelerator()
hed, dataloader = accelerator.prepare(hed, dataloader)
for img, path in tqdm(dataloader):
out = hed(img.cuda())
for p, o in zip(path, out):
save = f'data/InternalData/hed_feature_{image_resize}/' + p.replace('.png', '.npz')
if os.path.exists(save):
continue
os.makedirs(os.path.dirname(save), exist_ok=True)
np.savez_compressed(save, o)
if __name__ == "__main__":
main()