Spaces:
Running
Running
import os | |
import gradio as gr | |
import json | |
import time | |
import traceback | |
import io | |
import base64 | |
from PIL import Image, ImageEnhance, ImageFilter | |
# --- Environment Configuration --- | |
GEMINI_KEY = os.environ.get("GEMINI_KEY", "") | |
DEFAULT_PORT = int(os.environ.get("PORT", 7860)) | |
API_TIMEOUT = 120 # seconds | |
# --- Style Template Optimization --- | |
BASE_TEMPLATE = """Describe this design as a concise Flux 1.1 Pro prompt focusing on: | |
- Key visual elements | |
- Technical specifications | |
- Style consistency | |
- Functional requirements""" | |
STYLE_INSTRUCTIONS = { | |
"General": BASE_TEMPLATE, | |
"Realistic": f"{BASE_TEMPLATE}\nPHOTOREALISTIC RULES: Use photography terms, texture details, accurate lighting", | |
"Kawaii": f"{BASE_TEMPLATE}\nKAWAII RULES: Rounded shapes, pastel colors, cute expressions", | |
"Vector": f"{BASE_TEMPLATE}\nVECTOR RULES: Clean lines, geometric shapes, B&W gradients", | |
"Silhouette": f"{BASE_TEMPLATE}\nSILHOUETTE RULES: High contrast, minimal details, strong outlines" | |
} | |
# --- Flux Configuration --- | |
FLUX_SPECS = { | |
"aspect_ratios": ["1:1", "16:9", "4:3", "9:16"], | |
"formats": ["SVG", "PNG", "PDF"], | |
"color_modes": ["B&W", "CMYK", "RGB"], | |
"dpi_options": [72, 150, 300, 600] | |
} | |
# --- Image Processing Pipeline --- | |
def preprocess_image(img): | |
"""Convert and enhance uploaded images""" | |
try: | |
if isinstance(img, str): # Handle file paths | |
img = Image.open(img) | |
img = img.convert("RGB") | |
img = ImageEnhance.Contrast(img).enhance(1.2) | |
img = img.filter(ImageFilter.SHARPEN) | |
return img | |
except Exception as e: | |
raise ValueError(f"Image processing error: {str(e)}") | |
# --- Core Generation Engine --- | |
def generate_prompt(image, api_key, style, creativity, neg_prompt, aspect, color_mode, dpi): | |
try: | |
# Validate inputs | |
if not image: | |
return {"error": "Please upload an image"} | |
api_key = api_key or GEMINI_KEY | |
if not api_key: | |
return {"error": "API key required - set in env (GEMINI_KEY) or input field"} | |
# Import and configure Gemini only when needed | |
try: | |
import google.generativeai as genai | |
genai.configure(api_key=api_key) | |
model = genai.GenerativeModel("gemini-1.5-pro") | |
except ImportError: | |
return {"error": "Failed to import google.generativeai. Install with: pip install google-generativeai"} | |
except Exception as e: | |
if "authentication" in str(e).lower(): | |
return {"error": "Invalid API key or authentication error"} | |
else: | |
return {"error": f"API initialization error: {str(e)}"} | |
# Process image with timeout protection | |
start_time = time.time() | |
img = preprocess_image(image) | |
img_bytes = io.BytesIO() | |
img.save(img_bytes, format="PNG") | |
img_b64 = base64.b64encode(img_bytes.getvalue()).decode() | |
# Build instruction | |
instruction = f"{STYLE_INSTRUCTIONS[style]}\nAVOID: {neg_prompt}\n" | |
instruction += f"ASPECT: {aspect}, COLORS: {color_mode}, DPI: {dpi}\n" | |
# Generate prompt with timeout protection | |
try: | |
response = model.generate_content( | |
contents=[instruction, {"mime_type": "image/png", "data": img_b64}], | |
generation_config={"temperature": creativity} | |
) | |
raw_prompt = response.text | |
except Exception as e: | |
return {"error": f"Generation failed: {str(e)}"} | |
# Simple quality validation | |
validation = {"score": 8, "issues": [], "suggestions": []} | |
# Token tracking | |
input_tokens = len(img_b64) // 4 # Approximate base64 token count | |
output_tokens = len(raw_prompt.split()) | |
return { | |
"prompt": raw_prompt, | |
"validation": validation, | |
"stats": {"input": input_tokens, "output": output_tokens} | |
} | |
except Exception as e: | |
traceback.print_exc() | |
return {"error": str(e)} | |
# --- UI Response Formatting --- | |
def format_generation_response(result): | |
"""Format the response from generate_prompt for the UI""" | |
if "error" in result: | |
return result["error"], None, None | |
else: | |
return result.get("prompt", ""), result.get("validation", {}), result.get("stats", {}) | |
# Modern copy function using Gradio's JavaScript API | |
def copy_text(text): | |
return gr.update(value=text), f"โ Copied: '{text[:20]}...'", gr.Button.update(variant="secondary") | |
# --- Main Interface --- | |
def build_interface(): | |
with gr.Blocks(title="Flux Pro Generator", theme="soft") as app: | |
# Header | |
gr.Markdown("# ๐จ Flux Pro Prompt Generator") | |
gr.Markdown("Generate optimized design prompts from images using Google's Gemini") | |
# Security Section | |
api_key = gr.Textbox( | |
label="๐ Gemini API Key", | |
value=GEMINI_KEY, | |
type="password", | |
info="Set GEMINI_KEY environment variable for production" | |
) | |
# Main Workflow | |
with gr.Row(): | |
with gr.Column(scale=1): | |
img_input = gr.Image( | |
label="๐ผ๏ธ Upload Design", | |
type="pil", | |
sources=["upload"], | |
interactive=True | |
) | |
style = gr.Dropdown( | |
list(STYLE_INSTRUCTIONS.keys()), | |
value="General", | |
label="๐จ Target Style" | |
) | |
# Advanced Settings | |
with gr.Accordion("โ๏ธ Advanced Settings", open=False): | |
creativity = gr.Slider(0.0, 1.0, 0.7, label="Creativity Level") | |
neg_prompt = gr.Textbox(label="๐ซ Negative Prompts", placeholder="What to avoid") | |
aspect = gr.Dropdown(FLUX_SPECS["aspect_ratios"], value="1:1", label="Aspect Ratio") | |
color_mode = gr.Dropdown(FLUX_SPECS["color_modes"], value="RGB", label="Color Mode") | |
dpi = gr.Dropdown([str(d) for d in FLUX_SPECS["dpi_options"]], value="300", label="Output DPI") | |
gen_btn = gr.Button("โจ Generate Prompt", variant="primary") | |
with gr.Column(scale=2): | |
prompt_output = gr.Textbox( | |
label="๐ Optimized Prompt", | |
lines=8, | |
interactive=True, | |
show_copy_button=True # Modern Gradio has built-in copy button | |
) | |
status_msg = gr.Textbox(label="Status", visible=True) | |
with gr.Row(): | |
copy_btn = gr.Button("๐ Copy to Clipboard", variant="secondary") | |
quality_report = gr.JSON( | |
label="๐ Quality Report", | |
visible=True | |
) | |
token_stats = gr.JSON( | |
label="๐งฎ Token Usage", | |
visible=True | |
) | |
# Event Handling | |
gen_btn.click( | |
fn=generate_prompt, | |
inputs=[ | |
img_input, api_key, style, creativity, | |
neg_prompt, aspect, color_mode, dpi | |
], | |
outputs=[prompt_output, quality_report, token_stats], | |
api_name="generate" | |
).then( | |
fn=lambda: gr.update(value="Generation complete!"), | |
outputs=status_msg | |
) | |
# Modern copy implementation | |
copy_btn.click( | |
fn=copy_text, | |
inputs=prompt_output, | |
outputs=[prompt_output, status_msg, copy_btn], | |
js="(text) => { if(text) { navigator.clipboard.writeText(text); } return [text]; }" | |
) | |
return app | |
# --- Production Launch --- | |
if __name__ == "__main__": | |
app = build_interface() | |
app.launch() |