Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline | |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref | |
from src.unet_hacked_tryon import UNet2DConditionModel | |
from transformers import ( | |
CLIPImageProcessor, | |
CLIPVisionModelWithProjection, | |
CLIPTextModel, | |
CLIPTextModelWithProjection, | |
) | |
from diffusers import DDPMScheduler,AutoencoderKL | |
from typing import List | |
import torch | |
import os | |
from transformers import AutoTokenizer | |
import spaces | |
import numpy as np | |
from utils_mask import get_mask_location | |
from torchvision import transforms | |
import apply_net | |
from preprocess.humanparsing.run_parsing import Parsing | |
from preprocess.openpose.run_openpose import OpenPose | |
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation | |
from torchvision.transforms.functional import to_pil_image | |
def pil_to_binary_mask(pil_image, threshold=0): | |
np_image = np.array(pil_image) | |
grayscale_image = Image.fromarray(np_image).convert("L") | |
binary_mask = np.array(grayscale_image) > threshold | |
mask = np.zeros(binary_mask.shape, dtype=np.uint8) | |
for i in range(binary_mask.shape[0]): | |
for j in range(binary_mask.shape[1]): | |
if binary_mask[i,j] == True : | |
mask[i,j] = 1 | |
mask = (mask*255).astype(np.uint8) | |
output_mask = Image.fromarray(mask) | |
return output_mask | |
base_path = 'yisol/IDM-VTON' | |
example_path = os.path.join(os.path.dirname(__file__), 'example') | |
unet = UNet2DConditionModel.from_pretrained( | |
base_path, | |
subfolder="unet", | |
torch_dtype=torch.float16, | |
) | |
unet.requires_grad_(False) | |
tokenizer_one = AutoTokenizer.from_pretrained( | |
base_path, | |
subfolder="tokenizer", | |
revision=None, | |
use_fast=False, | |
) | |
tokenizer_two = AutoTokenizer.from_pretrained( | |
base_path, | |
subfolder="tokenizer_2", | |
revision=None, | |
use_fast=False, | |
) | |
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") | |
text_encoder_one = CLIPTextModel.from_pretrained( | |
base_path, | |
subfolder="text_encoder", | |
torch_dtype=torch.float16, | |
) | |
text_encoder_two = CLIPTextModelWithProjection.from_pretrained( | |
base_path, | |
subfolder="text_encoder_2", | |
torch_dtype=torch.float16, | |
) | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
base_path, | |
subfolder="image_encoder", | |
torch_dtype=torch.float16, | |
) | |
vae = AutoencoderKL.from_pretrained(base_path, | |
subfolder="vae", | |
torch_dtype=torch.float16, | |
) | |
# "stabilityai/stable-diffusion-xl-base-1.0", | |
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained( | |
base_path, | |
subfolder="unet_encoder", | |
torch_dtype=torch.float16, | |
) | |
parsing_model = Parsing(0) | |
openpose_model = OpenPose(0) | |
UNet_Encoder.requires_grad_(False) | |
image_encoder.requires_grad_(False) | |
vae.requires_grad_(False) | |
unet.requires_grad_(False) | |
text_encoder_one.requires_grad_(False) | |
text_encoder_two.requires_grad_(False) | |
tensor_transfrom = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
pipe = TryonPipeline.from_pretrained( | |
base_path, | |
unet=unet, | |
vae=vae, | |
feature_extractor= CLIPImageProcessor(), | |
text_encoder = text_encoder_one, | |
text_encoder_2 = text_encoder_two, | |
tokenizer = tokenizer_one, | |
tokenizer_2 = tokenizer_two, | |
scheduler = noise_scheduler, | |
image_encoder=image_encoder, | |
torch_dtype=torch.float16, | |
) | |
pipe.unet_encoder = UNet_Encoder | |
def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed, category): | |
device = "cuda" | |
category = int(category) | |
if category==0: | |
category='upper_body' | |
elif category==1: | |
category='lower_body' | |
else: | |
category='dresses' | |
openpose_model.preprocessor.body_estimation.model.to(device) | |
pipe.to(device) | |
pipe.unet_encoder.to(device) | |
garm_img= garm_img.convert("RGB").resize((768,1024)) | |
human_img_orig = dict["background"].convert("RGB") | |
if is_checked_crop: | |
width, height = human_img_orig.size | |
target_width = int(min(width, height * (3 / 4))) | |
target_height = int(min(height, width * (4 / 3))) | |
left = (width - target_width) / 2 | |
top = (height - target_height) / 2 | |
right = (width + target_width) / 2 | |
bottom = (height + target_height) / 2 | |
cropped_img = human_img_orig.crop((left, top, right, bottom)) | |
crop_size = cropped_img.size | |
human_img = cropped_img.resize((768,1024)) | |
else: | |
human_img = human_img_orig.resize((768,1024)) | |
if is_checked: | |
keypoints = openpose_model(human_img.resize((384,512))) | |
model_parse, _ = parsing_model(human_img.resize((384,512))) | |
mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints) | |
mask = mask.resize((768,1024)) | |
else: | |
mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024))) | |
# mask = transforms.ToTensor()(mask) | |
# mask = mask.unsqueeze(0) | |
mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img) | |
mask_gray = to_pil_image((mask_gray+1.0)/2.0) | |
human_img_arg = _apply_exif_orientation(human_img.resize((384,512))) | |
human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") | |
args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) | |
# verbosity = getattr(args, "verbosity", None) | |
pose_img = args.func(args,human_img_arg) | |
pose_img = pose_img[:,:,::-1] | |
pose_img = Image.fromarray(pose_img).resize((768,1024)) | |
with torch.no_grad(): | |
# Extract the images | |
with torch.cuda.amp.autocast(): | |
with torch.no_grad(): | |
prompt = "model is wearing " + garment_des | |
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, contortionist, amputee, polydactyly, deformed, distorted, misshapen, malformed, abnormal, mutant, defaced, shapeless, missing arms, three feet, bad face, extra fingers, fused crus, ugly fingers, three hands, fused face, extra eyes, bad hands, worst thigh, fused thigh, extra thigh, long fingers, three legs, missing fingers, cloned face, disconnected limbs, bad anatomy, huge eyes, three crus, worst face, amputation, 2girl, missing legs, poorly drawn face, horn, realistic photo, fused feet, worst feet, three thigh, extra crus" | |
with torch.inference_mode(): | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = pipe.encode_prompt( | |
prompt, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, | |
) | |
prompt = "a photo of " + garment_des | |
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, contortionist, amputee, polydactyly, deformed, distorted, misshapen, malformed, abnormal, mutant, defaced, shapeless, missing arms, three feet, bad face, extra fingers, fused crus, ugly fingers, three hands, fused face, extra eyes, bad hands, worst thigh, fused thigh, extra thigh, long fingers, three legs, missing fingers, cloned face, disconnected limbs, bad anatomy, huge eyes, three crus, worst face, amputation, 2girl, missing legs, poorly drawn face, horn, realistic photo, fused feet, worst feet, three thigh, extra crus" | |
if not isinstance(prompt, List): | |
prompt = [prompt] * 1 | |
if not isinstance(negative_prompt, List): | |
negative_prompt = [negative_prompt] * 1 | |
with torch.inference_mode(): | |
( | |
prompt_embeds_c, | |
_, | |
_, | |
_, | |
) = pipe.encode_prompt( | |
prompt, | |
num_images_per_prompt=1, | |
do_classifier_free_guidance=False, | |
negative_prompt=negative_prompt, | |
) | |
pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16) | |
garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16) | |
generator = torch.Generator(device).manual_seed(seed) if seed is not None else None | |
images = pipe( | |
prompt_embeds=prompt_embeds.to(device,torch.float16), | |
negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16), | |
pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16), | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16), | |
num_inference_steps=denoise_steps, | |
generator=generator, | |
strength = 1.0, | |
pose_img = pose_img.to(device,torch.float16), | |
text_embeds_cloth=prompt_embeds_c.to(device,torch.float16), | |
cloth = garm_tensor.to(device,torch.float16), | |
mask_image=mask, | |
image=human_img, | |
height=1024, | |
width=768, | |
ip_adapter_image = garm_img.resize((768,1024)), | |
guidance_scale=2.0, | |
)[0] | |
if is_checked_crop: | |
out_img = images[0].resize(crop_size) | |
human_img_orig.paste(out_img, (int(left), int(top))) | |
return human_img_orig, mask_gray | |
else: | |
return images[0], mask_gray | |
# return images[0], mask_gray | |
garm_list = os.listdir(os.path.join(example_path,"cloth")) | |
garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list] | |
human_list = os.listdir(os.path.join(example_path,"human")) | |
human_list_path = [os.path.join(example_path,"human",human) for human in human_list] | |
human_ex_list = [] | |
for ex_human in human_list_path: | |
ex_dict= {} | |
ex_dict['background'] = ex_human | |
ex_dict['layers'] = None | |
ex_dict['composite'] = None | |
human_ex_list.append(ex_dict) | |
##default human | |
image_blocks = gr.Blocks().queue() | |
with image_blocks as demo: | |
gr.Markdown("## IDM-VTON πππ") | |
gr.Markdown("Virtual Try-on with your image and garment image. Check out the [source codes](https://github.com/yisol/IDM-VTON) and the [model](https://huggingface.co/yisol/IDM-VTON)") | |
with gr.Row(): | |
with gr.Column(): | |
imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True) | |
with gr.Row(): | |
is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True) | |
with gr.Row(): | |
is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False) | |
with gr.Row(): | |
category = gr.Textbox(placeholder="0 = upper body, 1 = lower body, 2 = full body", show_label=False, elem_id="prompt") | |
example = gr.Examples( | |
inputs=imgs, | |
examples_per_page=10, | |
examples=human_ex_list | |
) | |
with gr.Column(): | |
garm_img = gr.Image(label="Garment", sources='upload', type="pil") | |
with gr.Row(elem_id="prompt-container"): | |
with gr.Row(): | |
prompt = gr.Textbox(placeholder="Description of garment ex) Short Sleeve Round Neck T-shirts", show_label=False, elem_id="prompt") | |
example = gr.Examples( | |
inputs=garm_img, | |
examples_per_page=8, | |
examples=garm_list_path) | |
with gr.Column(): | |
# image_out = gr.Image(label="Output", elem_id="output-img", height=400) | |
masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False) | |
with gr.Column(): | |
# image_out = gr.Image(label="Output", elem_id="output-img", height=400) | |
image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False) | |
with gr.Column(): | |
try_button = gr.Button(value="Try-on") | |
with gr.Accordion(label="Advanced Settings", open=False): | |
with gr.Row(): | |
denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1) | |
seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42) | |
try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed, category], outputs=[image_out,masked_img], api_name='tryon') | |
image_blocks.launch() | |