Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import json | |
import time | |
import traceback | |
import io | |
import base64 | |
from PIL import Image, ImageEnhance, ImageFilter | |
# --- Environment Configuration --- | |
GEMINI_KEY = os.environ.get("GEMINI_KEY", "") | |
DEFAULT_PORT = int(os.environ.get("PORT", 7860)) | |
API_TIMEOUT = 120 # seconds | |
# --- Style Template Optimization --- | |
BASE_TEMPLATE = """Describe this design as a concise Flux 1.1 Pro prompt focusing on: | |
- Key visual elements | |
- Technical specifications | |
- Style consistency | |
- Functional requirements""" | |
STYLE_INSTRUCTIONS = { | |
"General": BASE_TEMPLATE, | |
"Realistic": f"{BASE_TEMPLATE}\nPHOTOREALISTIC RULES: Use photography terms, texture details, accurate lighting", | |
"Kawaii": f"{BASE_TEMPLATE}\nKAWAII RULES: Rounded shapes, pastel colors, cute expressions", | |
"Vector": f"{BASE_TEMPLATE}\nVECTOR RULES: Clean lines, geometric shapes, B&W gradients", | |
"Silhouette": f"{BASE_TEMPLATE}\nSILHOUETTE RULES: High contrast, minimal details, strong outlines" | |
} | |
# --- Flux Configuration --- | |
FLUX_SPECS = { | |
"aspect_ratios": ["1:1", "16:9", "4:3", "9:16"], | |
"formats": ["SVG", "PNG", "PDF"], | |
"color_modes": ["B&W", "CMYK", "RGB"], | |
"dpi_options": [72, 150, 300, 600] | |
} | |
# --- Image Processing Pipeline --- | |
def preprocess_image(img): | |
"""Convert and enhance uploaded images""" | |
try: | |
if isinstance(img, str): # Handle file paths | |
img = Image.open(img) | |
img = img.convert("RGB") | |
img = ImageEnhance.Contrast(img).enhance(1.2) | |
img = img.filter(ImageFilter.SHARPEN) | |
return img | |
except Exception as e: | |
raise ValueError(f"Image processing error: {str(e)}") | |
# --- Core Generation Engine --- | |
def generate_prompt(image, api_key, style, creativity, neg_prompt, aspect, color_mode, dpi): | |
try: | |
if not image: | |
return {"error": "โ Please upload an image"} | |
api_key = api_key or GEMINI_KEY | |
if not api_key: | |
return {"error": "โ API key required - set GEMINI_KEY environment variable or use input field"} | |
# Lazy import for Gemini | |
try: | |
import google.generativeai as genai | |
genai.configure(api_key=api_key) | |
model = genai.GenerativeModel("gemini-1.5-pro") | |
except ImportError: | |
return {"error": "โ Install Gemini SDK: pip install google-generativeai"} | |
except Exception as e: | |
if "authentication" in str(e).lower(): | |
return {"error": "โ Invalid API key or authentication error"} | |
return {"error": f"โ API initialization error: {str(e)}"} | |
# Image processing | |
img = preprocess_image(image) | |
img_bytes = io.BytesIO() | |
img.save(img_bytes, format="PNG") | |
img_b64 = base64.b64encode(img_bytes.getvalue()).decode() | |
# Prompt instruction | |
instruction = f"{STYLE_INSTRUCTIONS[style]}\nAVOID: {neg_prompt}\n" | |
instruction += f"ASPECT: {aspect}, COLORS: {color_mode}, DPI: {dpi}\n" | |
# Gemini generation | |
try: | |
response = model.generate_content( | |
contents=[instruction, {"mime_type": "image/png", "data": img_b64}], | |
generation_config={"temperature": creativity} | |
) | |
raw_prompt = response.text | |
except Exception as e: | |
return {"error": f"โ Prompt generation failed: {str(e)}"} | |
# Basic validation | |
validation = {"score": 8, "issues": [], "suggestions": []} | |
input_tokens = len(img_b64) // 4 # Approximate | |
output_tokens = len(raw_prompt.split()) | |
return raw_prompt, validation, { | |
"input": input_tokens, | |
"output": output_tokens | |
} | |
except Exception as e: | |
traceback.print_exc() | |
return {"error": str(e)} | |
# --- Response Formatter --- | |
def format_generation_response(result): | |
"""Format the response from generate_prompt for the UI""" | |
if "error" in result: | |
return result["error"], {}, {} | |
else: | |
return result.get("prompt", ""), result.get("validation", {}), result.get("stats", {}) | |
def update_status(result): | |
return "โ Prompt generated successfully!" if "prompt" in result else result.get("error", "โ Unknown error") | |
# --- Main Interface --- | |
def build_interface(): | |
with gr.Blocks(title="Flux Pro Generator") as app: | |
# Header | |
gr.Markdown("# ๐จ Flux Pro Prompt Generator") | |
gr.Markdown("Generate optimized design prompts from images using Google's Gemini") | |
# API Key | |
api_key = gr.Textbox( | |
label="๐ Gemini API Key", | |
value=GEMINI_KEY, | |
type="password", | |
info="Set GEMINI_KEY environment variable for production" | |
) | |
# Inputs | |
with gr.Row(): | |
with gr.Column(scale=1): | |
img_input = gr.Image(label="๐ผ๏ธ Upload Design", type="pil", sources=["upload"], interactive=True) | |
style = gr.Dropdown(list(STYLE_INSTRUCTIONS.keys()), value="General", label="๐จ Target Style") | |
with gr.Accordion("โ๏ธ Advanced Settings", open=False): | |
creativity = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Creativity Level") | |
neg_prompt = gr.Textbox(label="๐ซ Negative Prompts", placeholder="What to avoid") | |
aspect = gr.Dropdown(FLUX_SPECS["aspect_ratios"], value="1:1", label="Aspect Ratio") | |
color_mode = gr.Dropdown(FLUX_SPECS["color_modes"], value="RGB", label="Color Mode") | |
dpi = gr.Dropdown([str(d) for d in FLUX_SPECS["dpi_options"]], value="300", label="Output DPI") | |
gen_btn = gr.Button("โจ Generate Prompt", variant="primary") | |
with gr.Column(scale=2): | |
prompt_output = gr.Textbox(label="๐ Optimized Prompt", lines=8, interactive=True, show_copy_button=True) | |
status_msg = gr.Textbox(label="Status", visible=True) | |
quality_report = gr.JSON(label="๐ Quality Report", visible=True) | |
token_stats = gr.JSON(label="๐งฎ Token Usage", visible=True) | |
# Event bindings | |
gen_btn.click( | |
fn=generate_prompt, | |
inputs=[img_input, api_key, style, creativity, neg_prompt, aspect, color_mode, dpi], | |
outputs=None | |
).then( | |
fn=format_generation_response, | |
inputs=None, | |
outputs=[prompt_output, quality_report, token_stats] | |
).then( | |
fn=update_status, | |
inputs=[prompt_output], | |
outputs=[status_msg] | |
) | |
return app | |
# --- Launch App --- | |
if __name__ == "__main__": | |
app = build_interface() | |
app.launch(server_port=DEFAULT_PORT) |