Spaces:
Paused
Paused
import base64 | |
import dash | |
from dash import dcc, html, Input, Output, State | |
import dash_bootstrap_components as dbc | |
from dash.exceptions import PreventUpdate | |
import google.generativeai as genai | |
import requests | |
import logging | |
import threading | |
import io | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
# List of popular styles | |
STYLES = [ | |
"Photorealistic", "Oil Painting", "Watercolor", "Anime", | |
"Studio Ghibli", "Black and White", "Polaroid", "Sketch", | |
"3D Render", "Pixel Art", "Cyberpunk", "Steampunk", | |
"Art Nouveau", "Pop Art", "Minimalist" | |
] | |
# Default negative prompt | |
DEFAULT_NEGATIVE_PROMPT = """ | |
ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, | |
extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, | |
cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face | |
""" | |
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) | |
app.layout = dbc.Container([ | |
html.H1("Stability AI SD3.5 Large Turbo Image Generator with Google Gemini Prompt Enhancement", className="my-4"), | |
dbc.Card([ | |
dbc.CardBody([ | |
dbc.Input(id="google-api-key", type="password", placeholder="Enter Google AI API Key", className="mb-3"), | |
dbc.Input(id="stability-api-key", type="password", placeholder="Enter Stability AI API Key", className="mb-3"), | |
dbc.Textarea(id="prompt", placeholder="Enter your prompt", className="mb-3"), | |
dcc.Dropdown(id="style", options=[{"label": s, "value": s} for s in STYLES], placeholder="Select style", className="mb-3"), | |
dbc.Textarea(id="negative-prompt", value=DEFAULT_NEGATIVE_PROMPT, className="mb-3"), | |
dbc.Button("Generate Image", id="submit-btn", color="primary", className="mb-3"), | |
]) | |
], className="mb-4"), | |
dbc.Card([ | |
dbc.CardBody([ | |
html.Img(id="image-output", className="img-fluid"), | |
html.Div(id="enhanced-prompt-output", className="mt-3"), | |
]) | |
]) | |
], fluid=True) | |
def enhance_prompt(google_api_key, prompt, style): | |
genai.configure(api_key=google_api_key) | |
model = genai.GenerativeModel("gemini-2.0-flash-lite") | |
enhanced_prompt_request = f""" | |
Task: Enhance the following prompt for image generation. | |
Style: {style} | |
Original prompt: '{prompt}' | |
Instructions: | |
1. Expand the prompt to be more detailed and vivid. | |
2. Incorporate elements of the specified style. | |
3. Maintain the original intent of the prompt. | |
4. Provide ONLY the enhanced prompt, without any explanations or options. | |
5. Keep the enhanced prompt concise, ideally under 100 words. | |
Enhanced prompt: | |
""" | |
try: | |
response = model.generate_content(enhanced_prompt_request) | |
enhanced_prompt = response.text.strip() | |
prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"] | |
for prefix in prefixes_to_remove: | |
if enhanced_prompt.lower().startswith(prefix.lower()): | |
enhanced_prompt = enhanced_prompt[len(prefix):].strip() | |
logging.info(f"Enhanced prompt: {enhanced_prompt}") | |
return enhanced_prompt | |
except Exception as e: | |
logging.error(f"Error in enhance_prompt: {str(e)}") | |
raise | |
def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt): | |
url = "https://api.stability.ai/v2beta/stable-image/generate/sd3" | |
headers = { | |
"Accept": "image/*", | |
"Authorization": f"Bearer {stability_api_key}" | |
} | |
data = { | |
"prompt": f"{enhanced_prompt}, Style: {style}", | |
"negative_prompt": negative_prompt, | |
"model": "sd3.5-large-turbo", | |
"output_format": "jpeg", | |
"width": 1024, | |
"height": 1024, | |
"num_images": 1, | |
"steps": 4, # SD3.5 Large Turbo generates high-quality images in just 4 steps | |
} | |
try: | |
response = requests.post(url, headers=headers, files={"none": ''}, data=data) | |
response.raise_for_status() | |
logging.debug(f"Response headers: {response.headers}") | |
logging.debug(f"Response content type: {response.headers.get('content-type')}") | |
if response.headers.get('content-type').startswith('image/'): | |
return response.content | |
else: | |
error_message = response.text | |
logging.error(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}") | |
raise Exception(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}") | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Request failed: {str(e)}") | |
raise Exception(f"Request failed: {str(e)}") | |
def process_and_generate(google_api_key, stability_api_key, prompt, style, negative_prompt): | |
try: | |
enhanced_prompt = enhance_prompt(google_api_key, prompt, style) | |
image_bytes = generate_image(stability_api_key, enhanced_prompt, style, negative_prompt) | |
return image_bytes, enhanced_prompt | |
except Exception as e: | |
logging.error(f"Error in process_and_generate: {str(e)}") | |
return None, str(e) | |
def update_output(n_clicks, google_api_key, stability_api_key, prompt, style, negative_prompt): | |
if n_clicks is None: | |
raise PreventUpdate | |
def run_process(): | |
image_bytes, enhanced_prompt = process_and_generate(google_api_key, stability_api_key, prompt, style, negative_prompt) | |
if image_bytes: | |
encoded_image = base64.b64encode(image_bytes).decode('ascii') | |
return f"data:image/jpeg;base64,{encoded_image}", f"Enhanced Prompt: {enhanced_prompt}" | |
else: | |
return "", f"Error: {enhanced_prompt}" | |
# Run the process in a separate thread | |
thread = threading.Thread(target=run_process) | |
thread.start() | |
thread.join() # Wait for the thread to complete | |
return run_process() | |
if __name__ == '__main__': | |
print("Starting the Dash application...") | |
app.run(debug=True, host='0.0.0.0', port=7860) | |
print("Dash application has finished running.") |