File size: 5,154 Bytes
2eba94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os

os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"

import gradio as gr
import torch
import torchvision
from diffusers import DDIMScheduler
from load_image import load_exr_image, load_ldr_image
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline

current_directory = os.path.dirname(os.path.abspath(__file__))


def get_rgb2x_demo():
    # Load pipeline
    pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
        "zheng95z/rgb-to-x",
        torch_dtype=torch.float16,
        cache_dir=os.path.join(current_directory, "model_cache"),
    ).to("cuda")
    pipe.scheduler = DDIMScheduler.from_config(
        pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
    )
    pipe.set_progress_bar_config(disable=True)
    pipe.to("cuda")

    # Augmentation
    def callback(
        photo,
        seed,
        inference_step,
        num_samples,
    ):
        generator = torch.Generator(device="cuda").manual_seed(seed)

        if photo.name.endswith(".exr"):
            photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
        elif (
            photo.name.endswith(".png")
            or photo.name.endswith(".jpg")
            or photo.name.endswith(".jpeg")
        ):
            photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")

        # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
        old_height = photo.shape[1]
        old_width = photo.shape[2]
        new_height = old_height
        new_width = old_width
        radio = old_height / old_width
        max_side = 1000
        if old_height > old_width:
            new_height = max_side
            new_width = int(new_height / radio)
        else:
            new_width = max_side
            new_height = int(new_width * radio)

        if new_width % 8 != 0 or new_height % 8 != 0:
            new_width = new_width // 8 * 8
            new_height = new_height // 8 * 8

        photo = torchvision.transforms.Resize((new_height, new_width))(photo)

        required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
        prompts = {
            "albedo": "Albedo (diffuse basecolor)",
            "normal": "Camera-space Normal",
            "roughness": "Roughness",
            "metallic": "Metallicness",
            "irradiance": "Irradiance (diffuse lighting)",
        }

        return_list = []
        for i in range(num_samples):
            for aov_name in required_aovs:
                prompt = prompts[aov_name]
                generated_image = pipe(
                    prompt=prompt,
                    photo=photo,
                    num_inference_steps=inference_step,
                    height=new_height,
                    width=new_width,
                    generator=generator,
                    required_aovs=[aov_name],
                ).images[0][0]

                generated_image = torchvision.transforms.Resize(
                    (old_height, old_width)
                )(generated_image)

                generated_image = (generated_image, f"Generated {aov_name} {i}")
                return_list.append(generated_image)

        return return_list

    block = gr.Blocks()
    with block:
        with gr.Row():
            gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
        with gr.Row():
            # Input side
            with gr.Column():
                gr.Markdown("### Given Image")
                photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])

                gr.Markdown("### Parameters")
                run_button = gr.Button(value="Run")
                with gr.Accordion("Advanced options", open=False):
                    seed = gr.Slider(
                        label="Seed",
                        minimum=-1,
                        maximum=2147483647,
                        step=1,
                        randomize=True,
                    )
                    inference_step = gr.Slider(
                        label="Inference Step",
                        minimum=1,
                        maximum=100,
                        step=1,
                        value=50,
                    )
                    num_samples = gr.Slider(
                        label="Samples",
                        minimum=1,
                        maximum=100,
                        step=1,
                        value=1,
                    )

            # Output side
            with gr.Column():
                gr.Markdown("### Output Gallery")
                result_gallery = gr.Gallery(
                    label="Output",
                    show_label=False,
                    elem_id="gallery",
                    columns=2,
                )

        inputs = [
            photo,
            seed,
            inference_step,
            num_samples,
        ]
        run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True)

    return block


if __name__ == "__main__":
    demo = get_rgb2x_demo()
    demo.queue(max_size=1)
    demo.launch()