Spaces:
Saad0KH
/
Running on Zero

File size: 3,947 Bytes
af7056a
848b0e8
004975c
af7056a
 
004975c
 
af7056a
 
 
004975c
af7056a
 
 
848b0e8
 
 
af7056a
 
 
 
 
 
 
 
 
 
 
 
227771d
 
af7056a
227771d
 
 
 
 
 
 
 
 
 
af7056a
227771d
9c9e9a9
af7056a
227771d
 
 
 
af7056a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df466fb
af7056a
 
 
df466fb
af7056a
 
 
 
df466fb
af7056a
 
 
 
 
 
 
 
 
 
 
 
 
df466fb
af7056a
 
df466fb
af7056a
 
bbabb76
af7056a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from flask import Flask, request, jsonify
import torch
from transformers import (
    UNet2DConditionModel,
    AutoTokenizer,
    CLIPTextModel,
    CLIPTextModelWithProjection,
    CLIPVisionModelWithProjection,
    AutoencoderKL,
    DDPMScheduler
)
from PIL import Image
import base64
from io import BytesIO

app = Flask(__name__)

# Global variables for models to load them once at startup
unet = None
tokenizer_one = None
tokenizer_two = None
noise_scheduler = None
text_encoder_one = None
text_encoder_two = None
image_encoder = None
vae = None
UNet_Encoder = None

# Load models once at startup
def load_models():
    global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
    base_path = "your_base_path_here"
    unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
    tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
    tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
    noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
    text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
    text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
    image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
    vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
    UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False)

# Call the function to load models at startup
load_models()

# Helper function to free up GPU memory after processing
def clear_gpu_memory():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

# Helper function to convert base64 to image
def base64_to_image(base64_str):
    image_data = base64.b64decode(base64_str)
    image = Image.open(BytesIO(image_data)).convert("RGB")
    return image

# Helper function to resize images for faster processing
def resize_image(image, size=(512, 768)):
    return image.resize(size)

# Example try-on function
@app.route('/start_tryon', methods=['POST'])
def start_tryon():
    data = request.get_json()
    garm_img_base64 = data['garm_img']
    human_img_base64 = data['human_img']
    
    # Decode and resize images
    garm_img = resize_image(base64_to_image(garm_img_base64))
    human_img = resize_image(base64_to_image(human_img_base64))

    # Convert images to tensors and move to GPU
    garm_img_tensor = torch.tensor(garm_img, dtype=torch.float16).unsqueeze(0).to('cuda')
    human_img_tensor = torch.tensor(human_img, dtype=torch.float16).unsqueeze(0).to('cuda')
    
    try:
        # Processing steps (dummy example, replace with your logic)
        with torch.inference_mode():
            # Run the inference for both images
            result_tensor = unet(garm_img_tensor, human_img_tensor)  # Replace with your actual logic
            
        # Free GPU memory after inference
        clear_gpu_memory()

        # Convert result back to base64 for return
        result_img = Image.fromarray(result_tensor.squeeze(0).cpu().numpy())
        buffered = BytesIO()
        result_img.save(buffered, format="JPEG")
        result_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

        return jsonify({"result": result_base64})
    
    except Exception as e:
        clear_gpu_memory()
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)