import time import logging import os import sys import subprocess from contextlib import asynccontextmanager from typing import List from enum import Enum from pydantic import BaseModel # Install required packages def install_packages(): """Install required packages using pip""" packages = [ "fastapi", "uvicorn[standard]", "pillow", "huggingface_hub", "pydantic" ] for package in packages: try: # Check if package is already installed if package == "uvicorn[standard]": __import__("uvicorn") elif package == "huggingface_hub": __import__("huggingface_hub") else: __import__(package.replace("-", "_")) print(f"{package} already installed") except ImportError: print(f"Installing {package}...") subprocess.check_call([sys.executable, "-m", "pip", "install", package]) # Install packages before importing install_packages() import uvicorn from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse # Define models directly in the file class ResponseFormat(str, Enum): URL = "url" B64_JSON = "b64_json" class ImageGenerationRequest(BaseModel): prompt: str model: str = "dall-e-3" n: int = 1 size: str = "1024x1024" quality: str = "standard" response_format: ResponseFormat = ResponseFormat.URL class ImageData(BaseModel): url: str = None b64_json: str = None revised_prompt: str = None class ImageGenerationResponse(BaseModel): created: int data: List[ImageData] class ErrorResponse(BaseModel): error: dict class ModelInfo(BaseModel): id: str created: int owned_by: str class ModelsResponse(BaseModel): data: List[ModelInfo] # Import the modified image generator from image_generator import ImageGenerator # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global image generator instance image_generator = None @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan management""" global image_generator logger.info("Starting TTI Frame API...") # Initialize image generator hf_token = os.getenv("HF_TOKEN") if not hf_token: logger.warning("HF_TOKEN environment variable not set. Image generation may fail.") image_generator = ImageGenerator(hf_token=hf_token) # Set base URL for serving images base_url = os.getenv("BASE_URL", "http://localhost:8000") image_generator.set_config(base_url=base_url) # Mount the temporary directory for static files app.mount("/images", StaticFiles(directory=image_generator.output_dir), name="images") logger.info(f"Image generator initialized with output directory: {image_generator.output_dir}") yield logger.info("Shutting down TTI Frame API...") if image_generator: image_generator.cleanup() # Create FastAPI app app = FastAPI( title="TTI Frame - OpenAI Compatible Text-to-Image API", description="A FastAPI wrapper providing OpenAI-compatible endpoints for text-to-image generation", version="1.0.0", lifespan=lifespan ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure as needed allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """Root endpoint""" return { "message": "TTI Frame - OpenAI Compatible Text-to-Image API", "version": "1.0.0", "docs": "/docs", "output_dir": image_generator.output_dir if image_generator else "Not initialized" } @app.get("/v1/models", response_model=ModelsResponse) async def list_models(): """List available models (OpenAI compatible)""" models = [ ModelInfo( id="dall-e-3", created=1677649963, owned_by="tti-frame" ), ModelInfo( id="dall-e-2", created=1677649963, owned_by="tti-frame" ), ModelInfo( id="black-forest-labs/flux-schnell", created=1677649963, owned_by="tti-frame" ) ] return ModelsResponse(data=models) @app.post("/v1/images/generations", response_model=ImageGenerationResponse) async def create_image(request: ImageGenerationRequest): """ Generate images from text prompts (OpenAI compatible) Creates images based on a text prompt using advanced diffusion models. Supports various sizes, qualities, and response formats. """ if not image_generator: raise HTTPException( status_code=500, detail="Image generator not initialized. Check HF_TOKEN environment variable." ) try: logger.info(f"Received image generation request: {request.prompt[:50]}...") # Validate request if not request.prompt or not request.prompt.strip(): raise HTTPException( status_code=400, detail="Prompt cannot be empty" ) if len(request.prompt) > 4000: raise HTTPException( status_code=400, detail="Prompt too long. Maximum 4000 characters allowed." ) # Map OpenAI model names to HuggingFace models model_mapping = { "dall-e-3": "black-forest-labs/flux-schnell", "dall-e-2": "black-forest-labs/flux-schnell", } # Update request model if needed if request.model in model_mapping: request.model = model_mapping[request.model] # Generate images image_data = await image_generator.generate_images(request) response = ImageGenerationResponse( created=int(time.time()), data=image_data ) logger.info(f"Successfully generated {len(image_data)} images") return response except HTTPException: raise except Exception as e: logger.error(f"Image generation failed: {e}") raise HTTPException( status_code=500, detail=f"Image generation failed: {str(e)}" ) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "timestamp": int(time.time()), "generator_initialized": image_generator is not None, "output_dir": image_generator.output_dir if image_generator else None } @app.get("/config") async def get_config(): """Get current configuration""" if not image_generator: return {"error": "Image generator not initialized"} return { "output_dir": image_generator.output_dir, "base_url": image_generator.base_url, "default_model": image_generator.default_model, "hf_token_set": bool(image_generator.hf_token) } @app.post("/config") async def update_config(hf_token: str = None, base_url: str = None, default_model: str = None): """Update configuration""" if not image_generator: raise HTTPException(status_code=500, detail="Image generator not initialized") image_generator.set_config( hf_token=hf_token, base_url=base_url, default_model=default_model ) return {"message": "Configuration updated successfully"} @app.exception_handler(Exception) async def global_exception_handler(request, exc): """Global exception handler""" logger.error(f"Unhandled exception: {exc}") return JSONResponse( status_code=500, content=ErrorResponse( error={ "message": "Internal server error", "type": "server_error", "code": "internal_error" } ).dict() ) if __name__ == "__main__": # Set environment variables if not already set if not os.getenv("HF_TOKEN"): print("Warning: HF_TOKEN environment variable not set.") print("Please set it with: export HF_TOKEN=your_huggingface_token") uvicorn.run( "main:app", host="0.0.0.0", port=8000, reload=True, log_level="info" )