paint / ui_old.py
baconnier's picture
Upload 10 files
05f2374 verified
import spaces
import os
import gradio as gr
import random
import torch
import logging
import numpy as np
from typing import Dict, Any, List
from diffusers import DiffusionPipeline
from api import PromptEnhancementSystem
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
print(f"Using device: {DEVICE}")
logger = logging.getLogger(__name__)
# Initialize model
try:
print("Loading model...")
pipe = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE
).to(DEVICE)
print("Model loaded successfully")
logger.info("Model loaded successfully")
except Exception as e:
print(f"Failed to load model: {str(e)}")
logger.error(f"Failed to load model: {str(e)}")
raise
@spaces.GPU()
def generate_multiple_images_batch(
improvement_axes,
seed=42,
randomize_seed=False,
width=512,
height=512,
num_inference_steps=4,
progress=gr.Progress(track_tqdm=True)
):
try:
# Extract prompts from improvement axes
prompts = [axis["enhanced_prompt"] for axis in improvement_axes if axis.get("enhanced_prompt")]
if not prompts:
return [None] * 4 + [seed]
if randomize_seed:
current_seed = random.randint(0, MAX_SEED)
else:
current_seed = seed
print(f"Generating images with {len(prompts)} prompts")
print(f"Using seed: {current_seed}")
# Generate all images in a single batch
generator = torch.Generator().manual_seed(current_seed)
images = pipe(
prompt=prompts, # Pass list of prompts directly
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=0.0
).images
# Pad with None if we have fewer than 4 images
while len(images) < 4:
images.append(None)
print("All images generated successfully")
return images[:4] + [current_seed]
except Exception as e:
print(f"Image generation error: {str(e)}")
logger.error(f"Image generation error: {str(e)}")
raise
def handle_image_select(evt: gr.SelectData, improvement_axes_data):
"""Handle image selection event"""
try:
if improvement_axes_data and isinstance(improvement_axes_data, list):
selected_index = evt.index[1] if isinstance(evt.index, tuple) else evt.index
if selected_index < len(improvement_axes_data):
selected_prompt = improvement_axes_data[selected_index].get("enhanced_prompt", "")
return selected_prompt
return ""
except Exception as e:
print(f"Error in handle_image_select: {str(e)}")
return ""
def create_interface():
print("Creating interface...")
api_key = os.getenv("GROQ_API_KEY")
base_url = os.getenv("API_BASE_URL")
if not api_key:
print("GROQ_API_KEY not found in environment variables")
raise ValueError("GROQ_API_KEY not found in environment variables")
system = PromptEnhancementSystem(api_key, base_url)
print("PromptEnhancementSystem initialized")
def update_interface(prompt):
try:
print(f"\n=== Processing prompt: {prompt}")
state = system.start_session(prompt)
improvement_axes = state.get("improvement_axes", [])
initial_analysis = state.get("initial_analysis", {})
enhanced_prompt = ""
if improvement_axes and len(improvement_axes) > 0:
enhanced_prompt = improvement_axes[0].get("enhanced_prompt", prompt)
button_updates = []
for i in range(4):
if i < len(improvement_axes):
focus_area = improvement_axes[i].get("focus_area", f"Option {i+1}")
button_updates.append(gr.update(visible=True, value=focus_area))
else:
button_updates.append(gr.update(visible=False))
return [prompt, enhanced_prompt] + [
initial_analysis.get(key, {}) for key in [
"subject_analysis",
"style_evaluation",
"technical_assessment",
"composition_review",
"context_evaluation",
"mood_assessment"
]
] + [
improvement_axes,
state.get("technical_recommendations", {}),
None, None, None, None, # Four None values for the four image outputs
state
] + button_updates
except Exception as e:
print(f"Error in update_interface: {str(e)}")
logger.error(f"Error in update_interface: {str(e)}")
empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]}
return [prompt, prompt] + [empty_analysis] * 6 + [{}, {}, None, None, None, None, {}] + [gr.update(visible=False)] * 4
def handle_option_click(option_num, input_prompt, current_text):
try:
print(f"\n=== Processing option {option_num}")
state = system.current_state
if state and "improvement_axes" in state:
improvement_axes = state["improvement_axes"]
if option_num < len(improvement_axes):
selected_prompt = improvement_axes[option_num]["enhanced_prompt"]
return [
input_prompt,
selected_prompt,
state.get("initial_analysis", {}).get("subject_analysis", {}),
state.get("initial_analysis", {}).get("style_evaluation", {}),
state.get("initial_analysis", {}).get("technical_assessment", {}),
state.get("initial_analysis", {}).get("composition_review", {}),
state.get("initial_analysis", {}).get("context_evaluation", {}),
state.get("initial_analysis", {}).get("mood_assessment", {}),
improvement_axes,
state.get("technical_recommendations", {}),
state
]
return handle_error()
except Exception as e:
print(f"Error in handle_option_click: {str(e)}")
logger.error(f"Error in handle_option_click: {str(e)}")
return handle_error()
def handle_error():
empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]}
return ["", "", empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, [], {}, {}]
with gr.Blocks(
title="AI Prompt Enhancement System",
theme=gr.themes.Soft(),
css="footer {visibility: hidden}"
) as interface:
gr.Markdown("# 🎨 AI Prompt Enhancement & Image Generation System")
with gr.Row():
input_prompt = gr.Textbox(
label="Initial Prompt",
placeholder="Enter your prompt here...",
lines=3,
scale=1
)
current_prompt = gr.Textbox(
label="Current Prompt",
lines=3,
scale=1,
interactive=True
)
with gr.Row():
start_btn = gr.Button("Start Enhancement", variant="primary")
with gr.Row():
option_buttons = [gr.Button("", visible=False) for _ in range(4)]
with gr.Tabs():
with gr.TabItem("Initial Analysis"):
with gr.Row():
with gr.Column():
subject_analysis = gr.JSON(label="Subject Analysis")
with gr.Column():
style_evaluation = gr.JSON(label="Style Evaluation")
with gr.Column():
technical_assessment = gr.JSON(label="Technical Assessment")
with gr.Row():
with gr.Column():
composition_review = gr.JSON(label="Composition Review")
with gr.Column():
context_evaluation = gr.JSON(label="Context Evaluation")
with gr.Column():
mood_assessment = gr.JSON(label="Mood Assessment")
with gr.TabItem("Generated Images"):
with gr.Row():
generated_images = [
gr.Image(
label=f"Image {i+1}",
type="pil",
show_label=True,
height=256,
width=256,
interactive=True,
elem_id=f"image_{i}"
) for i in range(4)
]
with gr.Row():
finalize_btn = gr.Button("Generate All Images", variant="primary")
with gr.Accordion("Image Generation Settings", open=False):
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2048,
step=1,
value=42
)
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=2048,
step=256,
value=512
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=2048,
step=256,
value=512
)
num_inference_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
step=1,
value=4
)
with gr.Accordion("Additional Information", open=False):
improvement_axes = gr.JSON(label="Improvement Axes")
technical_recommendations = gr.JSON(label="Technical Recommendations")
full_llm_response = gr.JSON(label="Full LLM Response")
# Add select events for each image
for i, img in enumerate(generated_images):
img.select(
fn=handle_image_select,
inputs=[improvement_axes],
outputs=[input_prompt]
)
start_btn.click(
update_interface,
inputs=[input_prompt],
outputs=[
input_prompt,
current_prompt,
subject_analysis,
style_evaluation,
technical_assessment,
composition_review,
context_evaluation,
mood_assessment,
improvement_axes,
technical_recommendations
] + generated_images + [full_llm_response] + option_buttons
)
for i, btn in enumerate(option_buttons):
btn.click(
handle_option_click,
inputs=[
gr.Slider(value=i, visible=False),
input_prompt,
current_prompt
],
outputs=[
input_prompt,
current_prompt,
subject_analysis,
style_evaluation,
technical_assessment,
composition_review,
context_evaluation,
mood_assessment,
improvement_axes,
technical_recommendations,
full_llm_response
]
)
finalize_btn.click(
generate_multiple_images_batch,
inputs=[
improvement_axes,
seed,
randomize_seed,
width,
height,
num_inference_steps
],
outputs=generated_images + [seed]
)
print("Interface setup complete")
return interface