from fastapi import FastAPI, File, UploadFile, Form, HTTPException from gradio_client import Client, handle_file from fastapi.responses import JSONResponse import tempfile import os import uuid import asyncio from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Gradio API URL GRADIO_API_URL = "jallenjia/Change-Clothes-AI" # Thread pool for Gradio API calls (to avoid blocking async loop) executor = ThreadPoolExecutor(max_workers=None) # Changed from max_workers=10 to max_workers=None to remove limit # Context manager for temporary files @asynccontextmanager async def temp_file_manager(file_content: bytes, suffix: str = ".png"): temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=f"tryon_{uuid.uuid4()}_") try: temp_file.write(file_content) temp_file.close() yield temp_file.name finally: try: if os.path.exists(temp_file.name): os.unlink(temp_file.name) except Exception as e: logger.error(f"Failed to delete temp file {temp_file.name}: {e}") # Run Gradio API call in thread pool to avoid blocking async def run_gradio_predict( background_path: str, garm_img_path: str, garment_des: str, is_checked: bool, is_checked_crop: bool, denoise_steps: int, seed: int, category: str ): loop = asyncio.get_event_loop() try: client = Client(GRADIO_API_URL) result = await loop.run_in_executor( executor, lambda: client.predict( dict={ "background": handle_file(background_path), "layers": [], "composite": None }, garm_img=handle_file(garm_img_path), garment_des=garment_des, is_checked=is_checked, is_checked_crop=is_checked_crop, denoise_steps=denoise_steps, seed=seed, category=category, api_name="/tryon" ) ) return result except Exception as e: logger.error(f"Gradio API error: {e}") raise HTTPException(status_code=500, detail=f"Gradio API error: {str(e)}") @app.post("/tryon") async def tryon( background: UploadFile = File(...), garm_img: UploadFile = File(...), garment_des: str = Form("navy blue polo shirt"), is_checked: bool = Form(True), is_checked_crop: bool = Form(False), denoise_steps: int = Form(30), seed: int = Form(42), category: str = Form("upper_body") ): try: # Validate file types if not background.content_type.startswith("image/") or not garm_img.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Only image files are allowed") # Read file contents background_content = await background.read() garm_img_content = await garm_img.read() # Create temporary files with unique names async with temp_file_manager(background_content, ".png") as background_path: async with temp_file_manager(garm_img_content, ".png") as garm_img_path: # Call Gradio API in thread pool result = await run_gradio_predict( background_path=background_path, garm_img_path=garm_img_path, garment_des=garment_des, is_checked=is_checked, is_checked_crop=is_checked_crop, denoise_steps=denoise_steps, seed=seed, category=category ) return JSONResponse(content={"result": str(result)}) except HTTPException as e: raise e except Exception as e: logger.error(f"Request failed: {e}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") # Shutdown thread pool gracefully @app.on_event("shutdown") def shutdown_event(): executor.shutdown(wait=True) logger.info("Thread pool shut down")