Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
from gradio_client import Client, handle_file | |
from fastapi.responses import JSONResponse | |
import tempfile | |
import os | |
import uuid | |
import asyncio | |
from contextlib import asynccontextmanager | |
from concurrent.futures import ThreadPoolExecutor | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Gradio API URL | |
GRADIO_API_URL = "jallenjia/Change-Clothes-AI" | |
# Thread pool for Gradio API calls (to avoid blocking async loop) | |
executor = ThreadPoolExecutor(max_workers=None) # Changed from max_workers=10 to max_workers=None to remove limit | |
# Context manager for temporary files | |
async def temp_file_manager(file_content: bytes, suffix: str = ".png"): | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=f"tryon_{uuid.uuid4()}_") | |
try: | |
temp_file.write(file_content) | |
temp_file.close() | |
yield temp_file.name | |
finally: | |
try: | |
if os.path.exists(temp_file.name): | |
os.unlink(temp_file.name) | |
except Exception as e: | |
logger.error(f"Failed to delete temp file {temp_file.name}: {e}") | |
# Run Gradio API call in thread pool to avoid blocking | |
async def run_gradio_predict( | |
background_path: str, | |
garm_img_path: str, | |
garment_des: str, | |
is_checked: bool, | |
is_checked_crop: bool, | |
denoise_steps: int, | |
seed: int, | |
category: str | |
): | |
loop = asyncio.get_event_loop() | |
try: | |
client = Client(GRADIO_API_URL) | |
result = await loop.run_in_executor( | |
executor, | |
lambda: client.predict( | |
dict={ | |
"background": handle_file(background_path), | |
"layers": [], | |
"composite": None | |
}, | |
garm_img=handle_file(garm_img_path), | |
garment_des=garment_des, | |
is_checked=is_checked, | |
is_checked_crop=is_checked_crop, | |
denoise_steps=denoise_steps, | |
seed=seed, | |
category=category, | |
api_name="/tryon" | |
) | |
) | |
return result | |
except Exception as e: | |
logger.error(f"Gradio API error: {e}") | |
raise HTTPException(status_code=500, detail=f"Gradio API error: {str(e)}") | |
async def tryon( | |
background: UploadFile = File(...), | |
garm_img: UploadFile = File(...), | |
garment_des: str = Form("navy blue polo shirt"), | |
is_checked: bool = Form(True), | |
is_checked_crop: bool = Form(False), | |
denoise_steps: int = Form(30), | |
seed: int = Form(42), | |
category: str = Form("upper_body") | |
): | |
try: | |
# Validate file types | |
if not background.content_type.startswith("image/") or not garm_img.content_type.startswith("image/"): | |
raise HTTPException(status_code=400, detail="Only image files are allowed") | |
# Read file contents | |
background_content = await background.read() | |
garm_img_content = await garm_img.read() | |
# Create temporary files with unique names | |
async with temp_file_manager(background_content, ".png") as background_path: | |
async with temp_file_manager(garm_img_content, ".png") as garm_img_path: | |
# Call Gradio API in thread pool | |
result = await run_gradio_predict( | |
background_path=background_path, | |
garm_img_path=garm_img_path, | |
garment_des=garment_des, | |
is_checked=is_checked, | |
is_checked_crop=is_checked_crop, | |
denoise_steps=denoise_steps, | |
seed=seed, | |
category=category | |
) | |
return JSONResponse(content={"result": str(result)}) | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
logger.error(f"Request failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
# Shutdown thread pool gracefully | |
def shutdown_event(): | |
executor.shutdown(wait=True) | |
logger.info("Thread pool shut down") |