File size: 5,760 Bytes
1e171d4
78e936b
5038bf8
 
 
 
 
78e936b
 
5038bf8
1e171d4
 
 
 
 
 
 
 
 
 
5038bf8
 
 
00c7d40
5038bf8
1e171d4
a42fda6
1e171d4
 
 
5038bf8
45aa957
1e171d4
5038bf8
 
 
 
1e171d4
5038bf8
1e171d4
 
 
 
 
 
5038bf8
1e171d4
5038bf8
1e171d4
5038bf8
f158f6d
a42fda6
1e171d4
 
5038bf8
 
1e171d4
5038bf8
 
 
 
 
 
1e171d4
 
5038bf8
 
 
 
 
 
 
1e171d4
 
5038bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e171d4
5038bf8
1e171d4
 
 
 
 
 
 
5038bf8
1e171d4
 
 
5038bf8
 
1e171d4
5038bf8
 
1e171d4
 
 
 
 
 
 
 
 
5038bf8
1e171d4
5038bf8
1e171d4
78e936b
 
48d8c2c
78e936b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from fastapi import FastAPI, File, UploadFile, Request, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
import httpx
import os
from sse_starlette.sse import EventSourceResponse
import asyncio
import urllib.request
import io

app = FastAPI(title="Virtual Try-On API", description="API to forward images and handle virtual try-on requests")

# Configuration
ALLOWED_IMAGE_TYPES = {"image/png", "image/jpeg"}
BASE_URL = "https://kwai-kolors-kolors-virtual-try-on.hf.space"

class TryOnRequest(BaseModel):
    person_image_path: str
    garment_image_path: str
    session_hash: str

@app.get("/")
async def root():
    return {"Virtual Try-On API"}

@app.post("/upload")
async def upload(file: UploadFile = File(...)):
    """Upload an image file to the external API."""
    if file.content_type not in ALLOWED_IMAGE_TYPES:
        raise HTTPException(status_code=400, detail="Invalid file type. Only PNG or JPEG allowed.")

    target_url = f"{BASE_URL}/upload?upload_id=abcde123456"
    headers = {"accept": "*/*",}

    try:
        file_content = await file.read()
        files = {"files": (file.filename, file_content, file.content_type)}
        
        async with httpx.AsyncClient() as client:
            response = await client.post(target_url, files=files, headers=headers)
            response.raise_for_status()
            return {"status": "success", "response": response.json()}
            
    except httpx.HTTPStatusError as e:
        raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
    except httpx.HTTPError as e:
        raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")

@app.post("/join")
async def join(request: TryOnRequest):
    """Process virtual try-on request with person and garment images."""
    url = f"{BASE_URL}/queue/join"
    headers = {
        "accept": "*/*",
        "content-type": "application/json"
    }

    payload = {
        "data": [
            {
                "path": request.person_image_path,
                "url": f"{BASE_URL}/file={request.person_image_path}",
                "orig_name": os.path.basename(request.person_image_path),
                "size": None,
                "mime_type": None,
                "is_stream": False,
                "meta": {"_type": "gradio.FileData"}
            },
            {
                "path": request.garment_image_path,
                "url": f"{BASE_URL}/file={request.garment_image_path}",
                "orig_name": os.path.basename(request.garment_image_path),
                "size": None,
                "mime_type": None,
                "is_stream": False,
                "meta": {"_type": "gradio.FileData"}
            },
            0,
            True
        ],
        "event_data": None,
        "fn_index": 2,
        "trigger_id": 26,
        "session_hash": request.session_hash
    }

    try:
        async with httpx.AsyncClient() as client:
            response = await client.post(url, headers=headers, json=payload)
            response.raise_for_status()
            return response.json()
            
    except httpx.HTTPStatusError as e:
        raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
    except httpx.HTTPError as e:
        raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")

async def event_generator(session_hash: str):
    """Generate Server-Sent Events for the given session hash."""
    sse_url = f"{BASE_URL}/queue/data?session_hash={session_hash}"
    headers = {
        "accept": "text/event-stream",
        "content-type": "application/json"
    }

    try:
        async with httpx.AsyncClient(timeout=None) as client:
            async with client.stream("GET", sse_url, headers=headers) as response:
                async for line in response.aiter_lines():
                    if line.startswith("data:"):
                        yield line.removeprefix("data:").strip()
                    await asyncio.sleep(0.01)
    except Exception as e:
        yield f"data: {{'error': 'SSE stream error: {str(e)}'}}"

@app.get("/sse")
async def sse_proxy(session_hash: str, request: Request):
    """Proxy Server-Sent Events for the virtual try-on process."""
    return EventSourceResponse(event_generator(session_hash))

@app.get("/download-image")
async def download_image(path: str):
    """Download and stream an image from the external API."""
    try:
        async with httpx.AsyncClient() as client:
            response = await client.get("https://kwai-kolors-kolors-virtual-try-on.hf.space/file="+path)
            response.raise_for_status()
            
            # Create a BytesIO object for streaming
            image_stream = io.BytesIO(response.content)
            
            # Return StreamingResponse with appropriate media type
            return StreamingResponse(
                image_stream,
                media_type="image/webp",
                headers={
                    "Content-Disposition": "inline; filename=image.webp"
                }
            )
    except httpx.HTTPStatusError as e:
        raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
    except httpx.HTTPError as e:
        raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")