|
import time |
|
import logging |
|
import os |
|
import sys |
|
import subprocess |
|
from contextlib import asynccontextmanager |
|
from typing import List |
|
from enum import Enum |
|
from pydantic import BaseModel |
|
|
|
|
|
def install_packages(): |
|
"""Install required packages using pip""" |
|
packages = [ |
|
"fastapi", |
|
"uvicorn[standard]", |
|
"pillow", |
|
"huggingface_hub", |
|
"pydantic" |
|
] |
|
|
|
for package in packages: |
|
try: |
|
|
|
if package == "uvicorn[standard]": |
|
__import__("uvicorn") |
|
elif package == "huggingface_hub": |
|
__import__("huggingface_hub") |
|
else: |
|
__import__(package.replace("-", "_")) |
|
print(f"{package} already installed") |
|
except ImportError: |
|
print(f"Installing {package}...") |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
|
|
|
|
install_packages() |
|
|
|
import uvicorn |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
class ResponseFormat(str, Enum): |
|
URL = "url" |
|
B64_JSON = "b64_json" |
|
|
|
class ImageGenerationRequest(BaseModel): |
|
prompt: str |
|
model: str = "dall-e-3" |
|
n: int = 1 |
|
size: str = "1024x1024" |
|
quality: str = "standard" |
|
response_format: ResponseFormat = ResponseFormat.URL |
|
|
|
class ImageData(BaseModel): |
|
url: str = None |
|
b64_json: str = None |
|
revised_prompt: str = None |
|
|
|
class ImageGenerationResponse(BaseModel): |
|
created: int |
|
data: List[ImageData] |
|
|
|
class ErrorResponse(BaseModel): |
|
error: dict |
|
|
|
class ModelInfo(BaseModel): |
|
id: str |
|
created: int |
|
owned_by: str |
|
|
|
class ModelsResponse(BaseModel): |
|
data: List[ModelInfo] |
|
|
|
|
|
from image_generator import ImageGenerator |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
image_generator = None |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
"""Application lifespan management""" |
|
global image_generator |
|
|
|
logger.info("Starting TTI Frame API...") |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
if not hf_token: |
|
logger.warning("HF_TOKEN environment variable not set. Image generation may fail.") |
|
|
|
image_generator = ImageGenerator(hf_token=hf_token) |
|
|
|
|
|
base_url = os.getenv("BASE_URL", "http://localhost:8000") |
|
image_generator.set_config(base_url=base_url) |
|
|
|
|
|
app.mount("/images", StaticFiles(directory=image_generator.output_dir), name="images") |
|
|
|
logger.info(f"Image generator initialized with output directory: {image_generator.output_dir}") |
|
|
|
yield |
|
|
|
logger.info("Shutting down TTI Frame API...") |
|
if image_generator: |
|
image_generator.cleanup() |
|
|
|
|
|
app = FastAPI( |
|
title="TTI Frame - OpenAI Compatible Text-to-Image API", |
|
description="A FastAPI wrapper providing OpenAI-compatible endpoints for text-to-image generation", |
|
version="1.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""Root endpoint""" |
|
return { |
|
"message": "TTI Frame - OpenAI Compatible Text-to-Image API", |
|
"version": "1.0.0", |
|
"docs": "/docs", |
|
"output_dir": image_generator.output_dir if image_generator else "Not initialized" |
|
} |
|
|
|
@app.get("/v1/models", response_model=ModelsResponse) |
|
async def list_models(): |
|
"""List available models (OpenAI compatible)""" |
|
models = [ |
|
ModelInfo( |
|
id="dall-e-3", |
|
created=1677649963, |
|
owned_by="tti-frame" |
|
), |
|
ModelInfo( |
|
id="dall-e-2", |
|
created=1677649963, |
|
owned_by="tti-frame" |
|
), |
|
ModelInfo( |
|
id="black-forest-labs/flux-schnell", |
|
created=1677649963, |
|
owned_by="tti-frame" |
|
) |
|
] |
|
|
|
return ModelsResponse(data=models) |
|
|
|
@app.post("/v1/images/generations", response_model=ImageGenerationResponse) |
|
async def create_image(request: ImageGenerationRequest): |
|
""" |
|
Generate images from text prompts (OpenAI compatible) |
|
|
|
Creates images based on a text prompt using advanced diffusion models. |
|
Supports various sizes, qualities, and response formats. |
|
""" |
|
if not image_generator: |
|
raise HTTPException( |
|
status_code=500, |
|
detail="Image generator not initialized. Check HF_TOKEN environment variable." |
|
) |
|
|
|
try: |
|
logger.info(f"Received image generation request: {request.prompt[:50]}...") |
|
|
|
|
|
if not request.prompt or not request.prompt.strip(): |
|
raise HTTPException( |
|
status_code=400, |
|
detail="Prompt cannot be empty" |
|
) |
|
|
|
if len(request.prompt) > 4000: |
|
raise HTTPException( |
|
status_code=400, |
|
detail="Prompt too long. Maximum 4000 characters allowed." |
|
) |
|
|
|
|
|
model_mapping = { |
|
"dall-e-3": "black-forest-labs/flux-schnell", |
|
"dall-e-2": "black-forest-labs/flux-schnell", |
|
} |
|
|
|
|
|
if request.model in model_mapping: |
|
request.model = model_mapping[request.model] |
|
|
|
|
|
image_data = await image_generator.generate_images(request) |
|
|
|
response = ImageGenerationResponse( |
|
created=int(time.time()), |
|
data=image_data |
|
) |
|
|
|
logger.info(f"Successfully generated {len(image_data)} images") |
|
return response |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Image generation failed: {e}") |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Image generation failed: {str(e)}" |
|
) |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
return { |
|
"status": "healthy", |
|
"timestamp": int(time.time()), |
|
"generator_initialized": image_generator is not None, |
|
"output_dir": image_generator.output_dir if image_generator else None |
|
} |
|
|
|
@app.get("/config") |
|
async def get_config(): |
|
"""Get current configuration""" |
|
if not image_generator: |
|
return {"error": "Image generator not initialized"} |
|
|
|
return { |
|
"output_dir": image_generator.output_dir, |
|
"base_url": image_generator.base_url, |
|
"default_model": image_generator.default_model, |
|
"hf_token_set": bool(image_generator.hf_token) |
|
} |
|
|
|
@app.post("/config") |
|
async def update_config(hf_token: str = None, base_url: str = None, default_model: str = None): |
|
"""Update configuration""" |
|
if not image_generator: |
|
raise HTTPException(status_code=500, detail="Image generator not initialized") |
|
|
|
image_generator.set_config( |
|
hf_token=hf_token, |
|
base_url=base_url, |
|
default_model=default_model |
|
) |
|
|
|
return {"message": "Configuration updated successfully"} |
|
|
|
@app.exception_handler(Exception) |
|
async def global_exception_handler(request, exc): |
|
"""Global exception handler""" |
|
logger.error(f"Unhandled exception: {exc}") |
|
return JSONResponse( |
|
status_code=500, |
|
content=ErrorResponse( |
|
error={ |
|
"message": "Internal server error", |
|
"type": "server_error", |
|
"code": "internal_error" |
|
} |
|
).dict() |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
if not os.getenv("HF_TOKEN"): |
|
print("Warning: HF_TOKEN environment variable not set.") |
|
print("Please set it with: export HF_TOKEN=your_huggingface_token") |
|
|
|
uvicorn.run( |
|
"main:app", |
|
host="0.0.0.0", |
|
port=8000, |
|
reload=True, |
|
log_level="info" |
|
) |