|
from fastapi import FastAPI, File, UploadFile, Request, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel |
|
import httpx |
|
import os |
|
from sse_starlette.sse import EventSourceResponse |
|
import asyncio |
|
|
|
app = FastAPI(title="Virtual Try-On API", description="API to forward images and handle virtual try-on requests") |
|
|
|
|
|
ALLOWED_IMAGE_TYPES = {"image/png", "image/jpeg"} |
|
BASE_URL = "https://kwai-kolors-kolors-virtual-try-on.hf.space" |
|
|
|
class TryOnRequest(BaseModel): |
|
person_image_path: str |
|
garment_image_path: str |
|
session_hash: str |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"Virtual Try-On API"} |
|
|
|
@app.post("/upload") |
|
async def upload(file: UploadFile = File(...)): |
|
"""Upload an image file to the external API.""" |
|
if file.content_type not in ALLOWED_IMAGE_TYPES: |
|
raise HTTPException(status_code=400, detail="Invalid file type. Only PNG or JPEG allowed.") |
|
|
|
target_url = f"{BASE_URL}/upload?upload_id=abcde123456" |
|
headers = {"accept": "*/*",} |
|
|
|
try: |
|
file_content = await file.read() |
|
files = {"files": (file.filename, file_content, file.content_type)} |
|
|
|
async with httpx.AsyncClient() as client: |
|
response = await client.post(target_url, files=files, headers=headers) |
|
response.raise_for_status() |
|
return {"status": "success", "response": response.json()} |
|
|
|
except httpx.HTTPStatusError as e: |
|
raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text)) |
|
except httpx.HTTPError as e: |
|
raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") |
|
|
|
@app.post("/join") |
|
async def join(request: TryOnRequest): |
|
"""Process virtual try-on request with person and garment images.""" |
|
url = f"{BASE_URL}/queue/join" |
|
headers = { |
|
"accept": "*/*", |
|
"content-type": "application/json" |
|
} |
|
|
|
payload = { |
|
"data": [ |
|
{ |
|
"path": request.person_image_path, |
|
"url": f"{BASE_URL}/file={request.person_image_path}", |
|
"orig_name": os.path.basename(request.person_image_path), |
|
"size": None, |
|
"mime_type": None, |
|
"is_stream": False, |
|
"meta": {"_type": "gradio.FileData"} |
|
}, |
|
{ |
|
"path": request.garment_image_path, |
|
"url": f"{BASE_URL}/file={request.garment_image_path}", |
|
"orig_name": os.path.basename(request.garment_image_path), |
|
"size": None, |
|
"mime_type": None, |
|
"is_stream": False, |
|
"meta": {"_type": "gradio.FileData"} |
|
}, |
|
0, |
|
True |
|
], |
|
"event_data": None, |
|
"fn_index": 2, |
|
"trigger_id": 26, |
|
"session_hash": request.session_hash |
|
} |
|
|
|
try: |
|
async with httpx.AsyncClient() as client: |
|
response = await client.post(url, headers=headers, json=payload) |
|
response.raise_for_status() |
|
return response.json() |
|
|
|
except httpx.HTTPStatusError as e: |
|
raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text)) |
|
except httpx.HTTPError as e: |
|
raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") |
|
|
|
async def event_generator(session_hash: str): |
|
"""Generate Server-Sent Events for the given session hash.""" |
|
sse_url = f"{BASE_URL}/queue/data?session_hash={session_hash}" |
|
headers = { |
|
"accept": "text/event-stream", |
|
"content-type": "application/json" |
|
} |
|
|
|
try: |
|
async with httpx.AsyncClient(timeout=None) as client: |
|
async with client.stream("GET", sse_url, headers=headers) as response: |
|
async for line in response.aiter_lines(): |
|
if line.startswith("data:"): |
|
yield line.removeprefix("data:").strip() |
|
await asyncio.sleep(0.01) |
|
except Exception as e: |
|
yield f"data: {{'error': 'SSE stream error: {str(e)}'}}" |
|
|
|
@app.get("/sse") |
|
async def sse_proxy(session_hash: str, request: Request): |
|
"""Proxy Server-Sent Events for the virtual try-on process.""" |
|
return EventSourceResponse(event_generator(session_hash)) |