Spaces:
Sleeping
Sleeping
import os | |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
import gradio as gr | |
import torch | |
import torchvision | |
from diffusers import DDIMScheduler | |
from load_image import load_exr_image, load_ldr_image | |
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline | |
current_directory = os.path.dirname(os.path.abspath(__file__)) | |
def get_rgb2x_demo(): | |
# Load pipeline | |
pipe = StableDiffusionAOVMatEstPipeline.from_pretrained( | |
"zheng95z/rgb-to-x", | |
torch_dtype=torch.float16, | |
cache_dir=os.path.join(current_directory, "model_cache"), | |
).to("cuda") | |
pipe.scheduler = DDIMScheduler.from_config( | |
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" | |
) | |
pipe.set_progress_bar_config(disable=True) | |
pipe.to("cuda") | |
# Augmentation | |
def callback( | |
photo, | |
seed, | |
inference_step, | |
num_samples, | |
): | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
if photo.name.endswith(".exr"): | |
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda") | |
elif ( | |
photo.name.endswith(".png") | |
or photo.name.endswith(".jpg") | |
or photo.name.endswith(".jpeg") | |
): | |
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda") | |
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop | |
old_height = photo.shape[1] | |
old_width = photo.shape[2] | |
new_height = old_height | |
new_width = old_width | |
radio = old_height / old_width | |
max_side = 1000 | |
if old_height > old_width: | |
new_height = max_side | |
new_width = int(new_height / radio) | |
else: | |
new_width = max_side | |
new_height = int(new_width * radio) | |
if new_width % 8 != 0 or new_height % 8 != 0: | |
new_width = new_width // 8 * 8 | |
new_height = new_height // 8 * 8 | |
photo = torchvision.transforms.Resize((new_height, new_width))(photo) | |
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] | |
prompts = { | |
"albedo": "Albedo (diffuse basecolor)", | |
"normal": "Camera-space Normal", | |
"roughness": "Roughness", | |
"metallic": "Metallicness", | |
"irradiance": "Irradiance (diffuse lighting)", | |
} | |
return_list = [] | |
for i in range(num_samples): | |
for aov_name in required_aovs: | |
prompt = prompts[aov_name] | |
generated_image = pipe( | |
prompt=prompt, | |
photo=photo, | |
num_inference_steps=inference_step, | |
height=new_height, | |
width=new_width, | |
generator=generator, | |
required_aovs=[aov_name], | |
).images[0][0] | |
generated_image = torchvision.transforms.Resize( | |
(old_height, old_width) | |
)(generated_image) | |
generated_image = (generated_image, f"Generated {aov_name} {i}") | |
return_list.append(generated_image) | |
return return_list | |
block = gr.Blocks() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)") | |
with gr.Row(): | |
# Input side | |
with gr.Column(): | |
gr.Markdown("### Given Image") | |
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"]) | |
gr.Markdown("### Parameters") | |
run_button = gr.Button(value="Run") | |
with gr.Accordion("Advanced options", open=False): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=-1, | |
maximum=2147483647, | |
step=1, | |
randomize=True, | |
) | |
inference_step = gr.Slider( | |
label="Inference Step", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=50, | |
) | |
num_samples = gr.Slider( | |
label="Samples", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=1, | |
) | |
# Output side | |
with gr.Column(): | |
gr.Markdown("### Output Gallery") | |
result_gallery = gr.Gallery( | |
label="Output", | |
show_label=False, | |
elem_id="gallery", | |
columns=2, | |
) | |
inputs = [ | |
photo, | |
seed, | |
inference_step, | |
num_samples, | |
] | |
run_button.click(fn=callback, inputs=inputs, outputs=result_gallery, queue=True) | |
return block | |
if __name__ == "__main__": | |
demo = get_rgb2x_demo() | |
demo.queue(max_size=1) | |
demo.launch() |