TryOnOther / app.py
tejani's picture
Update app.py
6519940 verified
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from gradio_client import Client, handle_file
from fastapi.responses import JSONResponse
import tempfile
import os
import uuid
import asyncio
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Gradio API URL
GRADIO_API_URL = "jallenjia/Change-Clothes-AI"
# Thread pool for Gradio API calls (to avoid blocking async loop)
executor = ThreadPoolExecutor(max_workers=None) # Changed from max_workers=10 to max_workers=None to remove limit
# Context manager for temporary files
@asynccontextmanager
async def temp_file_manager(file_content: bytes, suffix: str = ".png"):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=f"tryon_{uuid.uuid4()}_")
try:
temp_file.write(file_content)
temp_file.close()
yield temp_file.name
finally:
try:
if os.path.exists(temp_file.name):
os.unlink(temp_file.name)
except Exception as e:
logger.error(f"Failed to delete temp file {temp_file.name}: {e}")
# Run Gradio API call in thread pool to avoid blocking
async def run_gradio_predict(
background_path: str,
garm_img_path: str,
garment_des: str,
is_checked: bool,
is_checked_crop: bool,
denoise_steps: int,
seed: int,
category: str
):
loop = asyncio.get_event_loop()
try:
client = Client(GRADIO_API_URL)
result = await loop.run_in_executor(
executor,
lambda: client.predict(
dict={
"background": handle_file(background_path),
"layers": [],
"composite": None
},
garm_img=handle_file(garm_img_path),
garment_des=garment_des,
is_checked=is_checked,
is_checked_crop=is_checked_crop,
denoise_steps=denoise_steps,
seed=seed,
category=category,
api_name="/tryon"
)
)
return result
except Exception as e:
logger.error(f"Gradio API error: {e}")
raise HTTPException(status_code=500, detail=f"Gradio API error: {str(e)}")
@app.post("/tryon")
async def tryon(
background: UploadFile = File(...),
garm_img: UploadFile = File(...),
garment_des: str = Form("navy blue polo shirt"),
is_checked: bool = Form(True),
is_checked_crop: bool = Form(False),
denoise_steps: int = Form(30),
seed: int = Form(42),
category: str = Form("upper_body")
):
try:
# Validate file types
if not background.content_type.startswith("image/") or not garm_img.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="Only image files are allowed")
# Read file contents
background_content = await background.read()
garm_img_content = await garm_img.read()
# Create temporary files with unique names
async with temp_file_manager(background_content, ".png") as background_path:
async with temp_file_manager(garm_img_content, ".png") as garm_img_path:
# Call Gradio API in thread pool
result = await run_gradio_predict(
background_path=background_path,
garm_img_path=garm_img_path,
garment_des=garment_des,
is_checked=is_checked,
is_checked_crop=is_checked_crop,
denoise_steps=denoise_steps,
seed=seed,
category=category
)
return JSONResponse(content={"result": str(result)})
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Request failed: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
# Shutdown thread pool gracefully
@app.on_event("shutdown")
def shutdown_event():
executor.shutdown(wait=True)
logger.info("Thread pool shut down")