|
from flask import Flask, request, jsonify |
|
import torch |
|
from transformers import ( |
|
UNet2DConditionModel, |
|
AutoTokenizer, |
|
CLIPTextModel, |
|
CLIPTextModelWithProjection, |
|
CLIPVisionModelWithProjection, |
|
AutoencoderKL, |
|
DDPMScheduler |
|
) |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
unet = None |
|
tokenizer_one = None |
|
tokenizer_two = None |
|
noise_scheduler = None |
|
text_encoder_one = None |
|
text_encoder_two = None |
|
image_encoder = None |
|
vae = None |
|
UNet_Encoder = None |
|
|
|
|
|
def load_models(): |
|
global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder |
|
base_path = "your_base_path_here" |
|
unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False) |
|
tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False) |
|
tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False) |
|
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") |
|
text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False) |
|
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False) |
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False) |
|
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False) |
|
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False) |
|
|
|
|
|
load_models() |
|
|
|
|
|
def clear_gpu_memory(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.synchronize() |
|
|
|
|
|
def base64_to_image(base64_str): |
|
image_data = base64.b64decode(base64_str) |
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
return image |
|
|
|
|
|
def resize_image(image, size=(512, 768)): |
|
return image.resize(size) |
|
|
|
|
|
@app.route('/start_tryon', methods=['POST']) |
|
def start_tryon(): |
|
data = request.get_json() |
|
garm_img_base64 = data['garm_img'] |
|
human_img_base64 = data['human_img'] |
|
|
|
|
|
garm_img = resize_image(base64_to_image(garm_img_base64)) |
|
human_img = resize_image(base64_to_image(human_img_base64)) |
|
|
|
|
|
garm_img_tensor = torch.tensor(garm_img, dtype=torch.float16).unsqueeze(0).to('cuda') |
|
human_img_tensor = torch.tensor(human_img, dtype=torch.float16).unsqueeze(0).to('cuda') |
|
|
|
try: |
|
|
|
with torch.inference_mode(): |
|
|
|
result_tensor = unet(garm_img_tensor, human_img_tensor) |
|
|
|
|
|
clear_gpu_memory() |
|
|
|
|
|
result_img = Image.fromarray(result_tensor.squeeze(0).cpu().numpy()) |
|
buffered = BytesIO() |
|
result_img.save(buffered, format="JPEG") |
|
result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
return jsonify({"result": result_base64}) |
|
|
|
except Exception as e: |
|
clear_gpu_memory() |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |
|
|