bluenevus commited on
Commit
ae38572
·
1 Parent(s): 46404a1

Update app.py via AI Editor

Browse files
Files changed (1) hide show
  1. app.py +259 -1
app.py CHANGED
@@ -1 +1,259 @@
1
- --
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import dash
3
+ from dash import dcc, html, Input, Output, State
4
+ import dash_bootstrap_components as dbc
5
+ from dash.exceptions import PreventUpdate
6
+ import google.generativeai as genai
7
+ import requests
8
+ import logging
9
+ import threading
10
+ import time
11
+ import os
12
+
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
15
+
16
+ # Updated STYLES list
17
+ STYLES = [
18
+ "photographic", "3d-model", "analog-film", "anime", "cinematic", "comic-book",
19
+ "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly",
20
+ "modeling-compound", "neon-punk", "origami", "pixel-art", "tile-texture"
21
+ ]
22
+
23
+ # Default negative prompt (hidden from UI)
24
+ DEFAULT_NEGATIVE_PROMPT = """
25
+ ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame,
26
+ extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature,
27
+ cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face,
28
+ plastic, cartoonish, artificial, fake, unnatural, blurry, smooth, lack of detail, low quality
29
+ """
30
+
31
+ app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
32
+
33
+ app.layout = dbc.Container([
34
+ html.H1("ImaGen", className="my-4"),
35
+ dbc.Row([
36
+ # Left column: Form entry
37
+ dbc.Col([
38
+ dbc.Card([
39
+ dbc.CardBody([
40
+ dbc.Textarea(id="prompt", placeholder="Enter your prompt", className="mb-3"),
41
+ dcc.Dropdown(
42
+ id="style",
43
+ options=[{"label": s.replace("-", " ").title(), "value": s} for s in STYLES],
44
+ value="photographic",
45
+ placeholder="Select style",
46
+ className="mb-3"
47
+ ),
48
+ dbc.Button("Generate Image", id="submit-btn", color="primary", className="mb-3"),
49
+ dbc.Accordion([
50
+ dbc.AccordionItem(
51
+ [
52
+ dbc.Label("Aspect Ratio"),
53
+ dcc.Dropdown(
54
+ id="aspect-ratio",
55
+ options=[
56
+ {"label": ar, "value": ar} for ar in
57
+ ["16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"]
58
+ ],
59
+ value="1:1"
60
+ ),
61
+ dbc.Label("Steps"),
62
+ dcc.Slider(id="steps", min=4, max=50, step=1, value=30, marks={4: '4', 25: '25', 50: '50'}),
63
+ ],
64
+ title="Advanced Settings",
65
+ ),
66
+ ], start_collapsed=True, className="mb-3"),
67
+ ])
68
+ ], className="mb-4"),
69
+ ], width=6),
70
+ # Right column: Image preview
71
+ dbc.Col([
72
+ dbc.Card([
73
+ dbc.CardBody([
74
+ dcc.Loading(
75
+ id="loading",
76
+ type="circle",
77
+ children=[
78
+ html.Div(id="status-message", className="mb-3"),
79
+ html.Img(id="image-output", className="img-fluid mb-3"),
80
+ html.Div(id="enhanced-prompt-output", className="mb-3"),
81
+ dbc.Button("Download Image", id="download-btn", color="secondary", className="mb-3", disabled=True),
82
+ dcc.Download(id="download-image")
83
+ ]
84
+ ),
85
+ ])
86
+ ]),
87
+ ], width=6),
88
+ ]),
89
+ ], fluid=True)
90
+
91
+ def enhance_prompt(google_api_key, prompt, style):
92
+ genai.configure(api_key=google_api_key)
93
+ model = genai.GenerativeModel("gemini-2.0-flash-lite")
94
+ enhanced_prompt_request = f"""
95
+ Task: Enhance the following prompt with details to match the specified style
96
+ Style: {style}
97
+ Original prompt: '{prompt}'
98
+
99
+ Instructions:
100
+ 1. Expand the prompt to be more detailed, vivid, and realistic with camera used and the setting for that camera like ISO etc.
101
+ 2. Incorporate elements of the specified style.
102
+ 3. Add details that enhance the scene to the specified style
103
+ 4. Emphasize natural lighting and enhance the realism of textures and colors based on the specified style.
104
+ 5. Avoid terms that might result in artificial or cartoonish appearance unless specified by user.
105
+ 6. Maintain the original intent of the prompt while significantly improving its descriptive quality with details.
106
+ 7. Provide ONLY the enhanced prompt, without any explanations or options.
107
+ 8. Keep the enhanced prompt concise, ideally under 100 words.
108
+
109
+ Enhanced prompt:
110
+ """
111
+
112
+ try:
113
+ response = model.generate_content(enhanced_prompt_request)
114
+
115
+ enhanced_prompt = response.text.strip()
116
+
117
+ prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"]
118
+ for prefix in prefixes_to_remove:
119
+ if enhanced_prompt.lower().startswith(prefix.lower()):
120
+ enhanced_prompt = enhanced_prompt[len(prefix):].strip()
121
+
122
+ logging.info(f"Enhanced prompt: {enhanced_prompt}")
123
+ return enhanced_prompt
124
+ except Exception as e:
125
+ logging.error(f"Error in enhance_prompt: {str(e)}")
126
+ raise
127
+
128
+ def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, steps, aspect_ratio):
129
+ url = "https://api.stability.ai/v2beta/stable-image/generate/core"
130
+
131
+ headers = {
132
+ "Accept": "image/*",
133
+ "Authorization": f"Bearer {stability_api_key}"
134
+ }
135
+
136
+ data = {
137
+ "prompt": f"{enhanced_prompt}, Style: {style}, highly detailed, high quality, descriptive, sharp focus, intricate details",
138
+ "negative_prompt": negative_prompt,
139
+ "model": "sd3.5-large-turbo",
140
+ "output_format": "jpeg",
141
+ "num_images": 1,
142
+ "steps": steps,
143
+ "style_preset": style,
144
+ "aspect_ratio": aspect_ratio,
145
+ }
146
+
147
+ try:
148
+ response = requests.post(url, headers=headers, files={"none": ''}, data=data, timeout=60)
149
+ response.raise_for_status()
150
+
151
+ logging.debug(f"Response headers: {response.headers}")
152
+ logging.debug(f"Response content type: {response.headers.get('content-type')}")
153
+
154
+ if response.headers.get('content-type').startswith('image/'):
155
+ image_data = response.content
156
+ if len(image_data) < 1000:
157
+ raise Exception("Received incomplete image data")
158
+ return image_data
159
+ else:
160
+ error_message = response.text
161
+ logging.error(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}")
162
+ raise Exception(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}")
163
+
164
+ except requests.exceptions.RequestException as e:
165
+ logging.error(f"Request failed: {str(e)}")
166
+ raise Exception(f"Request failed: {str(e)}")
167
+
168
+ def process_and_generate(google_api_key, stability_api_key, prompt, style, steps, aspect_ratio, set_status):
169
+ try:
170
+ set_status("Enhancing prompt...")
171
+ enhanced_prompt = enhance_prompt(google_api_key, prompt, style)
172
+
173
+ set_status("Generating image...")
174
+ max_attempts = 3
175
+ for attempt in range(max_attempts):
176
+ try:
177
+ image_bytes = generate_image(stability_api_key, enhanced_prompt, style, DEFAULT_NEGATIVE_PROMPT, steps, aspect_ratio)
178
+ set_status("Image generated successfully!")
179
+ return image_bytes, enhanced_prompt
180
+ except Exception as e:
181
+ if attempt < max_attempts - 1:
182
+ set_status(f"Attempt {attempt + 1} failed. Retrying...")
183
+ time.sleep(2)
184
+ else:
185
+ raise e
186
+ except Exception as e:
187
+ logging.error(f"Error in process_and_generate: {str(e)}")
188
+ set_status(f"Error: {str(e)}")
189
+ return None, str(e)
190
+
191
+ @app.callback(
192
+ [Output("image-output", "src"),
193
+ Output("enhanced-prompt-output", "children"),
194
+ Output("status-message", "children"),
195
+ Output("download-btn", "disabled")],
196
+ [Input("submit-btn", "n_clicks")],
197
+ [State("prompt", "value"),
198
+ State("style", "value"),
199
+ State("steps", "value"),
200
+ State("aspect-ratio", "value")],
201
+ prevent_initial_call=True
202
+ )
203
+ def update_output(n_clicks, prompt, style, steps, aspect_ratio):
204
+ if n_clicks is None:
205
+ raise PreventUpdate
206
+
207
+ google_api_key = os.getenv('GOOGLE_API_KEY')
208
+ stability_api_key = os.getenv('STABILITY_API_KEY')
209
+
210
+ if not google_api_key or not stability_api_key:
211
+ return "", "Error: API keys not found in environment variables", "API keys missing", True
212
+
213
+ logging.debug(f"Stability API Key (first 4 chars): {stability_api_key[:4]}...")
214
+
215
+ status = {"message": "Starting process..."}
216
+
217
+ def set_status(message):
218
+ status["message"] = message
219
+
220
+ def run_process():
221
+ image_bytes, enhanced_prompt = process_and_generate(google_api_key, stability_api_key, prompt, style, steps, aspect_ratio, set_status)
222
+ if image_bytes:
223
+ encoded_image = base64.b64encode(image_bytes).decode('ascii')
224
+ return f"data:image/jpeg;base64,{encoded_image}", f"Enhanced Prompt: {enhanced_prompt}", status["message"], False
225
+ else:
226
+ return "", f"Error: {enhanced_prompt}", status["message"], True
227
+
228
+ try:
229
+ thread = threading.Thread(target=run_process)
230
+ thread.start()
231
+ thread.join(timeout=90)
232
+
233
+ if thread.is_alive():
234
+ return "", "Error: Image generation timed out", "Process timed out", True
235
+
236
+ return run_process()
237
+ except Exception as e:
238
+ logging.error(f"Unexpected error in update_output: {str(e)}")
239
+ return "", f"Unexpected error: {str(e)}", "An unexpected error occurred", True
240
+
241
+ @app.callback(
242
+ Output("download-image", "data"),
243
+ Input("download-btn", "n_clicks"),
244
+ State("image-output", "src"),
245
+ prevent_initial_call=True
246
+ )
247
+ def download_image(n_clicks, image_src):
248
+ if n_clicks is None:
249
+ raise PreventUpdate
250
+
251
+ image_data = image_src.split(",")[1]
252
+ image_bytes = base64.b64decode(image_data)
253
+
254
+ return dcc.send_bytes(image_bytes, "generated_image.jpeg")
255
+
256
+ if __name__ == '__main__':
257
+ print("Starting the Dash application...")
258
+ app.run(debug=False, host='0.0.0.0', port=7860)
259
+ print("Dash application has finished running.")