t3k45h1 commited on
Commit
f580797
Β·
verified Β·
1 Parent(s): 262ba70

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +109 -22
main.py CHANGED
@@ -9,11 +9,10 @@ import io
9
  from PIL import Image
10
  import os
11
 
12
-
13
- # Set custom Hugging Face cache directory
14
- os.environ["HF_HOME"] = "/app/cache"
15
- os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
16
- os.environ["HF_HUB_CACHE"] = "/app/cache"
17
 
18
  # Initialize FastAPI app
19
  app = FastAPI(title="PromptAgro Image Generator API")
@@ -27,21 +26,83 @@ app.add_middleware(
27
  allow_headers=["*"],
28
  )
29
 
30
- # Load Stable Diffusion LCM model (your original approach)
31
- print("πŸš€ Loading Stable Diffusion Model...")
32
- model_id = "rupeshs/LCM-runwayml-stable-diffusion-v1-5"
33
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
34
- pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
35
- print("βœ… Model Loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  @app.get("/")
38
  async def root():
39
- """Health check endpoint"""
40
  return {
41
  "status": "alive",
42
  "service": "PromptAgro Image Generator",
43
- "model_loaded": True,
44
- "device": "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
45
  }
46
 
47
  @app.post("/generate/")
@@ -50,19 +111,31 @@ async def generate_image(prompt: str = Form(...)):
50
  Generate product packaging image from input prompt.
51
  Returns image file directly (your original approach).
52
  """
 
 
 
 
 
 
 
53
  print(f"πŸ–ŒοΈ Generating image for prompt: {prompt}")
54
 
55
- # Generate image (your original approach)
56
- image = pipe(prompt).images[0]
 
57
 
58
- # Save image to temp file (your original approach)
59
- filename = f"/tmp/{uuid.uuid4().hex}.png"
60
- image.save(filename)
61
 
62
- print(f"πŸ“¦ Image saved to {filename}")
63
 
64
- # Return image file as response (your original approach)
65
- return FileResponse(filename, media_type="image/png")
 
 
 
 
66
 
67
  @app.post("/generate-json/")
68
  async def generate_image_json(
@@ -75,6 +148,13 @@ async def generate_image_json(
75
  """
76
  Generate image and return as JSON with base64 data (for frontend integration).
77
  """
 
 
 
 
 
 
 
78
  print(f"πŸ–ŒοΈ Generating image for prompt: {prompt}")
79
 
80
  try:
@@ -116,6 +196,13 @@ async def generate_packaging_specific(
116
  """
117
  Generate packaging with PromptAgro-specific prompt engineering
118
  """
 
 
 
 
 
 
 
119
  # Create professional prompt for agricultural packaging
120
  prompt = f"""Professional agricultural product packaging design for {product_name},
121
  modern clean style, {colors.replace(',', ' and ')} color scheme, premium typography,
 
9
  from PIL import Image
10
  import os
11
 
12
+ # Set cache directory to a writable location
13
+ os.environ["HF_HOME"] = "/tmp/huggingface_cache"
14
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
15
+ os.environ["HF_HUB_CACHE"] = "/tmp/huggingface_cache"
 
16
 
17
  # Initialize FastAPI app
18
  app = FastAPI(title="PromptAgro Image Generator API")
 
26
  allow_headers=["*"],
27
  )
28
 
29
+ # Global variable for the pipeline
30
+ pipe = None
31
+ model_loading = False
32
+
33
+ def load_model_if_needed():
34
+ """Load model lazily when first request arrives"""
35
+ global pipe, model_loading
36
+
37
+ if pipe is not None:
38
+ return True
39
+
40
+ if model_loading:
41
+ return False # Already loading, wait
42
+
43
+ model_loading = True
44
+ success = load_model()
45
+ model_loading = False
46
+ return success
47
+
48
+ def load_model():
49
+ """Load the Stable Diffusion model with proper error handling"""
50
+ global pipe
51
+
52
+ print("πŸš€ Loading Stable Diffusion Model...")
53
+ model_id = "rupeshs/LCM-runwayml-stable-diffusion-v1-5"
54
+
55
+ try:
56
+ # Create cache directory if it doesn't exist
57
+ os.makedirs("/tmp/huggingface_cache", exist_ok=True)
58
+
59
+ # Use appropriate dtype based on device
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
62
+
63
+ print(f"πŸ“± Device: {device}, dtype: {torch_dtype}")
64
+
65
+ # Load the model with cache directory specified
66
+ pipe = StableDiffusionPipeline.from_pretrained(
67
+ model_id,
68
+ torch_dtype=torch_dtype,
69
+ cache_dir="/tmp/huggingface_cache",
70
+ local_files_only=False
71
+ )
72
+
73
+ pipe = pipe.to(device)
74
+
75
+ # Enable memory efficient attention if available
76
+ if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
77
+ try:
78
+ pipe.enable_xformers_memory_efficient_attention()
79
+ print("βœ… XFormers memory efficient attention enabled")
80
+ except Exception:
81
+ print("⚠️ XFormers not available, using default attention")
82
+
83
+ print(f"βœ… Model Loaded successfully on {device}")
84
+ return True
85
+
86
+ except Exception as e:
87
+ print(f"❌ Failed to load model: {e}")
88
+ pipe = None
89
+ return False
90
+
91
+ # Don't load model on startup - do it lazily
92
+ # model_loaded = load_model()
93
 
94
  @app.get("/")
95
  async def root():
96
+ """Health check endpoint with enhanced status"""
97
  return {
98
  "status": "alive",
99
  "service": "PromptAgro Image Generator",
100
+ "model_loaded": pipe is not None,
101
+ "model_loading": model_loading,
102
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
103
+ "model_status": "loaded" if pipe is not None else ("loading" if model_loading else "not_loaded"),
104
+ "torch_dtype": "float16" if torch.cuda.is_available() else "float32",
105
+ "ready_for_requests": pipe is not None
106
  }
107
 
108
  @app.post("/generate/")
 
111
  Generate product packaging image from input prompt.
112
  Returns image file directly (your original approach).
113
  """
114
+ # Lazy load model on first request
115
+ if not load_model_if_needed():
116
+ if model_loading:
117
+ raise HTTPException(status_code=503, detail="Model is loading, please wait...")
118
+ else:
119
+ raise HTTPException(status_code=503, detail="Model failed to load. Please check logs.")
120
+
121
  print(f"πŸ–ŒοΈ Generating image for prompt: {prompt}")
122
 
123
+ try:
124
+ # Generate image (your original approach)
125
+ image = pipe(prompt).images[0]
126
 
127
+ # Save image to temp file (your original approach)
128
+ filename = f"/tmp/{uuid.uuid4().hex}.png"
129
+ image.save(filename)
130
 
131
+ print(f"πŸ“¦ Image saved to {filename}")
132
 
133
+ # Return image file as response (your original approach)
134
+ return FileResponse(filename, media_type="image/png")
135
+
136
+ except Exception as e:
137
+ print(f"❌ Image generation failed: {e}")
138
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
139
 
140
  @app.post("/generate-json/")
141
  async def generate_image_json(
 
148
  """
149
  Generate image and return as JSON with base64 data (for frontend integration).
150
  """
151
+ # Lazy load model on first request
152
+ if not load_model_if_needed():
153
+ if model_loading:
154
+ raise HTTPException(status_code=503, detail="Model is loading, please wait...")
155
+ else:
156
+ raise HTTPException(status_code=503, detail="Model failed to load. Please check logs.")
157
+
158
  print(f"πŸ–ŒοΈ Generating image for prompt: {prompt}")
159
 
160
  try:
 
196
  """
197
  Generate packaging with PromptAgro-specific prompt engineering
198
  """
199
+ # Lazy load model on first request
200
+ if not load_model_if_needed():
201
+ if model_loading:
202
+ raise HTTPException(status_code=503, detail="Model is loading, please wait...")
203
+ else:
204
+ raise HTTPException(status_code=503, detail="Model failed to load. Please check logs.")
205
+
206
  # Create professional prompt for agricultural packaging
207
  prompt = f"""Professional agricultural product packaging design for {product_name},
208
  modern clean style, {colors.replace(',', ' and ')} color scheme, premium typography,