bluenevus commited on
Commit
37386e2
·
1 Parent(s): ae38572

Update app.py via AI Editor

Browse files
Files changed (1) hide show
  1. app.py +160 -106
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import base64
2
  import dash
3
- from dash import dcc, html, Input, Output, State
4
  import dash_bootstrap_components as dbc
5
  from dash.exceptions import PreventUpdate
6
  import google.generativeai as genai
@@ -9,18 +9,18 @@ import logging
9
  import threading
10
  import time
11
  import os
 
 
12
 
13
  # Set up logging
14
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
15
 
16
- # Updated STYLES list
17
  STYLES = [
18
  "photographic", "3d-model", "analog-film", "anime", "cinematic", "comic-book",
19
  "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly",
20
  "modeling-compound", "neon-punk", "origami", "pixel-art", "tile-texture"
21
  ]
22
 
23
- # Default negative prompt (hidden from UI)
24
  DEFAULT_NEGATIVE_PROMPT = """
25
  ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame,
26
  extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature,
@@ -28,65 +28,43 @@ cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, di
28
  plastic, cartoonish, artificial, fake, unnatural, blurry, smooth, lack of detail, low quality
29
  """
30
 
31
- app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
 
 
32
 
33
- app.layout = dbc.Container([
34
- html.H1("ImaGen", className="my-4"),
35
- dbc.Row([
36
- # Left column: Form entry
37
- dbc.Col([
38
- dbc.Card([
39
- dbc.CardBody([
40
- dbc.Textarea(id="prompt", placeholder="Enter your prompt", className="mb-3"),
41
- dcc.Dropdown(
42
- id="style",
43
- options=[{"label": s.replace("-", " ").title(), "value": s} for s in STYLES],
44
- value="photographic",
45
- placeholder="Select style",
46
- className="mb-3"
47
- ),
48
- dbc.Button("Generate Image", id="submit-btn", color="primary", className="mb-3"),
49
- dbc.Accordion([
50
- dbc.AccordionItem(
51
- [
52
- dbc.Label("Aspect Ratio"),
53
- dcc.Dropdown(
54
- id="aspect-ratio",
55
- options=[
56
- {"label": ar, "value": ar} for ar in
57
- ["16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"]
58
- ],
59
- value="1:1"
60
- ),
61
- dbc.Label("Steps"),
62
- dcc.Slider(id="steps", min=4, max=50, step=1, value=30, marks={4: '4', 25: '25', 50: '50'}),
63
- ],
64
- title="Advanced Settings",
65
- ),
66
- ], start_collapsed=True, className="mb-3"),
67
- ])
68
- ], className="mb-4"),
69
- ], width=6),
70
- # Right column: Image preview
71
- dbc.Col([
72
- dbc.Card([
73
- dbc.CardBody([
74
- dcc.Loading(
75
- id="loading",
76
- type="circle",
77
- children=[
78
- html.Div(id="status-message", className="mb-3"),
79
- html.Img(id="image-output", className="img-fluid mb-3"),
80
- html.Div(id="enhanced-prompt-output", className="mb-3"),
81
- dbc.Button("Download Image", id="download-btn", color="secondary", className="mb-3", disabled=True),
82
- dcc.Download(id="download-image")
83
- ]
84
- ),
85
- ])
86
- ]),
87
- ], width=6),
88
- ]),
89
- ], fluid=True)
90
 
91
  def enhance_prompt(google_api_key, prompt, style):
92
  genai.configure(api_key=google_api_key)
@@ -108,17 +86,13 @@ def enhance_prompt(google_api_key, prompt, style):
108
 
109
  Enhanced prompt:
110
  """
111
-
112
  try:
113
  response = model.generate_content(enhanced_prompt_request)
114
-
115
  enhanced_prompt = response.text.strip()
116
-
117
  prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"]
118
  for prefix in prefixes_to_remove:
119
  if enhanced_prompt.lower().startswith(prefix.lower()):
120
  enhanced_prompt = enhanced_prompt[len(prefix):].strip()
121
-
122
  logging.info(f"Enhanced prompt: {enhanced_prompt}")
123
  return enhanced_prompt
124
  except Exception as e:
@@ -127,12 +101,10 @@ def enhance_prompt(google_api_key, prompt, style):
127
 
128
  def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, steps, aspect_ratio):
129
  url = "https://api.stability.ai/v2beta/stable-image/generate/core"
130
-
131
  headers = {
132
  "Accept": "image/*",
133
  "Authorization": f"Bearer {stability_api_key}"
134
  }
135
-
136
  data = {
137
  "prompt": f"{enhanced_prompt}, Style: {style}, highly detailed, high quality, descriptive, sharp focus, intricate details",
138
  "negative_prompt": negative_prompt,
@@ -143,14 +115,9 @@ def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, s
143
  "style_preset": style,
144
  "aspect_ratio": aspect_ratio,
145
  }
146
-
147
  try:
148
  response = requests.post(url, headers=headers, files={"none": ''}, data=data, timeout=60)
149
  response.raise_for_status()
150
-
151
- logging.debug(f"Response headers: {response.headers}")
152
- logging.debug(f"Response content type: {response.headers.get('content-type')}")
153
-
154
  if response.headers.get('content-type').startswith('image/'):
155
  image_data = response.content
156
  if len(image_data) < 1000:
@@ -160,7 +127,6 @@ def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, s
160
  error_message = response.text
161
  logging.error(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}")
162
  raise Exception(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}")
163
-
164
  except requests.exceptions.RequestException as e:
165
  logging.error(f"Request failed: {str(e)}")
166
  raise Exception(f"Request failed: {str(e)}")
@@ -169,7 +135,6 @@ def process_and_generate(google_api_key, stability_api_key, prompt, style, steps
169
  try:
170
  set_status("Enhancing prompt...")
171
  enhanced_prompt = enhance_prompt(google_api_key, prompt, style)
172
-
173
  set_status("Generating image...")
174
  max_attempts = 3
175
  for attempt in range(max_attempts):
@@ -188,72 +153,161 @@ def process_and_generate(google_api_key, stability_api_key, prompt, style, steps
188
  set_status(f"Error: {str(e)}")
189
  return None, str(e)
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  @app.callback(
192
- [Output("image-output", "src"),
193
- Output("enhanced-prompt-output", "children"),
194
- Output("status-message", "children"),
195
- Output("download-btn", "disabled")],
 
 
196
  [Input("submit-btn", "n_clicks")],
197
- [State("prompt", "value"),
198
- State("style", "value"),
199
- State("steps", "value"),
200
- State("aspect-ratio", "value")],
 
 
201
  prevent_initial_call=True
202
  )
203
  def update_output(n_clicks, prompt, style, steps, aspect_ratio):
 
204
  if n_clicks is None:
205
  raise PreventUpdate
206
-
 
 
 
 
207
  google_api_key = os.getenv('GOOGLE_API_KEY')
208
  stability_api_key = os.getenv('STABILITY_API_KEY')
209
-
210
  if not google_api_key or not stability_api_key:
211
  return "", "Error: API keys not found in environment variables", "API keys missing", True
212
-
213
- logging.debug(f"Stability API Key (first 4 chars): {stability_api_key[:4]}...")
214
-
215
  status = {"message": "Starting process..."}
216
-
217
  def set_status(message):
218
  status["message"] = message
219
-
220
  def run_process():
221
- image_bytes, enhanced_prompt = process_and_generate(google_api_key, stability_api_key, prompt, style, steps, aspect_ratio, set_status)
222
- if image_bytes:
223
- encoded_image = base64.b64encode(image_bytes).decode('ascii')
224
- return f"data:image/jpeg;base64,{encoded_image}", f"Enhanced Prompt: {enhanced_prompt}", status["message"], False
225
- else:
226
- return "", f"Error: {enhanced_prompt}", status["message"], True
227
-
 
 
 
 
228
  try:
229
  thread = threading.Thread(target=run_process)
230
  thread.start()
231
  thread.join(timeout=90)
232
-
233
  if thread.is_alive():
 
 
 
 
234
  return "", "Error: Image generation timed out", "Process timed out", True
235
-
236
  return run_process()
237
  except Exception as e:
238
  logging.error(f"Unexpected error in update_output: {str(e)}")
 
 
 
 
239
  return "", f"Unexpected error: {str(e)}", "An unexpected error occurred", True
240
 
241
  @app.callback(
242
  Output("download-image", "data"),
243
- Input("download-btn", "n_clicks"),
244
- State("image-output", "src"),
245
  prevent_initial_call=True
246
  )
247
  def download_image(n_clicks, image_src):
 
248
  if n_clicks is None:
249
  raise PreventUpdate
250
-
251
- image_data = image_src.split(",")[1]
252
- image_bytes = base64.b64decode(image_data)
253
-
254
- return dcc.send_bytes(image_bytes, "generated_image.jpeg")
 
 
 
255
 
256
  if __name__ == '__main__':
257
  print("Starting the Dash application...")
258
- app.run(debug=False, host='0.0.0.0', port=7860)
259
  print("Dash application has finished running.")
 
1
  import base64
2
  import dash
3
+ from dash import dcc, html, Input, Output, State, callback_context
4
  import dash_bootstrap_components as dbc
5
  from dash.exceptions import PreventUpdate
6
  import google.generativeai as genai
 
9
  import threading
10
  import time
11
  import os
12
+ import flask
13
+ import uuid
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
17
 
 
18
  STYLES = [
19
  "photographic", "3d-model", "analog-film", "anime", "cinematic", "comic-book",
20
  "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly",
21
  "modeling-compound", "neon-punk", "origami", "pixel-art", "tile-texture"
22
  ]
23
 
 
24
  DEFAULT_NEGATIVE_PROMPT = """
25
  ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame,
26
  extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature,
 
28
  plastic, cartoonish, artificial, fake, unnatural, blurry, smooth, lack of detail, low quality
29
  """
30
 
31
+ external_stylesheets = [dbc.themes.BOOTSTRAP]
32
+ server = flask.Flask(__name__)
33
+ app = dash.Dash(__name__, server=server, external_stylesheets=external_stylesheets)
34
 
35
+ app.title = "ImaGen"
36
+
37
+ # Global in-memory session storage and locks
38
+ SESSION_DATA = {}
39
+ SESSION_LOCKS = {}
40
+
41
+ def get_session_id():
42
+ if hasattr(flask.g, "session_id"):
43
+ return flask.g.session_id
44
+ session_id = flask.request.cookies.get("session_id")
45
+ if not session_id:
46
+ session_id = str(uuid.uuid4())
47
+ flask.g.session_id = session_id
48
+ return session_id
49
+
50
+ @app.server.before_request
51
+ def ensure_session_id():
52
+ session_id = flask.request.cookies.get("session_id")
53
+ if not session_id:
54
+ session_id = str(uuid.uuid4())
55
+ flask.g.set_cookie = session_id
56
+ flask.g.session_id = session_id or flask.g.get("set_cookie", None)
57
+ # Ensure session state and lock
58
+ if session_id not in SESSION_DATA:
59
+ SESSION_DATA[session_id] = {'image': None, 'enhanced_prompt': None, 'status': None}
60
+ if session_id not in SESSION_LOCKS:
61
+ SESSION_LOCKS[session_id] = threading.Lock()
62
+
63
+ @app.server.after_request
64
+ def set_session_cookie(response):
65
+ if hasattr(flask.g, "set_cookie"):
66
+ response.set_cookie("session_id", flask.g.set_cookie, httponly=True, samesite='Lax')
67
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def enhance_prompt(google_api_key, prompt, style):
70
  genai.configure(api_key=google_api_key)
 
86
 
87
  Enhanced prompt:
88
  """
 
89
  try:
90
  response = model.generate_content(enhanced_prompt_request)
 
91
  enhanced_prompt = response.text.strip()
 
92
  prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"]
93
  for prefix in prefixes_to_remove:
94
  if enhanced_prompt.lower().startswith(prefix.lower()):
95
  enhanced_prompt = enhanced_prompt[len(prefix):].strip()
 
96
  logging.info(f"Enhanced prompt: {enhanced_prompt}")
97
  return enhanced_prompt
98
  except Exception as e:
 
101
 
102
  def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, steps, aspect_ratio):
103
  url = "https://api.stability.ai/v2beta/stable-image/generate/core"
 
104
  headers = {
105
  "Accept": "image/*",
106
  "Authorization": f"Bearer {stability_api_key}"
107
  }
 
108
  data = {
109
  "prompt": f"{enhanced_prompt}, Style: {style}, highly detailed, high quality, descriptive, sharp focus, intricate details",
110
  "negative_prompt": negative_prompt,
 
115
  "style_preset": style,
116
  "aspect_ratio": aspect_ratio,
117
  }
 
118
  try:
119
  response = requests.post(url, headers=headers, files={"none": ''}, data=data, timeout=60)
120
  response.raise_for_status()
 
 
 
 
121
  if response.headers.get('content-type').startswith('image/'):
122
  image_data = response.content
123
  if len(image_data) < 1000:
 
127
  error_message = response.text
128
  logging.error(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}")
129
  raise Exception(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}")
 
130
  except requests.exceptions.RequestException as e:
131
  logging.error(f"Request failed: {str(e)}")
132
  raise Exception(f"Request failed: {str(e)}")
 
135
  try:
136
  set_status("Enhancing prompt...")
137
  enhanced_prompt = enhance_prompt(google_api_key, prompt, style)
 
138
  set_status("Generating image...")
139
  max_attempts = 3
140
  for attempt in range(max_attempts):
 
153
  set_status(f"Error: {str(e)}")
154
  return None, str(e)
155
 
156
+ app.layout = dbc.Container([
157
+ dbc.Row([
158
+ dbc.Col([
159
+ html.H1("ImaGen", className="text-center mb-4")
160
+ ], width=12)
161
+ ]),
162
+ dbc.Row([
163
+ dbc.Col([
164
+ dbc.Card([
165
+ dbc.CardBody([
166
+ dbc.Textarea(
167
+ id="prompt",
168
+ placeholder="Tell me what image you want me to build.",
169
+ className="mb-3",
170
+ style={"height": "120px", "whiteSpace": "pre-wrap", "wordWrap": "break-word"}
171
+ ),
172
+ dcc.Dropdown(
173
+ id="style",
174
+ options=[{"label": s.replace("-", " ").title(), "value": s} for s in STYLES],
175
+ value="photographic",
176
+ placeholder="Select style",
177
+ className="mb-3"
178
+ ),
179
+ # Advanced settings always open, below dropdown
180
+ dbc.FormGroup([
181
+ dbc.Label("Aspect Ratio"),
182
+ dcc.Dropdown(
183
+ id="aspect-ratio",
184
+ options=[
185
+ {"label": ar, "value": ar} for ar in
186
+ ["16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"]
187
+ ],
188
+ value="1:1",
189
+ className="mb-3"
190
+ ),
191
+ dbc.Label("Steps"),
192
+ dcc.Slider(
193
+ id="steps",
194
+ min=4,
195
+ max=50,
196
+ step=1,
197
+ value=30,
198
+ marks={4: '4', 25: '25', 50: '50'},
199
+ className="mb-3"
200
+ ),
201
+ ]),
202
+ dbc.Button("Generate Image", id="submit-btn", color="primary", className="mt-2 mb-2", style={"width": "100%"})
203
+ ])
204
+ ])
205
+ ], width=4, style={"minWidth": "300px", "maxWidth": "420px", "flex": "0 0 30%"}),
206
+ dbc.Col([
207
+ dbc.Card([
208
+ dbc.CardBody([
209
+ dbc.Button("Download Image", id="download-btn", color="secondary", className="mb-3", disabled=True, style={"width": "100%"}),
210
+ dcc.Loading(
211
+ id="loading",
212
+ type="default",
213
+ children=[
214
+ html.Div(id="status-message", className="mb-3"),
215
+ html.Img(id="image-output", className="img-fluid mb-3"),
216
+ html.Div(id="enhanced-prompt-output", className="mb-3"),
217
+ dcc.Download(id="download-image")
218
+ ],
219
+ style={"display": "block", "margin": "auto"}
220
+ ),
221
+ ])
222
+ ])
223
+ ], width=8, style={"flex": "0 0 70%"})
224
+ ], align="start")
225
+ ], fluid=True)
226
+
227
  @app.callback(
228
+ [
229
+ Output("image-output", "src"),
230
+ Output("enhanced-prompt-output", "children"),
231
+ Output("status-message", "children"),
232
+ Output("download-btn", "disabled")
233
+ ],
234
  [Input("submit-btn", "n_clicks")],
235
+ [
236
+ State("prompt", "value"),
237
+ State("style", "value"),
238
+ State("steps", "value"),
239
+ State("aspect-ratio", "value")
240
+ ],
241
  prevent_initial_call=True
242
  )
243
  def update_output(n_clicks, prompt, style, steps, aspect_ratio):
244
+ ctx = callback_context
245
  if n_clicks is None:
246
  raise PreventUpdate
247
+ session_id = flask.request.cookies.get("session_id")
248
+ if not session_id:
249
+ session_id = str(uuid.uuid4())
250
+ lock = SESSION_LOCKS.setdefault(session_id, threading.Lock())
251
+ session_data = SESSION_DATA.setdefault(session_id, {'image': None, 'enhanced_prompt': None, 'status': None})
252
  google_api_key = os.getenv('GOOGLE_API_KEY')
253
  stability_api_key = os.getenv('STABILITY_API_KEY')
 
254
  if not google_api_key or not stability_api_key:
255
  return "", "Error: API keys not found in environment variables", "API keys missing", True
 
 
 
256
  status = {"message": "Starting process..."}
 
257
  def set_status(message):
258
  status["message"] = message
259
+ session_data['status'] = message
260
  def run_process():
261
+ with lock:
262
+ image_bytes, enhanced_prompt = process_and_generate(google_api_key, stability_api_key, prompt, style, steps, aspect_ratio, set_status)
263
+ if image_bytes:
264
+ encoded_image = base64.b64encode(image_bytes).decode('ascii')
265
+ session_data['image'] = encoded_image
266
+ session_data['enhanced_prompt'] = enhanced_prompt
267
+ return f"data:image/jpeg;base64,{encoded_image}", f"Enhanced Prompt: {enhanced_prompt}", status["message"], False
268
+ else:
269
+ session_data['image'] = None
270
+ session_data['enhanced_prompt'] = None
271
+ return "", f"Error: {enhanced_prompt}", status["message"], True
272
  try:
273
  thread = threading.Thread(target=run_process)
274
  thread.start()
275
  thread.join(timeout=90)
 
276
  if thread.is_alive():
277
+ with lock:
278
+ session_data['status'] = "Process timed out"
279
+ session_data['image'] = None
280
+ session_data['enhanced_prompt'] = None
281
  return "", "Error: Image generation timed out", "Process timed out", True
 
282
  return run_process()
283
  except Exception as e:
284
  logging.error(f"Unexpected error in update_output: {str(e)}")
285
+ with lock:
286
+ session_data['status'] = "An unexpected error occurred"
287
+ session_data['image'] = None
288
+ session_data['enhanced_prompt'] = None
289
  return "", f"Unexpected error: {str(e)}", "An unexpected error occurred", True
290
 
291
  @app.callback(
292
  Output("download-image", "data"),
293
+ [Input("download-btn", "n_clicks")],
294
+ [State("image-output", "src")],
295
  prevent_initial_call=True
296
  )
297
  def download_image(n_clicks, image_src):
298
+ ctx = callback_context
299
  if n_clicks is None:
300
  raise PreventUpdate
301
+ session_id = flask.request.cookies.get("session_id")
302
+ lock = SESSION_LOCKS.setdefault(session_id, threading.Lock())
303
+ session_data = SESSION_DATA.setdefault(session_id, {'image': None, 'enhanced_prompt': None, 'status': None})
304
+ with lock:
305
+ if not session_data.get('image'):
306
+ raise PreventUpdate
307
+ image_bytes = base64.b64decode(session_data['image'])
308
+ return dcc.send_bytes(image_bytes, "generated_image.jpeg")
309
 
310
  if __name__ == '__main__':
311
  print("Starting the Dash application...")
312
+ app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)
313
  print("Dash application has finished running.")