Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import huggingface_hub | |
| huggingface_hub.snapshot_download( | |
| repo_id='h94/IP-Adapter', | |
| allow_patterns=[ | |
| 'models/**', | |
| 'sdxl_models/**', | |
| ], | |
| local_dir='./' | |
| ) | |
| import gradio as gr | |
| from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel | |
| from rembg import remove | |
| from PIL import Image | |
| import torch | |
| from ip_adapter import IPAdapterXL | |
| from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images | |
| from PIL import Image, ImageChops, ImageEnhance | |
| import numpy as np | |
| import os | |
| import glob | |
| import torch | |
| import cv2 | |
| import argparse | |
| import DPT.util.io | |
| from torchvision.transforms import Compose | |
| from DPT.dpt.models import DPTDepthModel | |
| from DPT.dpt.midas_net import MidasNet_large | |
| from DPT.dpt.transforms import Resize, NormalizeImage, PrepareForNet | |
| """ | |
| Get ZeST Ready | |
| """ | |
| base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
| image_encoder_path = "models/image_encoder" | |
| ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" | |
| controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" | |
| device = "cuda" | |
| torch.cuda.empty_cache() | |
| # load SDXL pipeline | |
| controlnet = ControlNetModel.from_pretrained(controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to(device) | |
| pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( | |
| base_model_path, | |
| controlnet=controlnet, | |
| use_safetensors=True, | |
| torch_dtype=torch.float16, | |
| add_watermarker=False, | |
| ).to(device) | |
| pipe.unet = register_cross_attention_hook(pipe.unet) | |
| ip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device) | |
| """ | |
| Get Depth Model Ready | |
| """ | |
| model_path = "DPT/weights/dpt_hybrid-midas-501f0c75.pt" | |
| net_w = net_h = 384 | |
| model = DPTDepthModel( | |
| path=model_path, | |
| backbone="vitb_rn50_384", | |
| non_negative=True, | |
| enable_attention_hooks=False, | |
| ) | |
| normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| transform = Compose( | |
| [ | |
| Resize( | |
| net_w, | |
| net_h, | |
| resize_target=None, | |
| keep_aspect_ratio=True, | |
| ensure_multiple_of=32, | |
| resize_method="minimal", | |
| image_interpolation_method=cv2.INTER_CUBIC, | |
| ), | |
| normalization, | |
| PrepareForNet(), | |
| ] | |
| ) | |
| model.eval() | |
| def infer(input_image, material_exemplar, progress=gr.Progress(track_tqdm=True)): | |
| """ | |
| Perform zero-shot material transfer from a single input image and a material exemplar image. | |
| This function uses a combination of a depth estimation model (DPT), foreground/background separation, | |
| grayscale stylization, and IP-Adapter+ControlNet with Stable Diffusion XL to generate an output image | |
| in which the material style from the exemplar image is applied to the input image's object. | |
| Args: | |
| input_image (PIL.Image): The original image containing the object to which the new material will be applied. | |
| material_exemplar (PIL.Image): A reference image whose material (texture, reflectance, etc.) is to be transferred to the object in the input image. | |
| progress (gradio.Progress, optional): For tracking the progress bar in Gradio UI. Default enables tqdm tracking. | |
| Returns: | |
| PIL.Image: The output image showing the object from `input_image` rendered with the material of `material_exemplar`. | |
| Steps: | |
| 1. Compute a depth map from `input_image` using a DPT-based model. | |
| 2. Remove the background from the input image to isolate the object and convert it into a grayscale version. | |
| 3. Combine and align the input image, depth map, and mask for use with the IP-Adapter + ControlNet SDXL pipeline. | |
| 4. Use the `IPAdapterXL.generate()` function to synthesize a new image by guiding generation using: | |
| - material_exemplar for style/material guidance | |
| - input_image's structure/content in grayscale | |
| - the estimated depth map for spatial layout | |
| - the mask for region-specific conditioning (object-only) | |
| 5. Return the first image in the generated list as the final material transfer result. | |
| """ | |
| """ | |
| Compute depth map from input_image | |
| """ | |
| img = np.array(input_image) | |
| img_input = transform({"image": img})["image"] | |
| # compute | |
| with torch.no_grad(): | |
| sample = torch.from_numpy(img_input).unsqueeze(0) | |
| # if optimize == True and device == torch.device("cuda"): | |
| # sample = sample.to(memory_format=torch.channels_last) | |
| # sample = sample.half() | |
| prediction = model.forward(sample) | |
| prediction = ( | |
| torch.nn.functional.interpolate( | |
| prediction.unsqueeze(1), | |
| size=img.shape[:2], | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| .squeeze() | |
| .cpu() | |
| .numpy() | |
| ) | |
| depth_min = prediction.min() | |
| depth_max = prediction.max() | |
| bits = 2 | |
| max_val = (2 ** (8 * bits)) - 1 | |
| if depth_max - depth_min > np.finfo("float").eps: | |
| out = max_val * (prediction - depth_min) / (depth_max - depth_min) | |
| else: | |
| out = np.zeros(prediction.shape, dtype=depth.dtype) | |
| out = (out / 256).astype('uint8') | |
| depth_map = Image.fromarray(out).resize((1024, 1024)) | |
| """ | |
| Process foreground decolored image | |
| """ | |
| rm_bg = remove(input_image) | |
| target_mask = rm_bg.convert("RGB").point(lambda x: 0 if x < 1 else 255).convert('L').convert('RGB') | |
| mask_target_img = ImageChops.lighter(input_image, target_mask) | |
| invert_target_mask = ImageChops.invert(target_mask) | |
| gray_target_image = input_image.convert('L').convert('RGB') | |
| gray_target_image = ImageEnhance.Brightness(gray_target_image) | |
| factor = 1.0 # Try adjusting this to get the desired brightness | |
| gray_target_image = gray_target_image.enhance(factor) | |
| grayscale_img = ImageChops.darker(gray_target_image, target_mask) | |
| img_black_mask = ImageChops.darker(input_image, invert_target_mask) | |
| grayscale_init_img = ImageChops.lighter(img_black_mask, grayscale_img) | |
| init_img = grayscale_init_img | |
| """ | |
| Process material exemplar and resize all images | |
| """ | |
| ip_image = material_exemplar.resize((1024, 1024)) | |
| init_img = init_img.resize((1024,1024)) | |
| mask = target_mask.resize((1024, 1024)) | |
| num_samples = 1 | |
| images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42) | |
| return images[0] | |
| css = """ | |
| #col-container{ | |
| margin: 0 auto; | |
| max-width: 960px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(""" | |
| # ZeST: Zero-Shot Material Transfer from a Single Image | |
| <p>Upload two images -- input image and material exemplar. (both 1024*1024 for better results) <br /> | |
| ZeST extracts the material from the exemplar and cast it onto the input image following the original lighting cues.</p> | |
| """) | |
| gr.HTML(""" | |
| <div style="display:flex;column-gap:4px;"> | |
| <a href="https://github.com/ttchengab/zest_code"> | |
| <img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
| </a> | |
| <a href="https://ttchengab.github.io/zest/"> | |
| <img src='https://img.shields.io/badge/Project-Page-green'> | |
| </a> | |
| <a href="https://arxiv.org/abs/2404.06425"> | |
| <img src='https://img.shields.io/badge/ArXiv-Paper-red'> | |
| </a> | |
| <a href="https://huggingface.co/spaces/fffiloni/ZeST?duplicate=true"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
| </a> | |
| <a href="https://huggingface.co/fffiloni"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF"> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil", label="input image") | |
| input_image2 = gr.Image(type="pil", label = "material examplar") | |
| submit_btn = gr.Button("Submit") | |
| gr.Examples( | |
| examples = [["demo_assets/input_imgs/pumpkin.png", "demo_assets/material_exemplars/cup_glaze.png"]], | |
| inputs = [input_image, input_image2] | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image(label="transfer result") | |
| submit_btn.click(fn=infer, inputs=[input_image, input_image2], outputs=[output_image]) | |
| demo.queue().launch(show_error=True, ssr_mode=False, mcp_server=True) | |