Spaces:
Runtime error
Runtime error
File size: 4,108 Bytes
c9b624b d54f4b1 c9b624b a49fd59 c9b624b 6815a5f c9b624b 6815a5f c9b624b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import toml
import torch
from PIL import Image
from torch import nn
from torchvision import transforms
import net
from function import *
cfg = toml.load("config.toml") # static variables
# Setup device
if torch.cuda.is_available() and cfg["use_cuda"]:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load pretrained models
decoder = net.decoder
vgg = net.vgg
decoder.eval()
vgg.eval()
decoder.load_state_dict(torch.load(cfg["decoder_weight"]))
vgg.load_state_dict(torch.load(cfg["vgg_weight"]))
vgg = nn.Sequential(*list(vgg.children())[:31])
vgg = vgg.to(device)
decoder = decoder.to(device)
def transform(img, size, crop):
transform_list = []
if size > 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform(img)
@torch.inference_mode()
def style_transfer(content, style, style_type, alpha, keep_resolution):
"""Stylize function"""
style_type = style_type.lower()
# Step 1: convert image to PyTorch Tensor
if keep_resolution:
style = style.resize(content.size, Image.ANTIALIAS)
if style_type == "efdm" and not keep_resolution:
content = transform(content, cfg["content_size"], cfg["crop"])
style = transform(style, cfg["style_size"], cfg["crop"])
else:
content = transform(content, -1, False)
style = transform(style, -1, False)
content = content.to(device).unsqueeze(0)
style = style.to(device).unsqueeze(0)
# Step 2: extract content feature and style feature
content_feat = vgg(content)
style_feat = vgg(style)
# Step 3: perform style transfer
transfer = {
"adain": adaptive_instance_normalization,
"adamean": adaptive_mean_normalization,
"adastd": adaptive_std_normalization,
"efdm": exact_feature_distribution_matching,
"hm": histogram_matching,
}[style_type]
feat = transfer(content_feat, style_feat)
# Step 4: content-style trade-off
feat = feat * alpha + content_feat * (1 - alpha)
# Step 5: decode to image
output = decoder(feat).cpu().squeeze(0).clamp_(0, 1)
output = transforms.ToPILImage()(output)
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
return output
# Add image examples
example_img_pairs = {
"examples/content/sailboat.jpg": "examples/style/sketch.png",
"examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg",
"examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg",
"examples/content/paris.jpeg": "examples/style/vangogh.jpeg",
"examples/content/cornell.jpg": "examples/style/asheville.jpg",
}
# Customize interface
title = "Style Transfer with EFDM"
description = """
Gradio demo for neural style transfer using exact feature distribution matching
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.07740'>Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization</a></p>"
content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil")
style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil")
style_type = gr.inputs.Radio(
["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method"
)
alpha_selector = gr.inputs.Slider(
minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off"
)
keep_resolution = gr.inputs.Checkbox(
default=True, label="Keep content image resolution"
)
iface = gr.Interface(
fn=style_transfer,
inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution],
outputs=["image"],
title=title,
description=description,
article=article,
theme="huggingface",
examples=[
[content, style, "EFDM", 1.0, True]
for content, style in example_img_pairs.items()
],
)
iface.launch(debug=False, enable_queue=True)
|