matgen / rgb2x /gradio_demo_rgb2x.py
jingyangcarl's picture
update code
2eba94f
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import gradio as gr
import torch
import torchvision
from diffusers import DDIMScheduler
from load_image import load_exr_image, load_ldr_image
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
current_directory = os.path.dirname(os.path.abspath(__file__))
def get_rgb2x_demo():
# Load pipeline
pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
"zheng95z/rgb-to-x",
torch_dtype=torch.float16,
cache_dir=os.path.join(current_directory, "model_cache"),
).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
)
pipe.set_progress_bar_config(disable=True)
pipe.to("cuda")
# Augmentation
def callback(
photo,
seed,
inference_step,
num_samples,
):
generator = torch.Generator(device="cuda").manual_seed(seed)
if photo.name.endswith(".exr"):
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
elif (
photo.name.endswith(".png")
or photo.name.endswith(".jpg")
or photo.name.endswith(".jpeg")
):
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
old_height = photo.shape[1]
old_width = photo.shape[2]
new_height = old_height
new_width = old_width
radio = old_height / old_width
max_side = 1000
if old_height > old_width:
new_height = max_side
new_width = int(new_height / radio)
else:
new_width = max_side
new_height = int(new_width * radio)
if new_width % 8 != 0 or new_height % 8 != 0:
new_width = new_width // 8 * 8
new_height = new_height // 8 * 8
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
prompts = {
"albedo": "Albedo (diffuse basecolor)",
"normal": "Camera-space Normal",
"roughness": "Roughness",
"metallic": "Metallicness",
"irradiance": "Irradiance (diffuse lighting)",
}
return_list = []
for i in range(num_samples):
for aov_name in required_aovs:
prompt = prompts[aov_name]
generated_image = pipe(
prompt=prompt,
photo=photo,
num_inference_steps=inference_step,
height=new_height,
width=new_width,
generator=generator,
required_aovs=[aov_name],
).images[0][0]
generated_image = torchvision.transforms.Resize(
(old_height, old_width)
)(generated_image)
generated_image = (generated_image, f"Generated {aov_name} {i}")
return_list.append(generated_image)
return return_list
block = gr.Blocks()
with block:
with gr.Row():
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
with gr.Row():
# Input side
with gr.Column():
gr.Markdown("### Given Image")
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
gr.Markdown("### Parameters")
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
randomize=True,
)
inference_step = gr.Slider(
label="Inference Step",
minimum=1,
maximum=100,
step=1,
value=50,
)
num_samples = gr.Slider(
label="Samples",
minimum=1,
maximum=100,
step=1,
value=1,
)
# Output side
with gr.Column():
gr.Markdown("### Output Gallery")
result_gallery = gr.Gallery(
label="Output",
show_label=False,
elem_id="gallery",
columns=2,
)
inputs = [
photo,
seed,
inference_step,
num_samples,
]
run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True)
return block
if __name__ == "__main__":
demo = get_rgb2x_demo()
demo.queue(max_size=1)
demo.launch()