File size: 4,247 Bytes
6519940
4e51761
6519940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e51761
6519940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e51761
6519940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e51761
6519940
4e51761
6519940
 
4e51761
6519940
 
4e51761
6519940
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from gradio_client import Client, handle_file
from fastapi.responses import JSONResponse
import tempfile
import os
import uuid
import asyncio
from contextlib import asynccontextmanager
from concurrent.futures import ThreadPoolExecutor
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Gradio API URL
GRADIO_API_URL = "jallenjia/Change-Clothes-AI"

# Thread pool for Gradio API calls (to avoid blocking async loop)
executor = ThreadPoolExecutor(max_workers=None)  # Changed from max_workers=10 to max_workers=None to remove limit

# Context manager for temporary files
@asynccontextmanager
async def temp_file_manager(file_content: bytes, suffix: str = ".png"):
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=f"tryon_{uuid.uuid4()}_")
    try:
        temp_file.write(file_content)
        temp_file.close()
        yield temp_file.name
    finally:
        try:
            if os.path.exists(temp_file.name):
                os.unlink(temp_file.name)
        except Exception as e:
            logger.error(f"Failed to delete temp file {temp_file.name}: {e}")

# Run Gradio API call in thread pool to avoid blocking
async def run_gradio_predict(
    background_path: str,
    garm_img_path: str,
    garment_des: str,
    is_checked: bool,
    is_checked_crop: bool,
    denoise_steps: int,
    seed: int,
    category: str
):
    loop = asyncio.get_event_loop()
    try:
        client = Client(GRADIO_API_URL)
        result = await loop.run_in_executor(
            executor,
            lambda: client.predict(
                dict={
                    "background": handle_file(background_path),
                    "layers": [],
                    "composite": None
                },
                garm_img=handle_file(garm_img_path),
                garment_des=garment_des,
                is_checked=is_checked,
                is_checked_crop=is_checked_crop,
                denoise_steps=denoise_steps,
                seed=seed,
                category=category,
                api_name="/tryon"
            )
        )
        return result
    except Exception as e:
        logger.error(f"Gradio API error: {e}")
        raise HTTPException(status_code=500, detail=f"Gradio API error: {str(e)}")

@app.post("/tryon")
async def tryon(
    background: UploadFile = File(...),
    garm_img: UploadFile = File(...),
    garment_des: str = Form("navy blue polo shirt"),
    is_checked: bool = Form(True),
    is_checked_crop: bool = Form(False),
    denoise_steps: int = Form(30),
    seed: int = Form(42),
    category: str = Form("upper_body")
):
    try:
        # Validate file types
        if not background.content_type.startswith("image/") or not garm_img.content_type.startswith("image/"):
            raise HTTPException(status_code=400, detail="Only image files are allowed")

        # Read file contents
        background_content = await background.read()
        garm_img_content = await garm_img.read()

        # Create temporary files with unique names
        async with temp_file_manager(background_content, ".png") as background_path:
            async with temp_file_manager(garm_img_content, ".png") as garm_img_path:
                # Call Gradio API in thread pool
                result = await run_gradio_predict(
                    background_path=background_path,
                    garm_img_path=garm_img_path,
                    garment_des=garment_des,
                    is_checked=is_checked,
                    is_checked_crop=is_checked_crop,
                    denoise_steps=denoise_steps,
                    seed=seed,
                    category=category
                )

        return JSONResponse(content={"result": str(result)})

    except HTTPException as e:
        raise e
    except Exception as e:
        logger.error(f"Request failed: {e}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

# Shutdown thread pool gracefully
@app.on_event("shutdown")
def shutdown_event():
    executor.shutdown(wait=True)
    logger.info("Thread pool shut down")