bluenevus commited on
Commit
57df966
·
verified ·
1 Parent(s): 77e7a61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -92
app.py CHANGED
@@ -31,81 +31,39 @@ plastic, cartoonish, artificial, fake, unnatural
31
  app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
32
 
33
  app.layout = dbc.Container([
34
- html.H1("Insta-Image", className="my-4"),
35
- dbc.Card([
36
- dbc.CardBody([
37
- dbc.Input(id="google-api-key", type="password", placeholder="Enter Google AI API Key", className="mb-3"),
38
- dbc.Input(id="stability-api-key", type="password", placeholder="Enter Stability AI API Key", className="mb-3"),
39
- dbc.Textarea(id="prompt", placeholder="Enter your prompt", className="mb-3"),
40
- dcc.Dropdown(id="style", options=[{"label": s, "value": s} for s in STYLES], placeholder="Select style", className="mb-3"),
41
- dbc.Accordion([
42
- dbc.AccordionItem(
43
- [
44
- dbc.Row([
45
- dbc.Col([
46
- html.Label("CFG Scale:", className="mr-2"),
47
- dcc.Slider(id="cfg-scale", min=1, max=30, step=0.5, value=7,
48
- marks={1: '1', 15: '15', 30: '30'},
49
- tooltip={"placement": "bottom", "always_visible": True}),
50
- dbc.Tooltip(
51
- "Controls the influence of the prompt. Higher values adhere more closely to the prompt.",
52
- target="cfg-scale",
53
- ),
54
- ], width=12, className="mb-3"),
55
- ]),
56
- dbc.Row([
57
- dbc.Col([
58
- html.Label("Steps:", className="mr-2"),
59
- dcc.Slider(id="steps", min=4, max=50, step=1, value=20,
60
- marks={4: '4', 25: '25', 50: '50'},
61
- tooltip={"placement": "bottom", "always_visible": True}),
62
- dbc.Tooltip(
63
- "Number of denoising steps. More steps can lead to higher quality but longer generation time.",
64
- target="steps",
65
- ),
66
- ], width=12, className="mb-3"),
67
- ]),
68
- dbc.Row([
69
- dbc.Col([
70
- html.Label("Sampler:", className="mr-2"),
71
- dcc.Dropdown(
72
- id="sampler",
73
- options=[
74
- {"label": "DDIM", "value": "DDIM"},
75
- {"label": "PLMS", "value": "PLMS"},
76
- {"label": "K_EULER", "value": "K_EULER"},
77
- {"label": "K_EULER_ANCESTRAL", "value": "K_EULER_ANCESTRAL"},
78
- {"label": "DPM_2", "value": "DPM_2"},
79
- {"label": "DPM_2_ANCESTRAL", "value": "DPM_2_ANCESTRAL"},
80
- ],
81
- value="K_EULER_ANCESTRAL",
82
- ),
83
- dbc.Tooltip(
84
- "The algorithm used for image generation. Different samplers can produce varying results.",
85
- target="sampler",
86
- ),
87
- ], width=12, className="mb-3"),
88
- ]),
89
- ],
90
- title="Advanced Settings",
91
- )
92
- ], start_collapsed=True, className="mb-3"),
93
- dbc.Button("Generate Image", id="submit-btn", color="primary", className="mb-3"),
94
- ])
95
- ], className="mb-4"),
96
- dbc.Card([
97
- dbc.CardBody([
98
- dcc.Loading(
99
- id="loading",
100
- type="circle",
101
- children=[
102
- html.Div(id="status-message", className="mb-3"),
103
- html.Img(id="image-output", className="img-fluid"),
104
- html.Div(id="enhanced-prompt-output", className="mt-3"),
105
- ]
106
- ),
107
- ])
108
- ])
109
  ], fluid=True)
110
 
111
  def enhance_prompt(google_api_key, prompt, style):
@@ -117,11 +75,11 @@ def enhance_prompt(google_api_key, prompt, style):
117
  Original prompt: '{prompt}'
118
 
119
  Instructions:
120
- 1. Expand the prompt to be more detailed, vivid, and realism and always include the right camera used for the shot
121
  2. Incorporate elements of the specified style, focusing on realism and natural appearances.
122
  3. Add details that enhance the realism of the scene, especially for elements like trees, textures, and lighting.
123
  4. Emphasize natural lighting and enhance the realism of textures and colors.
124
- 5. Avoid terms that might result in artificial or cartoonish appearances unless specifically requested
125
  6. Maintain the original intent of the prompt while significantly improving its descriptive quality.
126
  7. Provide ONLY the enhanced prompt, without any explanations or options.
127
  8. Keep the enhanced prompt concise, ideally under 100 words.
@@ -145,7 +103,7 @@ def enhance_prompt(google_api_key, prompt, style):
145
  logging.error(f"Error in enhance_prompt: {str(e)}")
146
  raise
147
 
148
- def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, cfg_scale, steps, sampler):
149
  url = "https://api.stability.ai/v2beta/stable-image/generate/sd3"
150
 
151
  headers = {
@@ -161,9 +119,8 @@ def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, c
161
  "width": 1024,
162
  "height": 1024,
163
  "num_images": 1,
164
- "steps": steps,
165
- "cfg_scale": cfg_scale,
166
- "sampler": sampler,
167
  }
168
 
169
  try:
@@ -184,13 +141,13 @@ def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, c
184
  logging.error(f"Request failed: {str(e)}")
185
  raise Exception(f"Request failed: {str(e)}")
186
 
187
- def process_and_generate(google_api_key, stability_api_key, prompt, style, cfg_scale, steps, sampler, set_status):
188
  try:
189
  set_status("Enhancing prompt...")
190
  enhanced_prompt = enhance_prompt(google_api_key, prompt, style)
191
 
192
  set_status("Generating image...")
193
- image_bytes = generate_image(stability_api_key, enhanced_prompt, style, DEFAULT_NEGATIVE_PROMPT, cfg_scale, steps, sampler)
194
 
195
  set_status("Image generated successfully!")
196
  return image_bytes, enhanced_prompt
@@ -202,18 +159,16 @@ def process_and_generate(google_api_key, stability_api_key, prompt, style, cfg_s
202
  @app.callback(
203
  [Output("image-output", "src"),
204
  Output("enhanced-prompt-output", "children"),
205
- Output("status-message", "children")],
 
206
  [Input("submit-btn", "n_clicks")],
207
  [State("google-api-key", "value"),
208
  State("stability-api-key", "value"),
209
  State("prompt", "value"),
210
- State("style", "value"),
211
- State("cfg-scale", "value"),
212
- State("steps", "value"),
213
- State("sampler", "value")],
214
  prevent_initial_call=True
215
  )
216
- def update_output(n_clicks, google_api_key, stability_api_key, prompt, style, cfg_scale, steps, sampler):
217
  if n_clicks is None:
218
  raise PreventUpdate
219
 
@@ -225,12 +180,12 @@ def update_output(n_clicks, google_api_key, stability_api_key, prompt, style, cf
225
  status["message"] = message
226
 
227
  def run_process():
228
- image_bytes, enhanced_prompt = process_and_generate(google_api_key, stability_api_key, prompt, style, cfg_scale, steps, sampler, set_status)
229
  if image_bytes:
230
  encoded_image = base64.b64encode(image_bytes).decode('ascii')
231
- return f"data:image/jpeg;base64,{encoded_image}", f"Enhanced Prompt: {enhanced_prompt}", status["message"]
232
  else:
233
- return "", f"Error: {enhanced_prompt}", status["message"]
234
 
235
  # Run the process in a separate thread
236
  thread = threading.Thread(target=run_process)
@@ -239,6 +194,24 @@ def update_output(n_clicks, google_api_key, stability_api_key, prompt, style, cf
239
 
240
  return run_process()
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  if __name__ == '__main__':
243
  print("Starting the Dash application...")
244
  app.run(debug=True, host='0.0.0.0', port=7860)
 
31
  app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
32
 
33
  app.layout = dbc.Container([
34
+ html.H1("Stability AI SD3.5 Large Turbo Image Generator with Google Gemini Prompt Enhancement", className="my-4"),
35
+ dbc.Row([
36
+ # Left column: Form entry
37
+ dbc.Col([
38
+ dbc.Card([
39
+ dbc.CardBody([
40
+ dbc.Input(id="google-api-key", type="password", placeholder="Enter Google AI API Key", className="mb-3"),
41
+ dbc.Input(id="stability-api-key", type="password", placeholder="Enter Stability AI API Key", className="mb-3"),
42
+ dbc.Textarea(id="prompt", placeholder="Enter your prompt", className="mb-3"),
43
+ dcc.Dropdown(id="style", options=[{"label": s, "value": s} for s in STYLES], placeholder="Select style", className="mb-3"),
44
+ dbc.Button("Generate Image", id="submit-btn", color="primary", className="mb-3"),
45
+ ])
46
+ ], className="mb-4"),
47
+ ], width=6),
48
+ # Right column: Image preview
49
+ dbc.Col([
50
+ dbc.Card([
51
+ dbc.CardBody([
52
+ dcc.Loading(
53
+ id="loading",
54
+ type="circle",
55
+ children=[
56
+ html.Div(id="status-message", className="mb-3"),
57
+ html.Img(id="image-output", className="img-fluid mb-3"),
58
+ html.Div(id="enhanced-prompt-output", className="mb-3"),
59
+ dbc.Button("Download Image", id="download-btn", color="secondary", className="mb-3", disabled=True),
60
+ dcc.Download(id="download-image")
61
+ ]
62
+ ),
63
+ ])
64
+ ]),
65
+ ], width=6),
66
+ ]),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ], fluid=True)
68
 
69
  def enhance_prompt(google_api_key, prompt, style):
 
75
  Original prompt: '{prompt}'
76
 
77
  Instructions:
78
+ 1. Expand the prompt to be more detailed, vivid, and photorealistic.
79
  2. Incorporate elements of the specified style, focusing on realism and natural appearances.
80
  3. Add details that enhance the realism of the scene, especially for elements like trees, textures, and lighting.
81
  4. Emphasize natural lighting and enhance the realism of textures and colors.
82
+ 5. Avoid terms that might result in artificial or cartoonish appearances.
83
  6. Maintain the original intent of the prompt while significantly improving its descriptive quality.
84
  7. Provide ONLY the enhanced prompt, without any explanations or options.
85
  8. Keep the enhanced prompt concise, ideally under 100 words.
 
103
  logging.error(f"Error in enhance_prompt: {str(e)}")
104
  raise
105
 
106
+ def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt):
107
  url = "https://api.stability.ai/v2beta/stable-image/generate/sd3"
108
 
109
  headers = {
 
119
  "width": 1024,
120
  "height": 1024,
121
  "num_images": 1,
122
+ "steps": 20,
123
+ "cfg_scale": 7,
 
124
  }
125
 
126
  try:
 
141
  logging.error(f"Request failed: {str(e)}")
142
  raise Exception(f"Request failed: {str(e)}")
143
 
144
+ def process_and_generate(google_api_key, stability_api_key, prompt, style, set_status):
145
  try:
146
  set_status("Enhancing prompt...")
147
  enhanced_prompt = enhance_prompt(google_api_key, prompt, style)
148
 
149
  set_status("Generating image...")
150
+ image_bytes = generate_image(stability_api_key, enhanced_prompt, style, DEFAULT_NEGATIVE_PROMPT)
151
 
152
  set_status("Image generated successfully!")
153
  return image_bytes, enhanced_prompt
 
159
  @app.callback(
160
  [Output("image-output", "src"),
161
  Output("enhanced-prompt-output", "children"),
162
+ Output("status-message", "children"),
163
+ Output("download-btn", "disabled")],
164
  [Input("submit-btn", "n_clicks")],
165
  [State("google-api-key", "value"),
166
  State("stability-api-key", "value"),
167
  State("prompt", "value"),
168
+ State("style", "value")],
 
 
 
169
  prevent_initial_call=True
170
  )
171
+ def update_output(n_clicks, google_api_key, stability_api_key, prompt, style):
172
  if n_clicks is None:
173
  raise PreventUpdate
174
 
 
180
  status["message"] = message
181
 
182
  def run_process():
183
+ image_bytes, enhanced_prompt = process_and_generate(google_api_key, stability_api_key, prompt, style, set_status)
184
  if image_bytes:
185
  encoded_image = base64.b64encode(image_bytes).decode('ascii')
186
+ return f"data:image/jpeg;base64,{encoded_image}", f"Enhanced Prompt: {enhanced_prompt}", status["message"], False
187
  else:
188
+ return "", f"Error: {enhanced_prompt}", status["message"], True
189
 
190
  # Run the process in a separate thread
191
  thread = threading.Thread(target=run_process)
 
194
 
195
  return run_process()
196
 
197
+ @app.callback(
198
+ Output("download-image", "data"),
199
+ Input("download-btn", "n_clicks"),
200
+ State("image-output", "src"),
201
+ prevent_initial_call=True
202
+ )
203
+ def download_image(n_clicks, image_src):
204
+ if n_clicks is None:
205
+ raise PreventUpdate
206
+
207
+ # Extract the base64 encoded image data
208
+ image_data = image_src.split(",")[1]
209
+
210
+ # Decode the base64 data
211
+ image_bytes = base64.b64decode(image_data)
212
+
213
+ return dcc.send_bytes(image_bytes, "generated_image.jpeg")
214
+
215
  if __name__ == '__main__':
216
  print("Starting the Dash application...")
217
  app.run(debug=True, host='0.0.0.0', port=7860)