Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import requests | |
import json | |
import time | |
import traceback | |
import io | |
import base64 | |
from PIL import Image, ImageEnhance, ImageFilter | |
# Conditional imports | |
try: | |
import google.generativeai as genai | |
GENAI_AVAILABLE = True | |
except ImportError: | |
GENAI_AVAILABLE = False | |
print("Warning: google-generativeai not installed, will attempt on-demand import") | |
try: | |
import pyperclip | |
except ImportError: | |
pyperclip = None | |
# --- Environment Configuration --- | |
GEMINI_KEY = os.environ.get("GEMINI_KEY", "") | |
DEFAULT_PORT = int(os.environ.get("PORT", 7860)) | |
API_TIMEOUT = 30 # seconds | |
# --- Style Template Optimization --- | |
BASE_TEMPLATE = """Describe this design as a concise Flux 1.1 Pro prompt focusing on: | |
- Key visual elements | |
- Technical specifications | |
- Style consistency | |
- Functional requirements""" | |
STYLE_INSTRUCTIONS = { | |
"General": BASE_TEMPLATE, | |
"Realistic": f"{BASE_TEMPLATE}\nPHOTOREALISTIC RULES: Use photography terms, texture details, accurate lighting", | |
"Kawaii": f"{BASE_TEMPLATE}\nKAWAII RULES: Rounded shapes, pastel colors, cute expressions", | |
"Vector": f"{BASE_TEMPLATE}\nVECTOR RULES: Clean lines, geometric shapes, B&W gradients", | |
"Silhouette": f"{BASE_TEMPLATE}\nSILHOUETTE RULES: High contrast, minimal details, strong outlines" | |
} | |
# --- Flux Configuration --- | |
FLUX_SPECS = { | |
"aspect_ratios": ["1:1", "16:9", "4:3", "9:16"], | |
"formats": ["SVG", "PNG", "PDF"], | |
"color_modes": ["B&W", "CMYK", "RGB"], | |
"dpi_options": [72, 150, 300, 600] | |
} | |
# --- Quality Control System --- | |
class QualityValidator: | |
VALIDATION_TEMPLATE = """Analyze this Flux prompt: | |
1. Score style adherence (1-5) | |
2. List technical issues | |
3. Suggest improvements | |
Respond ONLY as JSON: {"score": x/10, "issues": [], "suggestions": []}""" | |
def validate(cls, prompt, model): | |
try: | |
with gr.utils.TempFiles() as temp: | |
response = model.generate_content([cls.VALIDATION_TEMPLATE, prompt]) | |
return json.loads(response.text) | |
except Exception as e: | |
print(f"Validation error: {str(e)}") | |
return {"score": 0, "issues": ["Validation failed"], "suggestions": []} | |
# --- Lazy API Initialization --- | |
def init_genai_api(api_key): | |
"""Initialize Gemini API with error handling""" | |
if not GENAI_AVAILABLE: | |
try: | |
# Attempt dynamic import | |
global genai | |
import google.generativeai as genai | |
except ImportError: | |
raise ValueError("Failed to import google.generativeai. Install with: pip install google-generativeai") | |
try: | |
genai.configure(api_key=api_key) | |
# Test connection with minimal request | |
model = genai.GenerativeModel("gemini-1.5-pro") | |
model.generate_content("test", request_options={"timeout": 5}) | |
return model | |
except Exception as e: | |
if "authentication" in str(e).lower(): | |
raise ValueError("Invalid API key or authentication error") | |
elif "timeout" in str(e).lower(): | |
raise ValueError("API connection timeout - check your internet connection") | |
else: | |
raise ValueError(f"API initialization error: {str(e)}") | |
# --- Image Processing Pipeline --- | |
def preprocess_image(img): | |
"""Convert and enhance uploaded images""" | |
try: | |
if isinstance(img, str): # Handle file paths | |
img = Image.open(img) | |
img = img.convert("RGB") | |
img = ImageEnhance.Contrast(img).enhance(1.2) | |
img = img.filter(ImageFilter.SHARPEN) | |
return img | |
except Exception as e: | |
raise ValueError(f"Image processing error: {str(e)}") | |
# --- Core Generation Engine --- | |
def generate_prompt(image, api_key, style, creativity, neg_prompt, aspect, color_mode, dpi): | |
try: | |
# Validate inputs | |
if not image: | |
return {"error": "Please upload an image"} | |
api_key = api_key or GEMINI_KEY | |
if not api_key: | |
return {"error": "API key required - set in env (GEMINI_KEY) or input field"} | |
# Initialize model with proper error handling | |
try: | |
model = init_genai_api(api_key) | |
except ValueError as e: | |
return {"error": str(e)} | |
# Process image with timeout protection | |
start_time = time.time() | |
img = preprocess_image(image) | |
img_bytes = io.BytesIO() | |
img.save(img_bytes, format="PNG") | |
img_b64 = base64.b64encode(img_bytes.getvalue()).decode() | |
# Build instruction | |
instruction = f"{STYLE_INSTRUCTIONS[style]}\nAVOID: {neg_prompt}\n" | |
instruction += f"ASPECT: {aspect}, COLORS: {color_mode}, DPI: {dpi}\n" | |
# Generate prompt with timeout protection | |
try: | |
response = model.generate_content( | |
contents=[instruction, {"mime_type": "image/png", "data": img_b64}], | |
generation_config={"temperature": creativity}, | |
request_options={"timeout": API_TIMEOUT} | |
) | |
raw_prompt = response.text | |
except requests.exceptions.Timeout: | |
return {"error": "API request timed out (>30s). Try a smaller image or check your connection."} | |
except Exception as e: | |
return {"error": f"Generation failed: {str(e)}"} | |
# Quality validation (skip if taking too long) | |
validation = {"score": 8, "issues": [], "suggestions": []} | |
if time.time() - start_time < 20: # Only validate if we have time | |
try: | |
validation = QualityValidator.validate(raw_prompt, model) | |
if validation.get("score", 0) < 7: | |
response = model.generate_content( | |
f"Improve this prompt: {raw_prompt}\nIssues: {validation['issues']}", | |
request_options={"timeout": 10} | |
) | |
raw_prompt = response.text | |
except: | |
# Continue even if validation fails | |
pass | |
# Token tracking | |
input_tokens = len(img_b64) // 4 # Approximate base64 token count | |
output_tokens = len(raw_prompt.split()) | |
return { | |
"prompt": raw_prompt, | |
"validation": validation, | |
"stats": {"input": input_tokens, "output": output_tokens} | |
} | |
except Exception as e: | |
traceback.print_exc() | |
return {"error": str(e)} | |
# --- UI Components --- | |
def create_advanced_controls(): | |
with gr.Accordion("โ๏ธ Advanced Settings", open=False): | |
with gr.Row(): | |
creativity = gr.Slider(0.0, 1.0, 0.7, label="Creativity Level") | |
neg_prompt = gr.Textbox(label="๐ซ Negative Prompts", placeholder="What to avoid") | |
with gr.Row(): | |
aspect = gr.Dropdown(FLUX_SPECS["aspect_ratios"], value="1:1", label="Aspect Ratio") | |
color_mode = gr.Dropdown(FLUX_SPECS["color_modes"], value="RGB", label="Color Mode") | |
dpi = gr.Dropdown([str(d) for d in FLUX_SPECS["dpi_options"]], value="300", label="Output DPI") | |
return [creativity, neg_prompt, aspect, color_mode, dpi] | |
# --- UI Response Formatting --- | |
def format_generation_response(result): | |
"""Format the response from generate_prompt for the UI""" | |
if "error" in result: | |
return result["error"], None, None | |
else: | |
return result.get("prompt", ""), result.get("validation", {}), result.get("stats", {}) | |
# --- Main Interface --- | |
def build_interface(): | |
with gr.Blocks(title="Flux Pro Generator", theme=gr.themes.Soft()) as app: | |
# Header | |
gr.Markdown("# ๐จ Flux Pro Prompt Generator") | |
gr.Markdown("Generate optimized design prompts from images using Google's Gemini") | |
# Security Section | |
api_key = gr.Textbox( | |
label="๐ Gemini API Key", | |
value=GEMINI_KEY, | |
type="password", | |
info="Set GEMINI_KEY environment variable for production" | |
) | |
# Main Workflow | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
img_input = gr.Image( | |
label="๐ผ๏ธ Upload Design", | |
type="pil", | |
sources=["upload"], | |
interactive=True | |
) | |
style = gr.Dropdown( | |
list(STYLE_INSTRUCTIONS.keys()), | |
value="General", | |
label="๐จ Target Style" | |
) | |
adv_controls = create_advanced_controls() | |
gen_btn = gr.Button("โจ Generate Prompt", variant="primary") | |
status_msg = gr.Textbox(label="Status", visible=True) | |
with gr.Column(scale=2): | |
prompt_output = gr.Textbox( | |
label="๐ Optimized Prompt", | |
lines=8, | |
interactive=False | |
) | |
with gr.Row(): | |
copy_btn = gr.Button("๐ Copy") | |
quality_report = gr.JSON( | |
label="๐ Quality Report", | |
visible=True | |
) | |
token_stats = gr.JSON( | |
label="๐งฎ Token Usage", | |
visible=True | |
) | |
# Event Handling | |
gen_btn.click( | |
lambda *args: format_generation_response(generate_prompt(*args)), | |
inputs=[img_input, api_key, style] + adv_controls, | |
outputs=[prompt_output, quality_report, token_stats], | |
api_name="generate" | |
) | |
if pyperclip: | |
copy_btn.click( | |
lambda x: pyperclip.copy(x) if x else None, | |
inputs=prompt_output, | |
outputs=None | |
) | |
else: | |
copy_btn.click( | |
lambda: "Copy functionality not available (pyperclip not installed)", | |
inputs=None, | |
outputs=status_msg | |
) | |
return app | |
# --- Production Launch --- | |
if __name__ == "__main__": | |
app = build_interface() | |
app.launch(server_name="0.0.0.0", server_port=DEFAULT_PORT, share=False) |