bluenevus commited on
Commit
2c6acee
·
verified ·
1 Parent(s): 4dc95a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -1,8 +1,5 @@
1
  import gradio as gr
2
  import google.generativeai as genai
3
- from PIL import Image
4
- import io
5
- import base64
6
  import requests
7
 
8
  # List of popular styles
@@ -20,7 +17,8 @@ extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, si
20
  cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face
21
  """
22
 
23
- def enhance_prompt(prompt, style):
 
24
  model = genai.GenerativeModel("gemini-2.0-flash-lite")
25
  enhanced_prompt_request = f"""
26
  Task: Enhance the following prompt for image generation.
@@ -39,10 +37,8 @@ def enhance_prompt(prompt, style):
39
 
40
  response = model.generate_content(enhanced_prompt_request)
41
 
42
- # Extract only the enhanced prompt, removing any potential explanations or extra text
43
  enhanced_prompt = response.text.strip()
44
 
45
- # If the response starts with "Enhanced prompt:" or similar, remove it
46
  prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"]
47
  for prefix in prefixes_to_remove:
48
  if enhanced_prompt.lower().startswith(prefix.lower()):
@@ -50,7 +46,6 @@ def enhance_prompt(prompt, style):
50
 
51
  return enhanced_prompt
52
 
53
-
54
  def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt):
55
  url = "https://api.stability.ai/v2beta/stable-image/generate/sd3"
56
 
@@ -79,12 +74,13 @@ def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt):
79
  else:
80
  return f"Image generation failed: {response.text}"
81
 
82
- def process_and_generate(google_api_key, stability_api_key, prompt, style, negative_prompt):
83
- genai.configure(api_key=google_api_key)
84
-
85
- enhanced_prompt = enhance_prompt(prompt, style)
 
86
  image_url = generate_image(stability_api_key, enhanced_prompt, style, negative_prompt)
87
- return image_url, enhanced_prompt
88
 
89
  with gr.Blocks() as demo:
90
  gr.Markdown("# Stability AI Image Generator with Google Gemini Prompt Enhancement")
@@ -96,16 +92,25 @@ with gr.Blocks() as demo:
96
  prompt = gr.Textbox(label="Prompt")
97
  style = gr.Dropdown(label="Style", choices=STYLES)
98
  negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
 
99
  generate_btn = gr.Button("Generate Image")
100
 
101
  with gr.Column(scale=1):
102
  image_output = gr.Image(label="Generated Image")
103
  enhanced_prompt_output = gr.Textbox(label="Enhanced Prompt")
104
 
 
 
 
 
 
 
 
 
105
  generate_btn.click(
106
- process_and_generate,
107
- inputs=[google_api_key, stability_api_key, prompt, style, negative_prompt],
108
- outputs=[image_output, enhanced_prompt_output]
109
  )
110
 
111
  demo.launch()
 
1
  import gradio as gr
2
  import google.generativeai as genai
 
 
 
3
  import requests
4
 
5
  # List of popular styles
 
17
  cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face
18
  """
19
 
20
+ def enhance_prompt(google_api_key, prompt, style):
21
+ genai.configure(api_key=google_api_key)
22
  model = genai.GenerativeModel("gemini-2.0-flash-lite")
23
  enhanced_prompt_request = f"""
24
  Task: Enhance the following prompt for image generation.
 
37
 
38
  response = model.generate_content(enhanced_prompt_request)
39
 
 
40
  enhanced_prompt = response.text.strip()
41
 
 
42
  prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"]
43
  for prefix in prefixes_to_remove:
44
  if enhanced_prompt.lower().startswith(prefix.lower()):
 
46
 
47
  return enhanced_prompt
48
 
 
49
  def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt):
50
  url = "https://api.stability.ai/v2beta/stable-image/generate/sd3"
51
 
 
74
  else:
75
  return f"Image generation failed: {response.text}"
76
 
77
+ def process_prompt(google_api_key, prompt, style):
78
+ enhanced_prompt = enhance_prompt(google_api_key, prompt, style)
79
+ return enhanced_prompt
80
+
81
+ def generate_from_enhanced(stability_api_key, enhanced_prompt, style, negative_prompt):
82
  image_url = generate_image(stability_api_key, enhanced_prompt, style, negative_prompt)
83
+ return image_url
84
 
85
  with gr.Blocks() as demo:
86
  gr.Markdown("# Stability AI Image Generator with Google Gemini Prompt Enhancement")
 
92
  prompt = gr.Textbox(label="Prompt")
93
  style = gr.Dropdown(label="Style", choices=STYLES)
94
  negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
95
+ enhance_btn = gr.Button("Enhance Prompt")
96
  generate_btn = gr.Button("Generate Image")
97
 
98
  with gr.Column(scale=1):
99
  image_output = gr.Image(label="Generated Image")
100
  enhanced_prompt_output = gr.Textbox(label="Enhanced Prompt")
101
 
102
+ enhanced_prompt_state = gr.State()
103
+
104
+ enhance_btn.click(
105
+ process_prompt,
106
+ inputs=[google_api_key, prompt, style],
107
+ outputs=[enhanced_prompt_output, enhanced_prompt_state]
108
+ )
109
+
110
  generate_btn.click(
111
+ generate_from_enhanced,
112
+ inputs=[stability_api_key, enhanced_prompt_state, style, negative_prompt],
113
+ outputs=[image_output]
114
  )
115
 
116
  demo.launch()