Spaces:
Sleeping
Sleeping
File size: 5,154 Bytes
2eba94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import gradio as gr
import torch
import torchvision
from diffusers import DDIMScheduler
from load_image import load_exr_image, load_ldr_image
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
current_directory = os.path.dirname(os.path.abspath(__file__))
def get_rgb2x_demo():
# Load pipeline
pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
"zheng95z/rgb-to-x",
torch_dtype=torch.float16,
cache_dir=os.path.join(current_directory, "model_cache"),
).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
)
pipe.set_progress_bar_config(disable=True)
pipe.to("cuda")
# Augmentation
def callback(
photo,
seed,
inference_step,
num_samples,
):
generator = torch.Generator(device="cuda").manual_seed(seed)
if photo.name.endswith(".exr"):
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
elif (
photo.name.endswith(".png")
or photo.name.endswith(".jpg")
or photo.name.endswith(".jpeg")
):
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
old_height = photo.shape[1]
old_width = photo.shape[2]
new_height = old_height
new_width = old_width
radio = old_height / old_width
max_side = 1000
if old_height > old_width:
new_height = max_side
new_width = int(new_height / radio)
else:
new_width = max_side
new_height = int(new_width * radio)
if new_width % 8 != 0 or new_height % 8 != 0:
new_width = new_width // 8 * 8
new_height = new_height // 8 * 8
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
prompts = {
"albedo": "Albedo (diffuse basecolor)",
"normal": "Camera-space Normal",
"roughness": "Roughness",
"metallic": "Metallicness",
"irradiance": "Irradiance (diffuse lighting)",
}
return_list = []
for i in range(num_samples):
for aov_name in required_aovs:
prompt = prompts[aov_name]
generated_image = pipe(
prompt=prompt,
photo=photo,
num_inference_steps=inference_step,
height=new_height,
width=new_width,
generator=generator,
required_aovs=[aov_name],
).images[0][0]
generated_image = torchvision.transforms.Resize(
(old_height, old_width)
)(generated_image)
generated_image = (generated_image, f"Generated {aov_name} {i}")
return_list.append(generated_image)
return return_list
block = gr.Blocks()
with block:
with gr.Row():
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
with gr.Row():
# Input side
with gr.Column():
gr.Markdown("### Given Image")
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
gr.Markdown("### Parameters")
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
randomize=True,
)
inference_step = gr.Slider(
label="Inference Step",
minimum=1,
maximum=100,
step=1,
value=50,
)
num_samples = gr.Slider(
label="Samples",
minimum=1,
maximum=100,
step=1,
value=1,
)
# Output side
with gr.Column():
gr.Markdown("### Output Gallery")
result_gallery = gr.Gallery(
label="Output",
show_label=False,
elem_id="gallery",
columns=2,
)
inputs = [
photo,
seed,
inference_step,
num_samples,
]
run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True)
return block
if __name__ == "__main__":
demo = get_rgb2x_demo()
demo.queue(max_size=1)
demo.launch() |