import spaces import os import gradio as gr import random import torch import logging import numpy as np from typing import Dict, Any, List from diffusers import DiffusionPipeline from api import PromptEnhancementSystem # Constants MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_ID = "black-forest-labs/FLUX.1-schnell" DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 print(f"Using device: {DEVICE}") logger = logging.getLogger(__name__) # Initialize model try: print("Loading model...") pipe = DiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=DTYPE ).to(DEVICE) print("Model loaded successfully") logger.info("Model loaded successfully") except Exception as e: print(f"Failed to load model: {str(e)}") logger.error(f"Failed to load model: {str(e)}") raise @spaces.GPU() def generate_multiple_images_batch( improvement_axes, seed=42, randomize_seed=False, width=512, height=512, num_inference_steps=4, progress=gr.Progress(track_tqdm=True) ): try: # Extract prompts from improvement axes prompts = [axis["enhanced_prompt"] for axis in improvement_axes if axis.get("enhanced_prompt")] if not prompts: return [None] * 4 + [seed] if randomize_seed: current_seed = random.randint(0, MAX_SEED) else: current_seed = seed print(f"Generating images with {len(prompts)} prompts") print(f"Using seed: {current_seed}") # Generate all images in a single batch generator = torch.Generator().manual_seed(current_seed) images = pipe( prompt=prompts, # Pass list of prompts directly width=width, height=height, num_inference_steps=num_inference_steps, generator=generator, guidance_scale=0.0 ).images # Pad with None if we have fewer than 4 images while len(images) < 4: images.append(None) print("All images generated successfully") return images[:4] + [current_seed] except Exception as e: print(f"Image generation error: {str(e)}") logger.error(f"Image generation error: {str(e)}") raise def handle_image_select(evt: gr.SelectData, improvement_axes_data): """Handle image selection event""" try: if improvement_axes_data and isinstance(improvement_axes_data, list): selected_index = evt.index[1] if isinstance(evt.index, tuple) else evt.index if selected_index < len(improvement_axes_data): selected_prompt = improvement_axes_data[selected_index].get("enhanced_prompt", "") return selected_prompt return "" except Exception as e: print(f"Error in handle_image_select: {str(e)}") return "" def create_interface(): print("Creating interface...") api_key = os.getenv("GROQ_API_KEY") base_url = os.getenv("API_BASE_URL") if not api_key: print("GROQ_API_KEY not found in environment variables") raise ValueError("GROQ_API_KEY not found in environment variables") system = PromptEnhancementSystem(api_key, base_url) print("PromptEnhancementSystem initialized") def update_interface(prompt): try: print(f"\n=== Processing prompt: {prompt}") state = system.start_session(prompt) improvement_axes = state.get("improvement_axes", []) initial_analysis = state.get("initial_analysis", {}) enhanced_prompt = "" if improvement_axes and len(improvement_axes) > 0: enhanced_prompt = improvement_axes[0].get("enhanced_prompt", prompt) button_updates = [] for i in range(4): if i < len(improvement_axes): focus_area = improvement_axes[i].get("focus_area", f"Option {i+1}") button_updates.append(gr.update(visible=True, value=focus_area)) else: button_updates.append(gr.update(visible=False)) return [prompt, enhanced_prompt] + [ initial_analysis.get(key, {}) for key in [ "subject_analysis", "style_evaluation", "technical_assessment", "composition_review", "context_evaluation", "mood_assessment" ] ] + [ improvement_axes, state.get("technical_recommendations", {}), None, None, None, None, # Four None values for the four image outputs state ] + button_updates except Exception as e: print(f"Error in update_interface: {str(e)}") logger.error(f"Error in update_interface: {str(e)}") empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} return [prompt, prompt] + [empty_analysis] * 6 + [{}, {}, None, None, None, None, {}] + [gr.update(visible=False)] * 4 def handle_option_click(option_num, input_prompt, current_text): try: print(f"\n=== Processing option {option_num}") state = system.current_state if state and "improvement_axes" in state: improvement_axes = state["improvement_axes"] if option_num < len(improvement_axes): selected_prompt = improvement_axes[option_num]["enhanced_prompt"] return [ input_prompt, selected_prompt, state.get("initial_analysis", {}).get("subject_analysis", {}), state.get("initial_analysis", {}).get("style_evaluation", {}), state.get("initial_analysis", {}).get("technical_assessment", {}), state.get("initial_analysis", {}).get("composition_review", {}), state.get("initial_analysis", {}).get("context_evaluation", {}), state.get("initial_analysis", {}).get("mood_assessment", {}), improvement_axes, state.get("technical_recommendations", {}), state ] return handle_error() except Exception as e: print(f"Error in handle_option_click: {str(e)}") logger.error(f"Error in handle_option_click: {str(e)}") return handle_error() def handle_error(): empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} return ["", "", empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, [], {}, {}] with gr.Blocks( title="AI Prompt Enhancement System", theme=gr.themes.Soft(), css="footer {visibility: hidden}" ) as interface: gr.Markdown("# 🎨 AI Prompt Enhancement & Image Generation System") with gr.Row(): input_prompt = gr.Textbox( label="Initial Prompt", placeholder="Enter your prompt here...", lines=3, scale=1 ) current_prompt = gr.Textbox( label="Current Prompt", lines=3, scale=1, interactive=True ) with gr.Row(): start_btn = gr.Button("Start Enhancement", variant="primary") with gr.Row(): option_buttons = [gr.Button("", visible=False) for _ in range(4)] with gr.Tabs(): with gr.TabItem("Initial Analysis"): with gr.Row(): with gr.Column(): subject_analysis = gr.JSON(label="Subject Analysis") with gr.Column(): style_evaluation = gr.JSON(label="Style Evaluation") with gr.Column(): technical_assessment = gr.JSON(label="Technical Assessment") with gr.Row(): with gr.Column(): composition_review = gr.JSON(label="Composition Review") with gr.Column(): context_evaluation = gr.JSON(label="Context Evaluation") with gr.Column(): mood_assessment = gr.JSON(label="Mood Assessment") with gr.TabItem("Generated Images"): with gr.Row(): generated_images = [ gr.Image( label=f"Image {i+1}", type="pil", show_label=True, height=256, width=256, interactive=True, elem_id=f"image_{i}" ) for i in range(4) ] with gr.Row(): finalize_btn = gr.Button("Generate All Images", variant="primary") with gr.Accordion("Image Generation Settings", open=False): with gr.Row(): seed = gr.Slider( label="Seed", minimum=0, maximum=2048, step=1, value=42 ) randomize_seed = gr.Checkbox( label="Randomize seed", value=True ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=2048, step=256, value=512 ) height = gr.Slider( label="Height", minimum=256, maximum=2048, step=256, value=512 ) num_inference_steps = gr.Slider( label="Steps", minimum=1, maximum=50, step=1, value=4 ) with gr.Accordion("Additional Information", open=False): improvement_axes = gr.JSON(label="Improvement Axes") technical_recommendations = gr.JSON(label="Technical Recommendations") full_llm_response = gr.JSON(label="Full LLM Response") # Add select events for each image for i, img in enumerate(generated_images): img.select( fn=handle_image_select, inputs=[improvement_axes], outputs=[input_prompt] ) start_btn.click( update_interface, inputs=[input_prompt], outputs=[ input_prompt, current_prompt, subject_analysis, style_evaluation, technical_assessment, composition_review, context_evaluation, mood_assessment, improvement_axes, technical_recommendations ] + generated_images + [full_llm_response] + option_buttons ) for i, btn in enumerate(option_buttons): btn.click( handle_option_click, inputs=[ gr.Slider(value=i, visible=False), input_prompt, current_prompt ], outputs=[ input_prompt, current_prompt, subject_analysis, style_evaluation, technical_assessment, composition_review, context_evaluation, mood_assessment, improvement_axes, technical_recommendations, full_llm_response ] ) finalize_btn.click( generate_multiple_images_batch, inputs=[ improvement_axes, seed, randomize_seed, width, height, num_inference_steps ], outputs=generated_images + [seed] ) print("Interface setup complete") return interface