|
from fastapi import FastAPI, HTTPException, UploadFile, File |
|
from pydantic import BaseModel |
|
import requests |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import os |
|
import uuid |
|
from pathlib import Path |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class TryOnRequest(BaseModel): |
|
garmentDesc: str |
|
category: str |
|
|
|
|
|
UPLOAD_DIR = Path("/tmp/gradio") |
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
async def save_file_and_get_url(file: UploadFile) -> str: |
|
try: |
|
|
|
file_extension = file.filename.split(".")[-1] |
|
unique_filename = f"{uuid.uuid4()}.{file_extension}" |
|
file_path = UPLOAD_DIR / unique_filename |
|
|
|
|
|
logger.info(f"Saving file to {file_path}") |
|
with file_path.open("wb") as buffer: |
|
content = await file.read() |
|
buffer.write(content) |
|
|
|
|
|
if not file_path.exists(): |
|
logger.error(f"File {file_path} was not saved correctly") |
|
raise HTTPException(status_code=500, detail="Failed to save file") |
|
|
|
|
|
|
|
public_url = f"https://tejani-tryapi.hf.space/file={str(file_path)}" |
|
logger.info(f"Generated public URL: {public_url}") |
|
|
|
|
|
try: |
|
response = requests.head(public_url, timeout=5) |
|
if response.status_code != 200: |
|
logger.warning(f"Public URL {public_url} returned status {response.status_code}") |
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Failed to access public URL {public_url}: {str(e)}") |
|
|
|
return public_url |
|
except Exception as e: |
|
logger.error(f"Error in save_file_and_get_url: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") |
|
|
|
|
|
@app.post("/try-on") |
|
async def try_on( |
|
human_img: UploadFile = File(...), |
|
garment: UploadFile = File(...), |
|
garment_desc: str = "", |
|
category: str = "upper_body" |
|
): |
|
try: |
|
|
|
human_img_url = await save_file_and_get_url(human_img) |
|
garment_url = await save_file_and_get_url(garment) |
|
|
|
|
|
url = "https://changeclothesai.online/api/try-on/edge" |
|
|
|
headers = { |
|
"accept": "*/*", |
|
"f": "sdfdsfsKaVgUoxa5j1jzcFtziPx", |
|
} |
|
|
|
data = { |
|
"humanImg": human_img_url, |
|
"garment": garment_url, |
|
"garmentDesc": garment_desc, |
|
"category": category |
|
} |
|
|
|
logger.info(f"Forwarding request to {url} with data: {data}") |
|
|
|
|
|
response = requests.post(url, headers=headers, cookies={}, data=data) |
|
response.raise_for_status() |
|
|
|
return { |
|
"status_code": response.status_code, |
|
"response": response.json() if response.headers.get('content-type') == 'application/json' else response.text, |
|
"human_img_url": human_img_url, |
|
"garment_url": garment_url |
|
} |
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Error forwarding request: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error forwarding request: {str(e)}") |
|
except Exception as e: |
|
logger.error(f"Error in try_on endpoint: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") |
|
|
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "FastAPI proxy for try-on API with file upload is running"} |
|
|
|
|
|
@app.get("/list-files") |
|
async def list_files(): |
|
try: |
|
files = [str(f) for f in UPLOAD_DIR.glob("*") if f.is_file()] |
|
logger.info(f"Files in {UPLOAD_DIR}: {files}") |
|
return {"files": files} |
|
except Exception as e: |
|
logger.error(f"Error listing files: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error listing files: {str(e)}") |