from fastapi import FastAPI, File, UploadFile, Request, HTTPException from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel import httpx import os from sse_starlette.sse import EventSourceResponse import asyncio import urllib.request import io app = FastAPI(title="Virtual Try-On API", description="API to forward images and handle virtual try-on requests") # Configuration ALLOWED_IMAGE_TYPES = {"image/png", "image/jpeg"} BASE_URL = "https://kwai-kolors-kolors-virtual-try-on.hf.space" class TryOnRequest(BaseModel): person_image_path: str garment_image_path: str session_hash: str @app.get("/") async def root(): return {"Virtual Try-On API"} @app.post("/upload") async def upload(file: UploadFile = File(...)): """Upload an image file to the external API.""" if file.content_type not in ALLOWED_IMAGE_TYPES: raise HTTPException(status_code=400, detail="Invalid file type. Only PNG or JPEG allowed.") target_url = f"{BASE_URL}/upload?upload_id=abcde123456" headers = {"accept": "*/*",} try: file_content = await file.read() files = {"files": (file.filename, file_content, file.content_type)} async with httpx.AsyncClient() as client: response = await client.post(target_url, files=files, headers=headers) response.raise_for_status() return {"status": "success", "response": response.json()} except httpx.HTTPStatusError as e: raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text)) except httpx.HTTPError as e: raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") @app.post("/join") async def join(request: TryOnRequest): """Process virtual try-on request with person and garment images.""" url = f"{BASE_URL}/queue/join" headers = { "accept": "*/*", "content-type": "application/json" } payload = { "data": [ { "path": request.person_image_path, "url": f"{BASE_URL}/file={request.person_image_path}", "orig_name": os.path.basename(request.person_image_path), "size": None, "mime_type": None, "is_stream": False, "meta": {"_type": "gradio.FileData"} }, { "path": request.garment_image_path, "url": f"{BASE_URL}/file={request.garment_image_path}", "orig_name": os.path.basename(request.garment_image_path), "size": None, "mime_type": None, "is_stream": False, "meta": {"_type": "gradio.FileData"} }, 0, True ], "event_data": None, "fn_index": 2, "trigger_id": 26, "session_hash": request.session_hash } try: async with httpx.AsyncClient() as client: response = await client.post(url, headers=headers, json=payload) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text)) except httpx.HTTPError as e: raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") async def event_generator(session_hash: str): """Generate Server-Sent Events for the given session hash.""" sse_url = f"{BASE_URL}/queue/data?session_hash={session_hash}" headers = { "accept": "text/event-stream", "content-type": "application/json" } try: async with httpx.AsyncClient(timeout=None) as client: async with client.stream("GET", sse_url, headers=headers) as response: async for line in response.aiter_lines(): if line.startswith("data:"): yield line.removeprefix("data:").strip() await asyncio.sleep(0.01) except Exception as e: yield f"data: {{'error': 'SSE stream error: {str(e)}'}}" @app.get("/sse") async def sse_proxy(session_hash: str, request: Request): """Proxy Server-Sent Events for the virtual try-on process.""" return EventSourceResponse(event_generator(session_hash)) @app.get("/download-image") async def download_image(path: str): """Download and stream an image from the external API.""" try: async with httpx.AsyncClient() as client: response = await client.get("https://kwai-kolors-kolors-virtual-try-on.hf.space/file="+path) response.raise_for_status() # Create a BytesIO object for streaming image_stream = io.BytesIO(response.content) # Return StreamingResponse with appropriate media type return StreamingResponse( image_stream, media_type="image/webp", headers={ "Content-Disposition": "inline; filename=image.webp" } ) except httpx.HTTPStatusError as e: raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text)) except httpx.HTTPError as e: raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")