Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import base64
|
|
7 |
import uuid
|
8 |
import random
|
9 |
import logging
|
|
|
10 |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
|
11 |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
|
12 |
from src.unet_hacked_tryon import UNet2DConditionModel
|
@@ -154,6 +155,7 @@ def clear_gpu_memory():
|
|
154 |
|
155 |
# Main try-on function
|
156 |
@torch.no_grad()
|
|
|
157 |
def start_tryon(human_dict, garment_img, garment_des, use_auto_mask, use_auto_crop, denoise_steps, seed, categorie='upper_body'):
|
158 |
try:
|
159 |
device = torch.device("cuda")
|
@@ -208,7 +210,7 @@ def start_tryon(human_dict, garment_img, garment_des, use_auto_mask, use_auto_cr
|
|
208 |
logging.error(f"Error during try-on: {e}")
|
209 |
raise
|
210 |
finally:
|
211 |
-
clear_gpu_memory()
|
212 |
|
213 |
# API endpoints
|
214 |
@app.route('/tryon', methods=['POST'])
|
|
|
7 |
import uuid
|
8 |
import random
|
9 |
import logging
|
10 |
+
import spaces
|
11 |
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
|
12 |
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
|
13 |
from src.unet_hacked_tryon import UNet2DConditionModel
|
|
|
155 |
|
156 |
# Main try-on function
|
157 |
@torch.no_grad()
|
158 |
+
@spaces.GPU
|
159 |
def start_tryon(human_dict, garment_img, garment_des, use_auto_mask, use_auto_crop, denoise_steps, seed, categorie='upper_body'):
|
160 |
try:
|
161 |
device = torch.device("cuda")
|
|
|
210 |
logging.error(f"Error during try-on: {e}")
|
211 |
raise
|
212 |
finally:
|
213 |
+
#clear_gpu_memory()
|
214 |
|
215 |
# API endpoints
|
216 |
@app.route('/tryon', methods=['POST'])
|