EFDM / app.py
biubiubiiu's picture
add an example
a49fd59
import gradio as gr
import toml
import torch
from PIL import Image
from torch import nn
from torchvision import transforms
import net
from function import *
cfg = toml.load("config.toml") # static variables
# Setup device
if torch.cuda.is_available() and cfg["use_cuda"]:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load pretrained models
decoder = net.decoder
vgg = net.vgg
decoder.eval()
vgg.eval()
decoder.load_state_dict(torch.load(cfg["decoder_weight"]))
vgg.load_state_dict(torch.load(cfg["vgg_weight"]))
vgg = nn.Sequential(*list(vgg.children())[:31])
vgg = vgg.to(device)
decoder = decoder.to(device)
def transform(img, size, crop):
transform_list = []
if size > 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform(img)
@torch.inference_mode()
def style_transfer(content, style, style_type, alpha, keep_resolution):
"""Stylize function"""
style_type = style_type.lower()
# Step 1: convert image to PyTorch Tensor
if keep_resolution:
style = style.resize(content.size, Image.ANTIALIAS)
if style_type == "efdm" and not keep_resolution:
content = transform(content, cfg["content_size"], cfg["crop"])
style = transform(style, cfg["style_size"], cfg["crop"])
else:
content = transform(content, -1, False)
style = transform(style, -1, False)
content = content.to(device).unsqueeze(0)
style = style.to(device).unsqueeze(0)
# Step 2: extract content feature and style feature
content_feat = vgg(content)
style_feat = vgg(style)
# Step 3: perform style transfer
transfer = {
"adain": adaptive_instance_normalization,
"adamean": adaptive_mean_normalization,
"adastd": adaptive_std_normalization,
"efdm": exact_feature_distribution_matching,
"hm": histogram_matching,
}[style_type]
feat = transfer(content_feat, style_feat)
# Step 4: content-style trade-off
feat = feat * alpha + content_feat * (1 - alpha)
# Step 5: decode to image
output = decoder(feat).cpu().squeeze(0).clamp_(0, 1)
output = transforms.ToPILImage()(output)
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
return output
# Add image examples
example_img_pairs = {
"examples/content/sailboat.jpg": "examples/style/sketch.png",
"examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg",
"examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg",
"examples/content/paris.jpeg": "examples/style/vangogh.jpeg",
"examples/content/cornell.jpg": "examples/style/asheville.jpg",
}
# Customize interface
title = "Style Transfer with EFDM"
description = """
Gradio demo for neural style transfer using exact feature distribution matching
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.07740'>Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization</a></p>"
content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil")
style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil")
style_type = gr.inputs.Radio(
["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method"
)
alpha_selector = gr.inputs.Slider(
minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off"
)
keep_resolution = gr.inputs.Checkbox(
default=True, label="Keep content image resolution"
)
iface = gr.Interface(
fn=style_transfer,
inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution],
outputs=["image"],
title=title,
description=description,
article=article,
theme="huggingface",
examples=[
[content, style, "EFDM", 1.0, True]
for content, style in example_img_pairs.items()
],
)
iface.launch(debug=False, enable_queue=True)