samu's picture
1st
7c7ef49
raw
history blame
2.93 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Literal, Optional, Dict, Any, Union
from backend.utils import async_generate_text_and_image, async_generate_with_image_input
from backend.category_config import CATEGORY_CONFIGS
from backend.logging_utils import log_category_usage, get_category_statistics
import backend.config as config # keep for reference if needed
import traceback
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class TextGenerateRequest(BaseModel):
prompt: str
category: Optional[str] = None
class ImageTextGenerateRequest(BaseModel):
text: Optional[str] = None
image: str
category: Optional[str] = None
class Part(BaseModel):
type: Literal["text", "image"]
data: Union[str, Dict[str, str]] # Can be either a string (for image) or dict (for text)
class GenerationResponse(BaseModel):
results: List[Part]
@app.post("/generate", response_model=GenerationResponse)
async def generate(request: TextGenerateRequest):
"""
Generate text and image from a text prompt with optional category.
"""
success = False
try:
results = []
async for part in async_generate_text_and_image(request.prompt, request.category):
results.append(part)
success = True
return GenerationResponse(results=results)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal error: {e}")
finally:
log_category_usage(request.category, "/generate", success)
@app.post("/generate_with_image", response_model=GenerationResponse)
async def generate_with_image(request: ImageTextGenerateRequest):
"""
Generate text and image given a text and base64 image with optional category.
"""
success = False
try:
results = []
text = request.text if request.text else config.DEFAULT_TEXT
async for part in async_generate_with_image_input(text, request.image, request.category):
results.append(part)
success = True
return GenerationResponse(results=results)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal error: {e}")
finally:
log_category_usage(request.category, "/generate_with_image", success)
@app.get("/categories")
async def get_categories():
"""
Get all available engineering categories with their descriptions and configurations.
"""
return CATEGORY_CONFIGS
@app.get("/category-stats")
async def get_usage_statistics():
"""
Get usage statistics for all categories.
"""
return get_category_statistics()
@app.get("/")
async def read_root():
return {"message": "Image generation API is up"}