File size: 4,563 Bytes
1e171d4
5038bf8
 
 
 
 
 
 
1e171d4
 
 
 
 
 
 
 
 
 
5038bf8
 
 
00c7d40
5038bf8
1e171d4
a42fda6
1e171d4
 
 
5038bf8
45aa957
1e171d4
5038bf8
 
 
 
1e171d4
5038bf8
1e171d4
 
 
 
 
 
5038bf8
1e171d4
5038bf8
1e171d4
5038bf8
f158f6d
a42fda6
1e171d4
 
5038bf8
 
1e171d4
5038bf8
 
 
 
 
 
1e171d4
 
5038bf8
 
 
 
 
 
 
1e171d4
 
5038bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e171d4
5038bf8
1e171d4
 
 
 
 
 
 
5038bf8
1e171d4
 
 
5038bf8
 
1e171d4
5038bf8
 
1e171d4
 
 
 
 
 
 
 
 
5038bf8
1e171d4
5038bf8
1e171d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from fastapi import FastAPI, File, UploadFile, Request, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import httpx
import os
from sse_starlette.sse import EventSourceResponse
import asyncio

app = FastAPI(title="Virtual Try-On API", description="API to forward images and handle virtual try-on requests")

# Configuration
ALLOWED_IMAGE_TYPES = {"image/png", "image/jpeg"}
BASE_URL = "https://kwai-kolors-kolors-virtual-try-on.hf.space"

class TryOnRequest(BaseModel):
    person_image_path: str
    garment_image_path: str
    session_hash: str

@app.get("/")
async def root():
    return {"Virtual Try-On API"}

@app.post("/upload")
async def upload(file: UploadFile = File(...)):
    """Upload an image file to the external API."""
    if file.content_type not in ALLOWED_IMAGE_TYPES:
        raise HTTPException(status_code=400, detail="Invalid file type. Only PNG or JPEG allowed.")

    target_url = f"{BASE_URL}/upload?upload_id=abcde123456"
    headers = {"accept": "*/*",}

    try:
        file_content = await file.read()
        files = {"files": (file.filename, file_content, file.content_type)}
        
        async with httpx.AsyncClient() as client:
            response = await client.post(target_url, files=files, headers=headers)
            response.raise_for_status()
            return {"status": "success", "response": response.json()}
            
    except httpx.HTTPStatusError as e:
        raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
    except httpx.HTTPError as e:
        raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")

@app.post("/join")
async def join(request: TryOnRequest):
    """Process virtual try-on request with person and garment images."""
    url = f"{BASE_URL}/queue/join"
    headers = {
        "accept": "*/*",
        "content-type": "application/json"
    }

    payload = {
        "data": [
            {
                "path": request.person_image_path,
                "url": f"{BASE_URL}/file={request.person_image_path}",
                "orig_name": os.path.basename(request.person_image_path),
                "size": None,
                "mime_type": None,
                "is_stream": False,
                "meta": {"_type": "gradio.FileData"}
            },
            {
                "path": request.garment_image_path,
                "url": f"{BASE_URL}/file={request.garment_image_path}",
                "orig_name": os.path.basename(request.garment_image_path),
                "size": None,
                "mime_type": None,
                "is_stream": False,
                "meta": {"_type": "gradio.FileData"}
            },
            0,
            True
        ],
        "event_data": None,
        "fn_index": 2,
        "trigger_id": 26,
        "session_hash": request.session_hash
    }

    try:
        async with httpx.AsyncClient() as client:
            response = await client.post(url, headers=headers, json=payload)
            response.raise_for_status()
            return response.json()
            
    except httpx.HTTPStatusError as e:
        raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
    except httpx.HTTPError as e:
        raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")

async def event_generator(session_hash: str):
    """Generate Server-Sent Events for the given session hash."""
    sse_url = f"{BASE_URL}/queue/data?session_hash={session_hash}"
    headers = {
        "accept": "text/event-stream",
        "content-type": "application/json"
    }

    try:
        async with httpx.AsyncClient(timeout=None) as client:
            async with client.stream("GET", sse_url, headers=headers) as response:
                async for line in response.aiter_lines():
                    if line.startswith("data:"):
                        yield line.removeprefix("data:").strip()
                    await asyncio.sleep(0.01)
    except Exception as e:
        yield f"data: {{'error': 'SSE stream error: {str(e)}'}}"

@app.get("/sse")
async def sse_proxy(session_hash: str, request: Request):
    """Proxy Server-Sent Events for the virtual try-on process."""
    return EventSourceResponse(event_generator(session_hash))