SmartLuga / app.py
Keshabwi66's picture
Create app.py
09402c7 verified
raw
history blame
7.89 kB
import sys
sys.path.append('./')
import gradio as gr
import torch
from PIL import Image
import torch.nn.functional as F
from transformers import CLIPImageProcessor
# Add necessary imports and initialize the model as in your code...
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
from ip_adapter.ip_adapter import Resampler
import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision
import numpy as np
import torch
import torch.nn.functional as F
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from torchvision import transforms
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel,
from src.unet_hacked_tryon import UNet2DConditionModel
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
# Define a class to hold configuration arguments
class Args:
def __init__(self):
self.pretrained_model_name_or_path = "yisol/IDM-VTON"
self.width = 768
self.height = 1024
self.num_inference_steps = 10
self.seed = 42
self.guidance_scale = 2.0
self.mixed_precision = None
# Determine the device to be used for computations (CUDA if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger = get_logger(__name__, log_level="INFO")
def pil_to_tensor(images):
images = np.array(images).astype(np.float32) / 255.0
images = torch.from_numpy(images.transpose(2, 0, 1))
return images
args = Args()
# Define the data type for model weights
weight_dtype = torch.float16
if args.seed is not None:
set_seed(args.seed)
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=torch.float16,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
torch_dtype=torch.float16,
)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="image_encoder",
torch_dtype=torch.float16,
)
unet_encoder = UNet2DConditionModel_ref.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet_encoder",
torch_dtype=torch.float16,
)
text_encoder_one = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=torch.float16,
)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
torch_dtype=torch.float16,
)
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=None,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=None,
use_fast=False,
)
# Freeze vae and text_encoder and set unet to trainable
unet.requires_grad_(False)
vae.requires_grad_(False)
image_encoder.requires_grad_(False)
unet_encoder.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
unet_encoder.to(device, weight_dtype)
unet.eval()
unet_encoder.eval()
pipe = TryonPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
vae=vae,
feature_extractor= CLIPImageProcessor(),
text_encoder = text_encoder_one,
text_encoder_2 = text_encoder_two,
tokenizer = tokenizer_one,
tokenizer_2 = tokenizer_two,
scheduler = noise_scheduler,
image_encoder=image_encoder,
unet_encoder = unet_encoder,
torch_dtype=torch.float16,
).to(device)
# pipe.enable_sequential_cpu_offload()
# pipe.enable_model_cpu_offload()
# pipe.enable_vae_slicing()
# Function to generate the image based on inputs
def generate_virtual_try_on(person_image, cloth_image, mask_image, pose_image,cloth_des):
# Prepare the input images as tensors
person_image = person_image.resize((args.width, args.height))
cloth_image = cloth_image.resize((args.width, args.height))
mask_image = mask_image.resize((args.width, args.height))
pose_image = pose_image.resize((args.width, args.height))
# Define transformations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
guidance_scale=2.0
seed=42
to_tensor = transforms.ToTensor()
person_tensor = transform(person_image).unsqueeze(0).to(device) # Add batch dimension
cloth_pure = transform(cloth_image).unsqueeze(0).to(device)
mask_tensor = to_tensor(mask_image)[:1].unsqueeze(0).to(device) # Keep only one channel
pose_tensor = transform(pose_image).unsqueeze(0).to(device)
# Prepare text prompts
prompt = ["A person wearing the cloth"+cloth_des] # Example prompt
negative_prompt = ["monochrome, lowres, bad anatomy, worst quality, low quality"]
# Encode prompts
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_cloth = ["a photo of"+cloth_des]
with torch.inference_mode():
(
prompt_embeds_c,
_,
_,
_,
) = pipe.encode_prompt(
prompt_cloth,
num_images_per_prompt=1,
do_classifier_free_guidance=False,
negative_prompt=negative_prompt,
)
# Encode garment using IP-Adapter
clip_processor = CLIPImageProcessor()
image_embeds = clip_processor(images=cloth_image, return_tensors="pt").pixel_values.to(device)
# Generate the image
generator = torch.Generator(pipe.device).manual_seed(seed) if seed is not None else None
with torch.no_grad():
images = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=args.num_inference_steps,
generator=generator,
strength=1.0,
pose_img=pose_tensor,
text_embeds_cloth=prompt_embeds_c,
cloth=cloth_pure,
mask_image=mask_tensor,
image=(person_tensor + 1.0) / 2.0,
height=args.height,
width=args.width,
guidance_scale=guidance_scale,
ip_adapter_image=image_embeds,
)[0]
# Convert output image to PIL format for display
generated_image = transforms.ToPILImage()(images[0])
return generated_image
# Create Gradio interface
iface = gr.Interface(
fn=generate_virtual_try_on,
inputs=[
gr.Image(type="pil", label="Person Image"),
gr.Image(type="pil", label="Cloth Image"),
gr.Image(type="pil", label="Mask Image"),
gr.Image(type="pil", label="Pose Image"),
gr.Textbox(label="cloth_des"), # Add text input
],
outputs=gr.Image(type="pil", label="Generated Image"),
)
# Launch the interface
iface.launch()