File size: 4,563 Bytes
1e171d4 5038bf8 1e171d4 5038bf8 00c7d40 5038bf8 1e171d4 a42fda6 1e171d4 5038bf8 45aa957 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 f158f6d a42fda6 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 5038bf8 1e171d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from fastapi import FastAPI, File, UploadFile, Request, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import httpx
import os
from sse_starlette.sse import EventSourceResponse
import asyncio
app = FastAPI(title="Virtual Try-On API", description="API to forward images and handle virtual try-on requests")
# Configuration
ALLOWED_IMAGE_TYPES = {"image/png", "image/jpeg"}
BASE_URL = "https://kwai-kolors-kolors-virtual-try-on.hf.space"
class TryOnRequest(BaseModel):
person_image_path: str
garment_image_path: str
session_hash: str
@app.get("/")
async def root():
return {"Virtual Try-On API"}
@app.post("/upload")
async def upload(file: UploadFile = File(...)):
"""Upload an image file to the external API."""
if file.content_type not in ALLOWED_IMAGE_TYPES:
raise HTTPException(status_code=400, detail="Invalid file type. Only PNG or JPEG allowed.")
target_url = f"{BASE_URL}/upload?upload_id=abcde123456"
headers = {"accept": "*/*",}
try:
file_content = await file.read()
files = {"files": (file.filename, file_content, file.content_type)}
async with httpx.AsyncClient() as client:
response = await client.post(target_url, files=files, headers=headers)
response.raise_for_status()
return {"status": "success", "response": response.json()}
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
@app.post("/join")
async def join(request: TryOnRequest):
"""Process virtual try-on request with person and garment images."""
url = f"{BASE_URL}/queue/join"
headers = {
"accept": "*/*",
"content-type": "application/json"
}
payload = {
"data": [
{
"path": request.person_image_path,
"url": f"{BASE_URL}/file={request.person_image_path}",
"orig_name": os.path.basename(request.person_image_path),
"size": None,
"mime_type": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"}
},
{
"path": request.garment_image_path,
"url": f"{BASE_URL}/file={request.garment_image_path}",
"orig_name": os.path.basename(request.garment_image_path),
"size": None,
"mime_type": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"}
},
0,
True
],
"event_data": None,
"fn_index": 2,
"trigger_id": 26,
"session_hash": request.session_hash
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
async def event_generator(session_hash: str):
"""Generate Server-Sent Events for the given session hash."""
sse_url = f"{BASE_URL}/queue/data?session_hash={session_hash}"
headers = {
"accept": "text/event-stream",
"content-type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream("GET", sse_url, headers=headers) as response:
async for line in response.aiter_lines():
if line.startswith("data:"):
yield line.removeprefix("data:").strip()
await asyncio.sleep(0.01)
except Exception as e:
yield f"data: {{'error': 'SSE stream error: {str(e)}'}}"
@app.get("/sse")
async def sse_proxy(session_hash: str, request: Request):
"""Proxy Server-Sent Events for the virtual try-on process."""
return EventSourceResponse(event_generator(session_hash)) |