Spaces:
Running
Running
File size: 4,247 Bytes
6519940 4e51761 6519940 4e51761 6519940 4e51761 6519940 4e51761 6519940 4e51761 6519940 4e51761 6519940 4e51761 6519940 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from gradio_client import Client, handle_file
from fastapi.responses import JSONResponse
import tempfile
import os
import uuid
import asyncio
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Gradio API URL
GRADIO_API_URL = "jallenjia/Change-Clothes-AI"
# Thread pool for Gradio API calls (to avoid blocking async loop)
executor = ThreadPoolExecutor(max_workers=None) # Changed from max_workers=10 to max_workers=None to remove limit
# Context manager for temporary files
@asynccontextmanager
async def temp_file_manager(file_content: bytes, suffix: str = ".png"):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=f"tryon_{uuid.uuid4()}_")
try:
temp_file.write(file_content)
temp_file.close()
yield temp_file.name
finally:
try:
if os.path.exists(temp_file.name):
os.unlink(temp_file.name)
except Exception as e:
logger.error(f"Failed to delete temp file {temp_file.name}: {e}")
# Run Gradio API call in thread pool to avoid blocking
async def run_gradio_predict(
background_path: str,
garm_img_path: str,
garment_des: str,
is_checked: bool,
is_checked_crop: bool,
denoise_steps: int,
seed: int,
category: str
):
loop = asyncio.get_event_loop()
try:
client = Client(GRADIO_API_URL)
result = await loop.run_in_executor(
executor,
lambda: client.predict(
dict={
"background": handle_file(background_path),
"layers": [],
"composite": None
},
garm_img=handle_file(garm_img_path),
garment_des=garment_des,
is_checked=is_checked,
is_checked_crop=is_checked_crop,
denoise_steps=denoise_steps,
seed=seed,
category=category,
api_name="/tryon"
)
)
return result
except Exception as e:
logger.error(f"Gradio API error: {e}")
raise HTTPException(status_code=500, detail=f"Gradio API error: {str(e)}")
@app.post("/tryon")
async def tryon(
background: UploadFile = File(...),
garm_img: UploadFile = File(...),
garment_des: str = Form("navy blue polo shirt"),
is_checked: bool = Form(True),
is_checked_crop: bool = Form(False),
denoise_steps: int = Form(30),
seed: int = Form(42),
category: str = Form("upper_body")
):
try:
# Validate file types
if not background.content_type.startswith("image/") or not garm_img.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Only image files are allowed")
# Read file contents
background_content = await background.read()
garm_img_content = await garm_img.read()
# Create temporary files with unique names
async with temp_file_manager(background_content, ".png") as background_path:
async with temp_file_manager(garm_img_content, ".png") as garm_img_path:
# Call Gradio API in thread pool
result = await run_gradio_predict(
background_path=background_path,
garm_img_path=garm_img_path,
garment_des=garment_des,
is_checked=is_checked,
is_checked_crop=is_checked_crop,
denoise_steps=denoise_steps,
seed=seed,
category=category
)
return JSONResponse(content={"result": str(result)})
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Request failed: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
# Shutdown thread pool gracefully
@app.on_event("shutdown")
def shutdown_event():
executor.shutdown(wait=True)
logger.info("Thread pool shut down") |