Sam3838 commited on
Commit
dbf9089
·
verified ·
1 Parent(s): f06b6e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -21
app.py CHANGED
@@ -1,34 +1,118 @@
1
  import time
2
  import logging
3
- import uvicorn
 
 
4
  from contextlib import asynccontextmanager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.staticfiles import StaticFiles
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import JSONResponse
9
 
10
- from models import (
11
- ImageGenerationRequest,
12
- ImageGenerationResponse,
13
- ErrorResponse,
14
- ModelsResponse,
15
- ModelInfo
16
- )
17
- from image_generator import image_generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Setup logging
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
 
 
23
 
24
  @asynccontextmanager
25
  async def lifespan(app: FastAPI):
26
  """Application lifespan management"""
 
 
27
  logger.info("Starting TTI Frame API...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  yield
 
29
  logger.info("Shutting down TTI Frame API...")
30
- image_generator.cleanup()
31
-
32
 
33
  # Create FastAPI app
34
  app = FastAPI(
@@ -47,20 +131,16 @@ app.add_middleware(
47
  allow_headers=["*"],
48
  )
49
 
50
- # Mount static files for serving images
51
- app.mount("/images", StaticFiles(directory="images"), name="images")
52
-
53
-
54
  @app.get("/")
55
  async def root():
56
  """Root endpoint"""
57
  return {
58
  "message": "TTI Frame - OpenAI Compatible Text-to-Image API",
59
  "version": "1.0.0",
60
- "docs": "/docs"
 
61
  }
62
 
63
-
64
  @app.get("/v1/models", response_model=ModelsResponse)
65
  async def list_models():
66
  """List available models (OpenAI compatible)"""
@@ -74,12 +154,16 @@ async def list_models():
74
  id="dall-e-2",
75
  created=1677649963,
76
  owned_by="tti-frame"
 
 
 
 
 
77
  )
78
  ]
79
 
80
  return ModelsResponse(data=models)
81
 
82
-
83
  @app.post("/v1/images/generations", response_model=ImageGenerationResponse)
84
  async def create_image(request: ImageGenerationRequest):
85
  """
@@ -88,6 +172,12 @@ async def create_image(request: ImageGenerationRequest):
88
  Creates images based on a text prompt using advanced diffusion models.
89
  Supports various sizes, qualities, and response formats.
90
  """
 
 
 
 
 
 
91
  try:
92
  logger.info(f"Received image generation request: {request.prompt[:50]}...")
93
 
@@ -104,6 +194,16 @@ async def create_image(request: ImageGenerationRequest):
104
  detail="Prompt too long. Maximum 4000 characters allowed."
105
  )
106
 
 
 
 
 
 
 
 
 
 
 
107
  # Generate images
108
  image_data = await image_generator.generate_images(request)
109
 
@@ -124,12 +224,42 @@ async def create_image(request: ImageGenerationRequest):
124
  detail=f"Image generation failed: {str(e)}"
125
  )
126
 
127
-
128
  @app.get("/health")
129
  async def health_check():
130
  """Health check endpoint"""
131
- return {"status": "healthy", "timestamp": int(time.time())}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  @app.exception_handler(Exception)
135
  async def global_exception_handler(request, exc):
@@ -146,8 +276,12 @@ async def global_exception_handler(request, exc):
146
  ).dict()
147
  )
148
 
149
-
150
  if __name__ == "__main__":
 
 
 
 
 
151
  uvicorn.run(
152
  "main:app",
153
  host="0.0.0.0",
 
1
  import time
2
  import logging
3
+ import os
4
+ import sys
5
+ import subprocess
6
  from contextlib import asynccontextmanager
7
+ from typing import List
8
+ from enum import Enum
9
+ from pydantic import BaseModel
10
+
11
+ # Install required packages
12
+ def install_packages():
13
+ """Install required packages using pip"""
14
+ packages = [
15
+ "fastapi",
16
+ "uvicorn[standard]",
17
+ "pillow",
18
+ "huggingface_hub",
19
+ "pydantic"
20
+ ]
21
+
22
+ for package in packages:
23
+ try:
24
+ # Check if package is already installed
25
+ if package == "uvicorn[standard]":
26
+ __import__("uvicorn")
27
+ elif package == "huggingface_hub":
28
+ __import__("huggingface_hub")
29
+ else:
30
+ __import__(package.replace("-", "_"))
31
+ print(f"{package} already installed")
32
+ except ImportError:
33
+ print(f"Installing {package}...")
34
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
35
+
36
+ # Install packages before importing
37
+ install_packages()
38
+
39
+ import uvicorn
40
  from fastapi import FastAPI, HTTPException
41
  from fastapi.staticfiles import StaticFiles
42
  from fastapi.middleware.cors import CORSMiddleware
43
  from fastapi.responses import JSONResponse
44
 
45
+ # Define models directly in the file
46
+ class ResponseFormat(str, Enum):
47
+ URL = "url"
48
+ B64_JSON = "b64_json"
49
+
50
+ class ImageGenerationRequest(BaseModel):
51
+ prompt: str
52
+ model: str = "dall-e-3"
53
+ n: int = 1
54
+ size: str = "1024x1024"
55
+ quality: str = "standard"
56
+ response_format: ResponseFormat = ResponseFormat.URL
57
+
58
+ class ImageData(BaseModel):
59
+ url: str = None
60
+ b64_json: str = None
61
+ revised_prompt: str = None
62
+
63
+ class ImageGenerationResponse(BaseModel):
64
+ created: int
65
+ data: List[ImageData]
66
+
67
+ class ErrorResponse(BaseModel):
68
+ error: dict
69
+
70
+ class ModelInfo(BaseModel):
71
+ id: str
72
+ created: int
73
+ owned_by: str
74
+
75
+ class ModelsResponse(BaseModel):
76
+ data: List[ModelInfo]
77
+
78
+ # Import the modified image generator
79
+ from image_generator import ImageGenerator
80
 
81
  # Setup logging
82
  logging.basicConfig(level=logging.INFO)
83
  logger = logging.getLogger(__name__)
84
 
85
+ # Global image generator instance
86
+ image_generator = None
87
 
88
  @asynccontextmanager
89
  async def lifespan(app: FastAPI):
90
  """Application lifespan management"""
91
+ global image_generator
92
+
93
  logger.info("Starting TTI Frame API...")
94
+
95
+ # Initialize image generator
96
+ hf_token = os.getenv("HF_TOKEN")
97
+ if not hf_token:
98
+ logger.warning("HF_TOKEN environment variable not set. Image generation may fail.")
99
+
100
+ image_generator = ImageGenerator(hf_token=hf_token)
101
+
102
+ # Set base URL for serving images
103
+ base_url = os.getenv("BASE_URL", "http://localhost:8000")
104
+ image_generator.set_config(base_url=base_url)
105
+
106
+ # Mount the temporary directory for static files
107
+ app.mount("/images", StaticFiles(directory=image_generator.output_dir), name="images")
108
+
109
+ logger.info(f"Image generator initialized with output directory: {image_generator.output_dir}")
110
+
111
  yield
112
+
113
  logger.info("Shutting down TTI Frame API...")
114
+ if image_generator:
115
+ image_generator.cleanup()
116
 
117
  # Create FastAPI app
118
  app = FastAPI(
 
131
  allow_headers=["*"],
132
  )
133
 
 
 
 
 
134
  @app.get("/")
135
  async def root():
136
  """Root endpoint"""
137
  return {
138
  "message": "TTI Frame - OpenAI Compatible Text-to-Image API",
139
  "version": "1.0.0",
140
+ "docs": "/docs",
141
+ "output_dir": image_generator.output_dir if image_generator else "Not initialized"
142
  }
143
 
 
144
  @app.get("/v1/models", response_model=ModelsResponse)
145
  async def list_models():
146
  """List available models (OpenAI compatible)"""
 
154
  id="dall-e-2",
155
  created=1677649963,
156
  owned_by="tti-frame"
157
+ ),
158
+ ModelInfo(
159
+ id="black-forest-labs/flux-schnell",
160
+ created=1677649963,
161
+ owned_by="tti-frame"
162
  )
163
  ]
164
 
165
  return ModelsResponse(data=models)
166
 
 
167
  @app.post("/v1/images/generations", response_model=ImageGenerationResponse)
168
  async def create_image(request: ImageGenerationRequest):
169
  """
 
172
  Creates images based on a text prompt using advanced diffusion models.
173
  Supports various sizes, qualities, and response formats.
174
  """
175
+ if not image_generator:
176
+ raise HTTPException(
177
+ status_code=500,
178
+ detail="Image generator not initialized. Check HF_TOKEN environment variable."
179
+ )
180
+
181
  try:
182
  logger.info(f"Received image generation request: {request.prompt[:50]}...")
183
 
 
194
  detail="Prompt too long. Maximum 4000 characters allowed."
195
  )
196
 
197
+ # Map OpenAI model names to HuggingFace models
198
+ model_mapping = {
199
+ "dall-e-3": "black-forest-labs/flux-schnell",
200
+ "dall-e-2": "black-forest-labs/flux-schnell",
201
+ }
202
+
203
+ # Update request model if needed
204
+ if request.model in model_mapping:
205
+ request.model = model_mapping[request.model]
206
+
207
  # Generate images
208
  image_data = await image_generator.generate_images(request)
209
 
 
224
  detail=f"Image generation failed: {str(e)}"
225
  )
226
 
 
227
  @app.get("/health")
228
  async def health_check():
229
  """Health check endpoint"""
230
+ return {
231
+ "status": "healthy",
232
+ "timestamp": int(time.time()),
233
+ "generator_initialized": image_generator is not None,
234
+ "output_dir": image_generator.output_dir if image_generator else None
235
+ }
236
+
237
+ @app.get("/config")
238
+ async def get_config():
239
+ """Get current configuration"""
240
+ if not image_generator:
241
+ return {"error": "Image generator not initialized"}
242
+
243
+ return {
244
+ "output_dir": image_generator.output_dir,
245
+ "base_url": image_generator.base_url,
246
+ "default_model": image_generator.default_model,
247
+ "hf_token_set": bool(image_generator.hf_token)
248
+ }
249
 
250
+ @app.post("/config")
251
+ async def update_config(hf_token: str = None, base_url: str = None, default_model: str = None):
252
+ """Update configuration"""
253
+ if not image_generator:
254
+ raise HTTPException(status_code=500, detail="Image generator not initialized")
255
+
256
+ image_generator.set_config(
257
+ hf_token=hf_token,
258
+ base_url=base_url,
259
+ default_model=default_model
260
+ )
261
+
262
+ return {"message": "Configuration updated successfully"}
263
 
264
  @app.exception_handler(Exception)
265
  async def global_exception_handler(request, exc):
 
276
  ).dict()
277
  )
278
 
 
279
  if __name__ == "__main__":
280
+ # Set environment variables if not already set
281
+ if not os.getenv("HF_TOKEN"):
282
+ print("Warning: HF_TOKEN environment variable not set.")
283
+ print("Please set it with: export HF_TOKEN=your_huggingface_token")
284
+
285
  uvicorn.run(
286
  "main:app",
287
  host="0.0.0.0",