File size: 4,108 Bytes
c9b624b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d54f4b1
 
 
c9b624b
 
 
 
 
 
 
 
 
 
a49fd59
c9b624b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6815a5f
c9b624b
 
 
 
 
 
 
 
 
 
 
6815a5f
c9b624b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import toml
import torch
from PIL import Image
from torch import nn
from torchvision import transforms

import net
from function import *

cfg = toml.load("config.toml")  # static variables

# Setup device
if torch.cuda.is_available() and cfg["use_cuda"]:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# Load pretrained models
decoder = net.decoder
vgg = net.vgg

decoder.eval()
vgg.eval()

decoder.load_state_dict(torch.load(cfg["decoder_weight"]))
vgg.load_state_dict(torch.load(cfg["vgg_weight"]))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg = vgg.to(device)
decoder = decoder.to(device)


def transform(img, size, crop):
    transform_list = []
    if size > 0:
        transform_list.append(transforms.Resize(size))
    if crop:
        transform_list.append(transforms.CenterCrop(size))
    transform_list.append(transforms.ToTensor())
    transform = transforms.Compose(transform_list)
    return transform(img)


@torch.inference_mode()
def style_transfer(content, style, style_type, alpha, keep_resolution):
    """Stylize function"""
    style_type = style_type.lower()

    # Step 1: convert image to PyTorch Tensor
    if keep_resolution:
        style = style.resize(content.size, Image.ANTIALIAS)

    if style_type == "efdm" and not keep_resolution:
        content = transform(content, cfg["content_size"], cfg["crop"])
        style = transform(style, cfg["style_size"], cfg["crop"])
    else:
        content = transform(content, -1, False)
        style = transform(style, -1, False)

    content = content.to(device).unsqueeze(0)
    style = style.to(device).unsqueeze(0)

    # Step 2: extract content feature and style feature
    content_feat = vgg(content)
    style_feat = vgg(style)

    # Step 3: perform style transfer
    transfer = {
        "adain": adaptive_instance_normalization,
        "adamean": adaptive_mean_normalization,
        "adastd": adaptive_std_normalization,
        "efdm": exact_feature_distribution_matching,
        "hm": histogram_matching,
    }[style_type]
    feat = transfer(content_feat, style_feat)

    # Step 4: content-style trade-off
    feat = feat * alpha + content_feat * (1 - alpha)

    # Step 5: decode to image
    output = decoder(feat).cpu().squeeze(0).clamp_(0, 1)
    output = transforms.ToPILImage()(output)

    if torch.cuda.is_available():
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()

    return output


# Add image examples
example_img_pairs = {
    "examples/content/sailboat.jpg": "examples/style/sketch.png",
    "examples/content/granatum.jpg": "examples/style/flowers_in_a_turquoise_vase.jpg",
    "examples/content/einstein.jpeg": "examples/style/polasticot2.jpeg",
    "examples/content/paris.jpeg": "examples/style/vangogh.jpeg",
    "examples/content/cornell.jpg": "examples/style/asheville.jpg",
}

# Customize interface
title = "Style Transfer with EFDM"
description = """
Gradio demo for neural style transfer using exact feature distribution matching
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2203.07740'>Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization</a></p>"
content_input = gr.inputs.Image(label="Content Image", source="upload", type="pil")
style_input = gr.inputs.Image(label="Style Image", source="upload", type="pil")
style_type = gr.inputs.Radio(
    ["EFDM", "AdaIN", "AdaMean", "AdaStd", "HM"], label="Method"
)
alpha_selector = gr.inputs.Slider(
    minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Content-Style trade-off"
)
keep_resolution = gr.inputs.Checkbox(
    default=True, label="Keep content image resolution"
)

iface = gr.Interface(
    fn=style_transfer,
    inputs=[content_input, style_input, style_type, alpha_selector, keep_resolution],
    outputs=["image"],
    title=title,
    description=description,
    article=article,
    theme="huggingface",
    examples=[
        [content, style, "EFDM", 1.0, True]
        for content, style in example_img_pairs.items()
    ],
)
iface.launch(debug=False, enable_queue=True)