File size: 3,947 Bytes
af7056a 848b0e8 004975c af7056a 004975c af7056a 004975c af7056a 848b0e8 af7056a 227771d af7056a 227771d af7056a 227771d 9c9e9a9 af7056a 227771d af7056a df466fb af7056a df466fb af7056a df466fb af7056a df466fb af7056a df466fb af7056a bbabb76 af7056a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from flask import Flask, request, jsonify
import torch
from transformers import (
UNet2DConditionModel,
AutoTokenizer,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
AutoencoderKL,
DDPMScheduler
)
from PIL import Image
import base64
from io import BytesIO
app = Flask(__name__)
# Global variables for models to load them once at startup
unet = None
tokenizer_one = None
tokenizer_two = None
noise_scheduler = None
text_encoder_one = None
text_encoder_two = None
image_encoder = None
vae = None
UNet_Encoder = None
# Load models once at startup
def load_models():
global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
base_path = "your_base_path_here"
unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False)
# Call the function to load models at startup
load_models()
# Helper function to free up GPU memory after processing
def clear_gpu_memory():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Helper function to convert base64 to image
def base64_to_image(base64_str):
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
# Helper function to resize images for faster processing
def resize_image(image, size=(512, 768)):
return image.resize(size)
# Example try-on function
@app.route('/start_tryon', methods=['POST'])
def start_tryon():
data = request.get_json()
garm_img_base64 = data['garm_img']
human_img_base64 = data['human_img']
# Decode and resize images
garm_img = resize_image(base64_to_image(garm_img_base64))
human_img = resize_image(base64_to_image(human_img_base64))
# Convert images to tensors and move to GPU
garm_img_tensor = torch.tensor(garm_img, dtype=torch.float16).unsqueeze(0).to('cuda')
human_img_tensor = torch.tensor(human_img, dtype=torch.float16).unsqueeze(0).to('cuda')
try:
# Processing steps (dummy example, replace with your logic)
with torch.inference_mode():
# Run the inference for both images
result_tensor = unet(garm_img_tensor, human_img_tensor) # Replace with your actual logic
# Free GPU memory after inference
clear_gpu_memory()
# Convert result back to base64 for return
result_img = Image.fromarray(result_tensor.squeeze(0).cpu().numpy())
buffered = BytesIO()
result_img.save(buffered, format="JPEG")
result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return jsonify({"result": result_base64})
except Exception as e:
clear_gpu_memory()
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)
|