Spaces:
Paused
Paused
import base64 | |
import dash | |
from dash import dcc, html, Input, Output, State, callback_context | |
import dash_bootstrap_components as dbc | |
from dash.exceptions import PreventUpdate | |
import google.generativeai as genai | |
import requests | |
import logging | |
import threading | |
import time | |
import os | |
import flask | |
import uuid | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
STYLES = [ | |
"photographic", "3d-model", "analog-film", "anime", "cinematic", "comic-book", | |
"digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly", | |
"modeling-compound", "neon-punk", "origami", "pixel-art", "tile-texture" | |
] | |
DEFAULT_NEGATIVE_PROMPT = """ | |
ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, | |
extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, | |
cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face, | |
plastic, cartoonish, artificial, fake, unnatural, blurry, smooth, lack of detail, low quality | |
""" | |
external_stylesheets = [dbc.themes.BOOTSTRAP] | |
server = flask.Flask(__name__) | |
app = dash.Dash(__name__, server=server, external_stylesheets=external_stylesheets) | |
app.title = "ImaGen" | |
SESSION_DATA = {} | |
SESSION_LOCKS = {} | |
def get_session_id(): | |
if hasattr(flask.g, "session_id"): | |
return flask.g.session_id | |
session_id = flask.request.cookies.get("session_id") | |
if not session_id: | |
session_id = str(uuid.uuid4()) | |
flask.g.session_id = session_id | |
return session_id | |
def ensure_session_id(): | |
session_id = flask.request.cookies.get("session_id") | |
if not session_id: | |
session_id = str(uuid.uuid4()) | |
flask.g.set_cookie = session_id | |
flask.g.session_id = session_id or flask.g.get("set_cookie", None) | |
if session_id not in SESSION_DATA: | |
SESSION_DATA[session_id] = {'image': None, 'enhanced_prompt': None, 'status': None} | |
if session_id not in SESSION_LOCKS: | |
SESSION_LOCKS[session_id] = threading.Lock() | |
def set_session_cookie(response): | |
if hasattr(flask.g, "set_cookie"): | |
response.set_cookie("session_id", flask.g.set_cookie, httponly=True, samesite='Lax') | |
return response | |
def enhance_prompt(google_api_key, prompt, style): | |
genai.configure(api_key=google_api_key) | |
model = genai.GenerativeModel("gemini-2.0-flash-lite") | |
enhanced_prompt_request = f""" | |
Task: Enhance the following prompt with details to match the specified style | |
Style: {style} | |
Original prompt: '{prompt}' | |
Instructions: | |
1. Expand the prompt to be more detailed, vivid, and realistic with camera used and the setting for that camera like ISO etc. | |
2. Incorporate elements of the specified style. | |
3. Add details that enhance the scene to the specified style | |
4. Emphasize natural lighting and enhance the realism of textures and colors based on the specified style. | |
5. Avoid terms that might result in artificial or cartoonish appearance unless specified by user. | |
6. Maintain the original intent of the prompt while significantly improving its descriptive quality with details. | |
7. Provide ONLY the enhanced prompt, without any explanations or options. | |
8. Keep the enhanced prompt concise, ideally under 100 words. | |
Enhanced prompt: | |
""" | |
try: | |
response = model.generate_content(enhanced_prompt_request) | |
enhanced_prompt = response.text.strip() | |
prefixes_to_remove = ["Enhanced prompt:", "Here's the enhanced prompt:", "The enhanced prompt is:"] | |
for prefix in prefixes_to_remove: | |
if enhanced_prompt.lower().startswith(prefix.lower()): | |
enhanced_prompt = enhanced_prompt[len(prefix):].strip() | |
logging.info(f"Enhanced prompt: {enhanced_prompt}") | |
return enhanced_prompt | |
except Exception as e: | |
logging.error(f"Error in enhance_prompt: {str(e)}") | |
raise | |
def generate_image(stability_api_key, enhanced_prompt, style, negative_prompt, steps, aspect_ratio): | |
url = "https://api.stability.ai/v2beta/stable-image/generate/core" | |
headers = { | |
"Accept": "image/*", | |
"Authorization": f"Bearer {stability_api_key}" | |
} | |
data = { | |
"prompt": f"{enhanced_prompt}, Style: {style}, highly detailed, high quality, descriptive, sharp focus, intricate details", | |
"negative_prompt": negative_prompt, | |
"model": "sd3.5-large-turbo", | |
"output_format": "jpeg", | |
"num_images": 1, | |
"steps": steps, | |
"style_preset": style, | |
"aspect_ratio": aspect_ratio, | |
} | |
try: | |
response = requests.post(url, headers=headers, files={"none": ''}, data=data, timeout=60) | |
response.raise_for_status() | |
if response.headers.get('content-type').startswith('image/'): | |
image_data = response.content | |
if len(image_data) < 1000: | |
raise Exception("Received incomplete image data") | |
return image_data | |
else: | |
error_message = response.text | |
logging.error(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}") | |
raise Exception(f"Unexpected content type: {response.headers.get('content-type')}. Response: {error_message}") | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Request failed: {str(e)}") | |
raise Exception(f"Request failed: {str(e)}") | |
def process_and_generate(google_api_key, stability_api_key, prompt, style, steps, aspect_ratio, set_status): | |
try: | |
set_status("Enhancing prompt...") | |
enhanced_prompt = enhance_prompt(google_api_key, prompt, style) | |
set_status("Generating image...") | |
max_attempts = 3 | |
for attempt in range(max_attempts): | |
try: | |
image_bytes = generate_image(stability_api_key, enhanced_prompt, style, DEFAULT_NEGATIVE_PROMPT, steps, aspect_ratio) | |
set_status("Image generated successfully!") | |
return image_bytes, enhanced_prompt | |
except Exception as e: | |
if attempt < max_attempts - 1: | |
set_status(f"Attempt {attempt + 1} failed. Retrying...") | |
time.sleep(2) | |
else: | |
raise e | |
except Exception as e: | |
logging.error(f"Error in process_and_generate: {str(e)}") | |
set_status(f"Error: {str(e)}") | |
return None, str(e) | |
app.layout = dbc.Container([ | |
dbc.Row([ | |
dbc.Col([ | |
html.H1("ImaGen", className="text-center mb-4") | |
], width=12) | |
]), | |
dbc.Row([ | |
dbc.Col([ | |
dbc.Card([ | |
dbc.CardBody([ | |
dbc.Textarea( | |
id="prompt", | |
placeholder="Tell me what image you want me to build.", | |
className="mb-3", | |
style={"height": "120px", "whiteSpace": "pre-wrap", "wordWrap": "break-word"} | |
), | |
dcc.Dropdown( | |
id="style", | |
options=[{"label": s.replace("-", " ").title(), "value": s} for s in STYLES], | |
value="photographic", | |
placeholder="Select style", | |
className="mb-3" | |
), | |
html.Div([ | |
dbc.Label("Aspect Ratio"), | |
dcc.Dropdown( | |
id="aspect-ratio", | |
options=[ | |
{"label": ar, "value": ar} for ar in | |
["16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"] | |
], | |
value="1:1", | |
className="mb-3" | |
), | |
dbc.Label("Steps"), | |
dcc.Slider( | |
id="steps", | |
min=4, | |
max=50, | |
step=1, | |
value=30, | |
marks={4: '4', 25: '25', 50: '50'}, | |
className="mb-3" | |
), | |
]), | |
dbc.Button("Generate Image", id="submit-btn", color="primary", className="mt-2 mb-2", style={"width": "100%"}) | |
]) | |
]) | |
], width=4, style={"minWidth": "300px", "maxWidth": "420px", "flex": "0 0 30%"}), | |
dbc.Col([ | |
dbc.Card([ | |
dbc.CardBody([ | |
dbc.Button("Download Image", id="download-btn", color="secondary", className="mb-3", disabled=True, style={"width": "100%"}), | |
dcc.Loading( | |
id="loading", | |
type="default", | |
children=[ | |
html.Div(id="status-message", className="mb-3"), | |
html.Img(id="image-output", className="img-fluid mb-3"), | |
html.Div(id="enhanced-prompt-output", className="mb-3"), | |
dcc.Download(id="download-image") | |
], | |
style={"display": "block", "margin": "auto"} | |
), | |
]) | |
]) | |
], width=8, style={"flex": "0 0 70%"}) | |
], align="start") | |
], fluid=True) | |
def update_output(n_clicks, prompt, style, steps, aspect_ratio): | |
ctx = callback_context | |
if n_clicks is None: | |
raise PreventUpdate | |
session_id = flask.request.cookies.get("session_id") | |
if not session_id: | |
session_id = str(uuid.uuid4()) | |
lock = SESSION_LOCKS.setdefault(session_id, threading.Lock()) | |
session_data = SESSION_DATA.setdefault(session_id, {'image': None, 'enhanced_prompt': None, 'status': None}) | |
google_api_key = os.getenv('GOOGLE_API_KEY') | |
stability_api_key = os.getenv('STABILITY_API_KEY') | |
if not google_api_key or not stability_api_key: | |
return "", "Error: API keys not found in environment variables", "API keys missing", True | |
status = {"message": "Starting process..."} | |
def set_status(message): | |
status["message"] = message | |
session_data['status'] = message | |
def run_process(): | |
with lock: | |
image_bytes, enhanced_prompt = process_and_generate(google_api_key, stability_api_key, prompt, style, steps, aspect_ratio, set_status) | |
if image_bytes: | |
encoded_image = base64.b64encode(image_bytes).decode('ascii') | |
session_data['image'] = encoded_image | |
session_data['enhanced_prompt'] = enhanced_prompt | |
return f"data:image/jpeg;base64,{encoded_image}", f"Enhanced Prompt: {enhanced_prompt}", status["message"], False | |
else: | |
session_data['image'] = None | |
session_data['enhanced_prompt'] = None | |
return "", f"Error: {enhanced_prompt}", status["message"], True | |
try: | |
thread = threading.Thread(target=run_process) | |
thread.start() | |
thread.join(timeout=90) | |
if thread.is_alive(): | |
with lock: | |
session_data['status'] = "Process timed out" | |
session_data['image'] = None | |
session_data['enhanced_prompt'] = None | |
return "", "Error: Image generation timed out", "Process timed out", True | |
return run_process() | |
except Exception as e: | |
logging.error(f"Unexpected error in update_output: {str(e)}") | |
with lock: | |
session_data['status'] = "An unexpected error occurred" | |
session_data['image'] = None | |
session_data['enhanced_prompt'] = None | |
return "", f"Unexpected error: {str(e)}", "An unexpected error occurred", True | |
def download_image(n_clicks, image_src): | |
ctx = callback_context | |
if n_clicks is None: | |
raise PreventUpdate | |
session_id = flask.request.cookies.get("session_id") | |
lock = SESSION_LOCKS.setdefault(session_id, threading.Lock()) | |
session_data = SESSION_DATA.setdefault(session_id, {'image': None, 'enhanced_prompt': None, 'status': None}) | |
with lock: | |
if not session_data.get('image'): | |
raise PreventUpdate | |
image_bytes = base64.b64decode(session_data['image']) | |
return dcc.send_bytes(image_bytes, "generated_image.jpeg") | |
if __name__ == '__main__': | |
print("Starting the Dash application...") | |
app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) | |
print("Dash application has finished running.") |