Describer-Pro / app.py
mroccuper's picture
Update app.py
9abec8f verified
raw
history blame
10.2 kB
import os
import gradio as gr
import requests
import json
import time
import traceback
import io
import base64
from PIL import Image, ImageEnhance, ImageFilter
# Conditional imports
try:
import google.generativeai as genai
GENAI_AVAILABLE = True
except ImportError:
GENAI_AVAILABLE = False
print("Warning: google-generativeai not installed, will attempt on-demand import")
try:
import pyperclip
except ImportError:
pyperclip = None
# --- Environment Configuration ---
GEMINI_KEY = os.environ.get("GEMINI_KEY", "")
DEFAULT_PORT = int(os.environ.get("PORT", 7860))
API_TIMEOUT = 30 # seconds
# --- Style Template Optimization ---
BASE_TEMPLATE = """Describe this design as a concise Flux 1.1 Pro prompt focusing on:
- Key visual elements
- Technical specifications
- Style consistency
- Functional requirements"""
STYLE_INSTRUCTIONS = {
"General": BASE_TEMPLATE,
"Realistic": f"{BASE_TEMPLATE}\nPHOTOREALISTIC RULES: Use photography terms, texture details, accurate lighting",
"Kawaii": f"{BASE_TEMPLATE}\nKAWAII RULES: Rounded shapes, pastel colors, cute expressions",
"Vector": f"{BASE_TEMPLATE}\nVECTOR RULES: Clean lines, geometric shapes, B&W gradients",
"Silhouette": f"{BASE_TEMPLATE}\nSILHOUETTE RULES: High contrast, minimal details, strong outlines"
}
# --- Flux Configuration ---
FLUX_SPECS = {
"aspect_ratios": ["1:1", "16:9", "4:3", "9:16"],
"formats": ["SVG", "PNG", "PDF"],
"color_modes": ["B&W", "CMYK", "RGB"],
"dpi_options": [72, 150, 300, 600]
}
# --- Quality Control System ---
class QualityValidator:
VALIDATION_TEMPLATE = """Analyze this Flux prompt:
1. Score style adherence (1-5)
2. List technical issues
3. Suggest improvements
Respond ONLY as JSON: {"score": x/10, "issues": [], "suggestions": []}"""
@classmethod
def validate(cls, prompt, model):
try:
with gr.utils.TempFiles() as temp:
response = model.generate_content([cls.VALIDATION_TEMPLATE, prompt])
return json.loads(response.text)
except Exception as e:
print(f"Validation error: {str(e)}")
return {"score": 0, "issues": ["Validation failed"], "suggestions": []}
# --- Lazy API Initialization ---
def init_genai_api(api_key):
"""Initialize Gemini API with error handling"""
if not GENAI_AVAILABLE:
try:
# Attempt dynamic import
global genai
import google.generativeai as genai
except ImportError:
raise ValueError("Failed to import google.generativeai. Install with: pip install google-generativeai")
try:
genai.configure(api_key=api_key)
# Test connection with minimal request
model = genai.GenerativeModel("gemini-1.5-pro")
model.generate_content("test", request_options={"timeout": 5})
return model
except Exception as e:
if "authentication" in str(e).lower():
raise ValueError("Invalid API key or authentication error")
elif "timeout" in str(e).lower():
raise ValueError("API connection timeout - check your internet connection")
else:
raise ValueError(f"API initialization error: {str(e)}")
# --- Image Processing Pipeline ---
def preprocess_image(img):
"""Convert and enhance uploaded images"""
try:
if isinstance(img, str): # Handle file paths
img = Image.open(img)
img = img.convert("RGB")
img = ImageEnhance.Contrast(img).enhance(1.2)
img = img.filter(ImageFilter.SHARPEN)
return img
except Exception as e:
raise ValueError(f"Image processing error: {str(e)}")
# --- Core Generation Engine ---
def generate_prompt(image, api_key, style, creativity, neg_prompt, aspect, color_mode, dpi):
try:
# Validate inputs
if not image:
return {"error": "Please upload an image"}
api_key = api_key or GEMINI_KEY
if not api_key:
return {"error": "API key required - set in env (GEMINI_KEY) or input field"}
# Initialize model with proper error handling
try:
model = init_genai_api(api_key)
except ValueError as e:
return {"error": str(e)}
# Process image with timeout protection
start_time = time.time()
img = preprocess_image(image)
img_bytes = io.BytesIO()
img.save(img_bytes, format="PNG")
img_b64 = base64.b64encode(img_bytes.getvalue()).decode()
# Build instruction
instruction = f"{STYLE_INSTRUCTIONS[style]}\nAVOID: {neg_prompt}\n"
instruction += f"ASPECT: {aspect}, COLORS: {color_mode}, DPI: {dpi}\n"
# Generate prompt with timeout protection
try:
response = model.generate_content(
contents=[instruction, {"mime_type": "image/png", "data": img_b64}],
generation_config={"temperature": creativity},
request_options={"timeout": API_TIMEOUT}
)
raw_prompt = response.text
except requests.exceptions.Timeout:
return {"error": "API request timed out (>30s). Try a smaller image or check your connection."}
except Exception as e:
return {"error": f"Generation failed: {str(e)}"}
# Quality validation (skip if taking too long)
validation = {"score": 8, "issues": [], "suggestions": []}
if time.time() - start_time < 20: # Only validate if we have time
try:
validation = QualityValidator.validate(raw_prompt, model)
if validation.get("score", 0) < 7:
response = model.generate_content(
f"Improve this prompt: {raw_prompt}\nIssues: {validation['issues']}",
request_options={"timeout": 10}
)
raw_prompt = response.text
except:
# Continue even if validation fails
pass
# Token tracking
input_tokens = len(img_b64) // 4 # Approximate base64 token count
output_tokens = len(raw_prompt.split())
return {
"prompt": raw_prompt,
"validation": validation,
"stats": {"input": input_tokens, "output": output_tokens}
}
except Exception as e:
traceback.print_exc()
return {"error": str(e)}
# --- UI Components ---
def create_advanced_controls():
with gr.Accordion("โš™๏ธ Advanced Settings", open=False):
with gr.Row():
creativity = gr.Slider(0.0, 1.0, 0.7, label="Creativity Level")
neg_prompt = gr.Textbox(label="๐Ÿšซ Negative Prompts", placeholder="What to avoid")
with gr.Row():
aspect = gr.Dropdown(FLUX_SPECS["aspect_ratios"], value="1:1", label="Aspect Ratio")
color_mode = gr.Dropdown(FLUX_SPECS["color_modes"], value="RGB", label="Color Mode")
dpi = gr.Dropdown([str(d) for d in FLUX_SPECS["dpi_options"]], value="300", label="Output DPI")
return [creativity, neg_prompt, aspect, color_mode, dpi]
# --- UI Response Formatting ---
def format_generation_response(result):
"""Format the response from generate_prompt for the UI"""
if "error" in result:
return result["error"], None, None
else:
return result.get("prompt", ""), result.get("validation", {}), result.get("stats", {})
# --- Main Interface ---
def build_interface():
with gr.Blocks(title="Flux Pro Generator", theme=gr.themes.Soft()) as app:
# Header
gr.Markdown("# ๐ŸŽจ Flux Pro Prompt Generator")
gr.Markdown("Generate optimized design prompts from images using Google's Gemini")
# Security Section
api_key = gr.Textbox(
label="๐Ÿ”‘ Gemini API Key",
value=GEMINI_KEY,
type="password",
info="Set GEMINI_KEY environment variable for production"
)
# Main Workflow
with gr.Row(variant="panel"):
with gr.Column(scale=1):
img_input = gr.Image(
label="๐Ÿ–ผ๏ธ Upload Design",
type="pil",
sources=["upload"],
interactive=True
)
style = gr.Dropdown(
list(STYLE_INSTRUCTIONS.keys()),
value="General",
label="๐ŸŽจ Target Style"
)
adv_controls = create_advanced_controls()
gen_btn = gr.Button("โœจ Generate Prompt", variant="primary")
status_msg = gr.Textbox(label="Status", visible=True)
with gr.Column(scale=2):
prompt_output = gr.Textbox(
label="๐Ÿ“ Optimized Prompt",
lines=8,
interactive=False
)
with gr.Row():
copy_btn = gr.Button("๐Ÿ“‹ Copy")
quality_report = gr.JSON(
label="๐Ÿ” Quality Report",
visible=True
)
token_stats = gr.JSON(
label="๐Ÿงฎ Token Usage",
visible=True
)
# Event Handling
gen_btn.click(
lambda *args: format_generation_response(generate_prompt(*args)),
inputs=[img_input, api_key, style] + adv_controls,
outputs=[prompt_output, quality_report, token_stats],
api_name="generate"
)
if pyperclip:
copy_btn.click(
lambda x: pyperclip.copy(x) if x else None,
inputs=prompt_output,
outputs=None
)
else:
copy_btn.click(
lambda: "Copy functionality not available (pyperclip not installed)",
inputs=None,
outputs=status_msg
)
return app
# --- Production Launch ---
if __name__ == "__main__":
app = build_interface()
app.launch(server_name="0.0.0.0", server_port=DEFAULT_PORT, share=False)