Sam3838 commited on
Commit
bdb7403
·
verified ·
1 Parent(s): 3b9f744

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ import uvicorn
4
+ from contextlib import asynccontextmanager
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.staticfiles import StaticFiles
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import JSONResponse
9
+
10
+ from models import (
11
+ ImageGenerationRequest,
12
+ ImageGenerationResponse,
13
+ ErrorResponse,
14
+ ModelsResponse,
15
+ ModelInfo
16
+ )
17
+ from image_generator import image_generator
18
+ from config import config
19
+
20
+ # Setup logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @asynccontextmanager
26
+ async def lifespan(app: FastAPI):
27
+ """Application lifespan management"""
28
+ logger.info("Starting TTI Frame API...")
29
+ yield
30
+ logger.info("Shutting down TTI Frame API...")
31
+ image_generator.cleanup()
32
+
33
+
34
+ # Create FastAPI app
35
+ app = FastAPI(
36
+ title="TTI Frame - OpenAI Compatible Text-to-Image API",
37
+ description="A FastAPI wrapper providing OpenAI-compatible endpoints for text-to-image generation",
38
+ version="1.0.0",
39
+ lifespan=lifespan
40
+ )
41
+
42
+ # Add CORS middleware
43
+ app.add_middleware(
44
+ CORSMiddleware,
45
+ allow_origins=["*"], # Configure as needed
46
+ allow_credentials=True,
47
+ allow_methods=["*"],
48
+ allow_headers=["*"],
49
+ )
50
+
51
+ # Mount static files for serving images
52
+ app.mount("/images", StaticFiles(directory=config.OUTPUT_DIR), name="images")
53
+
54
+
55
+ @app.get("/")
56
+ async def root():
57
+ """Root endpoint"""
58
+ return {
59
+ "message": "TTI Frame - OpenAI Compatible Text-to-Image API",
60
+ "version": "1.0.0",
61
+ "docs": "/docs"
62
+ }
63
+
64
+
65
+ @app.get("/v1/models", response_model=ModelsResponse)
66
+ async def list_models():
67
+ """List available models (OpenAI compatible)"""
68
+ models = [
69
+ ModelInfo(
70
+ id="dall-e-3",
71
+ created=1677649963,
72
+ owned_by="tti-frame"
73
+ ),
74
+ ModelInfo(
75
+ id="dall-e-2",
76
+ created=1677649963,
77
+ owned_by="tti-frame"
78
+ )
79
+ ]
80
+
81
+ return ModelsResponse(data=models)
82
+
83
+
84
+ @app.post("/v1/images/generations", response_model=ImageGenerationResponse)
85
+ async def create_image(request: ImageGenerationRequest):
86
+ """
87
+ Generate images from text prompts (OpenAI compatible)
88
+
89
+ Creates images based on a text prompt using advanced diffusion models.
90
+ Supports various sizes, qualities, and response formats.
91
+ """
92
+ try:
93
+ logger.info(f"Received image generation request: {request.prompt[:50]}...")
94
+
95
+ # Validate request
96
+ if not request.prompt or not request.prompt.strip():
97
+ raise HTTPException(
98
+ status_code=400,
99
+ detail="Prompt cannot be empty"
100
+ )
101
+
102
+ if len(request.prompt) > 4000:
103
+ raise HTTPException(
104
+ status_code=400,
105
+ detail="Prompt too long. Maximum 4000 characters allowed."
106
+ )
107
+
108
+ # Generate images
109
+ image_data = await image_generator.generate_images(request)
110
+
111
+ response = ImageGenerationResponse(
112
+ created=int(time.time()),
113
+ data=image_data
114
+ )
115
+
116
+ logger.info(f"Successfully generated {len(image_data)} images")
117
+ return response
118
+
119
+ except HTTPException:
120
+ raise
121
+ except Exception as e:
122
+ logger.error(f"Image generation failed: {e}")
123
+ raise HTTPException(
124
+ status_code=500,
125
+ detail=f"Image generation failed: {str(e)}"
126
+ )
127
+
128
+
129
+ @app.get("/health")
130
+ async def health_check():
131
+ """Health check endpoint"""
132
+ return {"status": "healthy", "timestamp": int(time.time())}
133
+
134
+
135
+ @app.exception_handler(Exception)
136
+ async def global_exception_handler(request, exc):
137
+ """Global exception handler"""
138
+ logger.error(f"Unhandled exception: {exc}")
139
+ return JSONResponse(
140
+ status_code=500,
141
+ content=ErrorResponse(
142
+ error={
143
+ "message": "Internal server error",
144
+ "type": "server_error",
145
+ "code": "internal_error"
146
+ }
147
+ ).dict()
148
+ )
149
+
150
+
151
+ if __name__ == "__main__":
152
+ uvicorn.run(
153
+ "main:app",
154
+ host=config.HOST,
155
+ port=config.PORT,
156
+ reload=True,
157
+ log_level="info"
158
+ )