Spaces:
Saad0KH
/
Runtime error

IDM-VTON / app.py
Saad0KH's picture
Update app.py
af7056a verified
raw
history blame
3.95 kB
from flask import Flask, request, jsonify
import torch
from transformers import (
UNet2DConditionModel,
AutoTokenizer,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
AutoencoderKL,
DDPMScheduler
)
from PIL import Image
import base64
from io import BytesIO
app = Flask(__name__)
# Global variables for models to load them once at startup
unet = None
tokenizer_one = None
tokenizer_two = None
noise_scheduler = None
text_encoder_one = None
text_encoder_two = None
image_encoder = None
vae = None
UNet_Encoder = None
# Load models once at startup
def load_models():
global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
base_path = "your_base_path_here"
unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False)
# Call the function to load models at startup
load_models()
# Helper function to free up GPU memory after processing
def clear_gpu_memory():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Helper function to convert base64 to image
def base64_to_image(base64_str):
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
# Helper function to resize images for faster processing
def resize_image(image, size=(512, 768)):
return image.resize(size)
# Example try-on function
@app.route('/start_tryon', methods=['POST'])
def start_tryon():
data = request.get_json()
garm_img_base64 = data['garm_img']
human_img_base64 = data['human_img']
# Decode and resize images
garm_img = resize_image(base64_to_image(garm_img_base64))
human_img = resize_image(base64_to_image(human_img_base64))
# Convert images to tensors and move to GPU
garm_img_tensor = torch.tensor(garm_img, dtype=torch.float16).unsqueeze(0).to('cuda')
human_img_tensor = torch.tensor(human_img, dtype=torch.float16).unsqueeze(0).to('cuda')
try:
# Processing steps (dummy example, replace with your logic)
with torch.inference_mode():
# Run the inference for both images
result_tensor = unet(garm_img_tensor, human_img_tensor) # Replace with your actual logic
# Free GPU memory after inference
clear_gpu_memory()
# Convert result back to base64 for return
result_img = Image.fromarray(result_tensor.squeeze(0).cpu().numpy())
buffered = BytesIO()
result_img.save(buffered, format="JPEG")
result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return jsonify({"result": result_base64})
except Exception as e:
clear_gpu_memory()
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)