Hook-old / app.py
tejani's picture
Update app.py
45aa957 verified
from fastapi import FastAPI, File, UploadFile, Request, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import httpx
import os
from sse_starlette.sse import EventSourceResponse
import asyncio
app = FastAPI(title="Virtual Try-On API", description="API to forward images and handle virtual try-on requests")
# Configuration
ALLOWED_IMAGE_TYPES = {"image/png", "image/jpeg"}
BASE_URL = "https://kwai-kolors-kolors-virtual-try-on.hf.space"
class TryOnRequest(BaseModel):
person_image_path: str
garment_image_path: str
session_hash: str
@app.get("/")
async def root():
return {"Virtual Try-On API"}
@app.post("/upload")
async def upload(file: UploadFile = File(...)):
"""Upload an image file to the external API."""
if file.content_type not in ALLOWED_IMAGE_TYPES:
raise HTTPException(status_code=400, detail="Invalid file type. Only PNG or JPEG allowed.")
target_url = f"{BASE_URL}/upload?upload_id=abcde123456"
headers = {"accept": "*/*",}
try:
file_content = await file.read()
files = {"files": (file.filename, file_content, file.content_type)}
async with httpx.AsyncClient() as client:
response = await client.post(target_url, files=files, headers=headers)
response.raise_for_status()
return {"status": "success", "response": response.json()}
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
@app.post("/join")
async def join(request: TryOnRequest):
"""Process virtual try-on request with person and garment images."""
url = f"{BASE_URL}/queue/join"
headers = {
"accept": "*/*",
"content-type": "application/json"
}
payload = {
"data": [
{
"path": request.person_image_path,
"url": f"{BASE_URL}/file={request.person_image_path}",
"orig_name": os.path.basename(request.person_image_path),
"size": None,
"mime_type": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"}
},
{
"path": request.garment_image_path,
"url": f"{BASE_URL}/file={request.garment_image_path}",
"orig_name": os.path.basename(request.garment_image_path),
"size": None,
"mime_type": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"}
},
0,
True
],
"event_data": None,
"fn_index": 2,
"trigger_id": 26,
"session_hash": request.session_hash
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=e.response.status_code, detail=str(e.response.text))
except httpx.HTTPError as e:
raise HTTPException(status_code=500, detail=f"HTTP error occurred: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
async def event_generator(session_hash: str):
"""Generate Server-Sent Events for the given session hash."""
sse_url = f"{BASE_URL}/queue/data?session_hash={session_hash}"
headers = {
"accept": "text/event-stream",
"content-type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream("GET", sse_url, headers=headers) as response:
async for line in response.aiter_lines():
if line.startswith("data:"):
yield line.removeprefix("data:").strip()
await asyncio.sleep(0.01)
except Exception as e:
yield f"data: {{'error': 'SSE stream error: {str(e)}'}}"
@app.get("/sse")
async def sse_proxy(session_hash: str, request: Request):
"""Proxy Server-Sent Events for the virtual try-on process."""
return EventSourceResponse(event_generator(session_hash))