Spaces:
Runtime error
Runtime error
import spaces | |
import os | |
import gradio as gr | |
import random | |
import torch | |
import logging | |
import numpy as np | |
from typing import Dict, Any, List | |
from diffusers import DiffusionPipeline | |
from api import PromptEnhancementSystem | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_ID = "black-forest-labs/FLUX.1-schnell" | |
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
print(f"Using device: {DEVICE}") | |
logger = logging.getLogger(__name__) | |
# Initialize model | |
try: | |
print("Loading model...") | |
pipe = DiffusionPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=DTYPE | |
).to(DEVICE) | |
print("Model loaded successfully") | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
print(f"Failed to load model: {str(e)}") | |
logger.error(f"Failed to load model: {str(e)}") | |
raise | |
def generate_multiple_images_batch( | |
improvement_axes, | |
seed=42, | |
randomize_seed=False, | |
width=512, | |
height=512, | |
num_inference_steps=4, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
try: | |
# Extract prompts from improvement axes | |
prompts = [axis["enhanced_prompt"] for axis in improvement_axes if axis.get("enhanced_prompt")] | |
if not prompts: | |
return [None] * 4 + [seed] | |
if randomize_seed: | |
current_seed = random.randint(0, MAX_SEED) | |
else: | |
current_seed = seed | |
print(f"Generating images with {len(prompts)} prompts") | |
print(f"Using seed: {current_seed}") | |
# Generate all images in a single batch | |
generator = torch.Generator().manual_seed(current_seed) | |
images = pipe( | |
prompt=prompts, # Pass list of prompts directly | |
width=width, | |
height=height, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
guidance_scale=0.0 | |
).images | |
# Pad with None if we have fewer than 4 images | |
while len(images) < 4: | |
images.append(None) | |
print("All images generated successfully") | |
return images[:4] + [current_seed] | |
except Exception as e: | |
print(f"Image generation error: {str(e)}") | |
logger.error(f"Image generation error: {str(e)}") | |
raise | |
def handle_image_select(evt: gr.SelectData, improvement_axes_data): | |
"""Handle image selection event""" | |
try: | |
if improvement_axes_data and isinstance(improvement_axes_data, list): | |
selected_index = evt.index[1] if isinstance(evt.index, tuple) else evt.index | |
if selected_index < len(improvement_axes_data): | |
selected_prompt = improvement_axes_data[selected_index].get("enhanced_prompt", "") | |
return selected_prompt | |
return "" | |
except Exception as e: | |
print(f"Error in handle_image_select: {str(e)}") | |
return "" | |
def create_interface(): | |
print("Creating interface...") | |
api_key = os.getenv("GROQ_API_KEY") | |
base_url = os.getenv("API_BASE_URL") | |
if not api_key: | |
print("GROQ_API_KEY not found in environment variables") | |
raise ValueError("GROQ_API_KEY not found in environment variables") | |
system = PromptEnhancementSystem(api_key, base_url) | |
print("PromptEnhancementSystem initialized") | |
def update_interface(prompt): | |
try: | |
print(f"\n=== Processing prompt: {prompt}") | |
state = system.start_session(prompt) | |
improvement_axes = state.get("improvement_axes", []) | |
initial_analysis = state.get("initial_analysis", {}) | |
enhanced_prompt = "" | |
if improvement_axes and len(improvement_axes) > 0: | |
enhanced_prompt = improvement_axes[0].get("enhanced_prompt", prompt) | |
button_updates = [] | |
for i in range(4): | |
if i < len(improvement_axes): | |
focus_area = improvement_axes[i].get("focus_area", f"Option {i+1}") | |
button_updates.append(gr.update(visible=True, value=focus_area)) | |
else: | |
button_updates.append(gr.update(visible=False)) | |
return [prompt, enhanced_prompt] + [ | |
initial_analysis.get(key, {}) for key in [ | |
"subject_analysis", | |
"style_evaluation", | |
"technical_assessment", | |
"composition_review", | |
"context_evaluation", | |
"mood_assessment" | |
] | |
] + [ | |
improvement_axes, | |
state.get("technical_recommendations", {}), | |
None, None, None, None, # Four None values for the four image outputs | |
state | |
] + button_updates | |
except Exception as e: | |
print(f"Error in update_interface: {str(e)}") | |
logger.error(f"Error in update_interface: {str(e)}") | |
empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} | |
return [prompt, prompt] + [empty_analysis] * 6 + [{}, {}, None, None, None, None, {}] + [gr.update(visible=False)] * 4 | |
def handle_option_click(option_num, input_prompt, current_text): | |
try: | |
print(f"\n=== Processing option {option_num}") | |
state = system.current_state | |
if state and "improvement_axes" in state: | |
improvement_axes = state["improvement_axes"] | |
if option_num < len(improvement_axes): | |
selected_prompt = improvement_axes[option_num]["enhanced_prompt"] | |
return [ | |
input_prompt, | |
selected_prompt, | |
state.get("initial_analysis", {}).get("subject_analysis", {}), | |
state.get("initial_analysis", {}).get("style_evaluation", {}), | |
state.get("initial_analysis", {}).get("technical_assessment", {}), | |
state.get("initial_analysis", {}).get("composition_review", {}), | |
state.get("initial_analysis", {}).get("context_evaluation", {}), | |
state.get("initial_analysis", {}).get("mood_assessment", {}), | |
improvement_axes, | |
state.get("technical_recommendations", {}), | |
state | |
] | |
return handle_error() | |
except Exception as e: | |
print(f"Error in handle_option_click: {str(e)}") | |
logger.error(f"Error in handle_option_click: {str(e)}") | |
return handle_error() | |
def handle_error(): | |
empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} | |
return ["", "", empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, [], {}, {}] | |
with gr.Blocks( | |
title="AI Prompt Enhancement System", | |
theme=gr.themes.Soft(), | |
css="footer {visibility: hidden}" | |
) as interface: | |
gr.Markdown("# 🎨 AI Prompt Enhancement & Image Generation System") | |
with gr.Row(): | |
input_prompt = gr.Textbox( | |
label="Initial Prompt", | |
placeholder="Enter your prompt here...", | |
lines=3, | |
scale=1 | |
) | |
current_prompt = gr.Textbox( | |
label="Current Prompt", | |
lines=3, | |
scale=1, | |
interactive=True | |
) | |
with gr.Row(): | |
start_btn = gr.Button("Start Enhancement", variant="primary") | |
with gr.Row(): | |
option_buttons = [gr.Button("", visible=False) for _ in range(4)] | |
with gr.Tabs(): | |
with gr.TabItem("Initial Analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
subject_analysis = gr.JSON(label="Subject Analysis") | |
with gr.Column(): | |
style_evaluation = gr.JSON(label="Style Evaluation") | |
with gr.Column(): | |
technical_assessment = gr.JSON(label="Technical Assessment") | |
with gr.Row(): | |
with gr.Column(): | |
composition_review = gr.JSON(label="Composition Review") | |
with gr.Column(): | |
context_evaluation = gr.JSON(label="Context Evaluation") | |
with gr.Column(): | |
mood_assessment = gr.JSON(label="Mood Assessment") | |
with gr.TabItem("Generated Images"): | |
with gr.Row(): | |
generated_images = [ | |
gr.Image( | |
label=f"Image {i+1}", | |
type="pil", | |
show_label=True, | |
height=256, | |
width=256, | |
interactive=True, | |
elem_id=f"image_{i}" | |
) for i in range(4) | |
] | |
with gr.Row(): | |
finalize_btn = gr.Button("Generate All Images", variant="primary") | |
with gr.Accordion("Image Generation Settings", open=False): | |
with gr.Row(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2048, | |
step=1, | |
value=42 | |
) | |
randomize_seed = gr.Checkbox( | |
label="Randomize seed", | |
value=True | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=2048, | |
step=256, | |
value=512 | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=2048, | |
step=256, | |
value=512 | |
) | |
num_inference_steps = gr.Slider( | |
label="Steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=4 | |
) | |
with gr.Accordion("Additional Information", open=False): | |
improvement_axes = gr.JSON(label="Improvement Axes") | |
technical_recommendations = gr.JSON(label="Technical Recommendations") | |
full_llm_response = gr.JSON(label="Full LLM Response") | |
# Add select events for each image | |
for i, img in enumerate(generated_images): | |
img.select( | |
fn=handle_image_select, | |
inputs=[improvement_axes], | |
outputs=[input_prompt] | |
) | |
start_btn.click( | |
update_interface, | |
inputs=[input_prompt], | |
outputs=[ | |
input_prompt, | |
current_prompt, | |
subject_analysis, | |
style_evaluation, | |
technical_assessment, | |
composition_review, | |
context_evaluation, | |
mood_assessment, | |
improvement_axes, | |
technical_recommendations | |
] + generated_images + [full_llm_response] + option_buttons | |
) | |
for i, btn in enumerate(option_buttons): | |
btn.click( | |
handle_option_click, | |
inputs=[ | |
gr.Slider(value=i, visible=False), | |
input_prompt, | |
current_prompt | |
], | |
outputs=[ | |
input_prompt, | |
current_prompt, | |
subject_analysis, | |
style_evaluation, | |
technical_assessment, | |
composition_review, | |
context_evaluation, | |
mood_assessment, | |
improvement_axes, | |
technical_recommendations, | |
full_llm_response | |
] | |
) | |
finalize_btn.click( | |
generate_multiple_images_batch, | |
inputs=[ | |
improvement_axes, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
num_inference_steps | |
], | |
outputs=generated_images + [seed] | |
) | |
print("Interface setup complete") | |
return interface |