Spaces:
Runtime error
Runtime error
import spaces | |
import os | |
import gradio as gr | |
import random | |
import torch | |
import logging | |
import numpy as np | |
from typing import Dict, Any, List | |
from diffusers import DiffusionPipeline | |
from api import PromptEnhancementSystem | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_ID = "black-forest-labs/FLUX.1-schnell" | |
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
print(f"Using device: {DEVICE}") | |
logger = logging.getLogger(__name__) | |
# Initialize model | |
try: | |
print("Loading model...") | |
pipe = DiffusionPipeline.from_pretrained( | |
MODEL_ID, | |
torch_dtype=DTYPE | |
).to(DEVICE) | |
print("Model loaded successfully") | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
print(f"Failed to load model: {str(e)}") | |
logger.error(f"Failed to load model: {str(e)}") | |
raise | |
def generate_multiple_images_batch( | |
improvement_axes, | |
current_gallery, | |
seed=42, | |
randomize_seed=False, | |
width=512, | |
height=512, | |
num_inference_steps=4, | |
current_prompt="", | |
initial_prompt="", | |
progress=gr.Progress(track_tqdm=True) | |
): | |
try: | |
# Use current_prompt if not empty, otherwise fall back to initial_prompt | |
input_prompt = current_prompt if current_prompt.strip() else initial_prompt | |
# Extract prompts from improvement axes or use the input prompt if no axes | |
prompts = [axis["enhanced_prompt"] for axis in improvement_axes if axis.get("enhanced_prompt")] | |
if not prompts and input_prompt: | |
prompts = [input_prompt] | |
if not prompts: | |
return [None] * 4 + [current_gallery] + [seed] | |
if randomize_seed: | |
current_seed = random.randint(0, MAX_SEED) | |
else: | |
current_seed = seed | |
print(f"Generating images with prompt: {input_prompt}") | |
print(f"Using seed: {current_seed}") | |
# Generate images with the selected prompt | |
generator = torch.Generator().manual_seed(current_seed) | |
images = pipe( | |
prompt=prompts, | |
width=width, | |
height=height, | |
num_inference_steps=num_inference_steps, | |
generator=generator, | |
max_sequence_length=256, # Maximum allowed for schnell | |
guidance_scale=0.0 | |
).images | |
# Pad with None if we have fewer than 4 images | |
while len(images) < 4: | |
images.append(None) | |
# Update gallery with new images | |
current_gallery = current_gallery or [] | |
new_gallery = current_gallery + [(img, f"Prompt: {prompt}") for img, prompt in zip(images, prompts) if img is not None] | |
print("All images generated successfully") | |
return images[:4] + [new_gallery] + [current_seed] | |
except Exception as e: | |
print(f"Image generation error: {str(e)}") | |
logger.error(f"Image generation error: {str(e)}") | |
raise | |
def handle_image_select(evt: gr.SelectData, improvement_axes_data): | |
try: | |
if improvement_axes_data and isinstance(improvement_axes_data, list): | |
selected_index = evt.index[1] if isinstance(evt.index, tuple) else evt.index | |
if selected_index < len(improvement_axes_data): | |
selected_prompt = improvement_axes_data[selected_index].get("enhanced_prompt", "") | |
return selected_prompt | |
return "" | |
except Exception as e: | |
print(f"Error in handle_image_select: {str(e)}") | |
return "" | |
def handle_gallery_select(evt: gr.SelectData, gallery_data): | |
try: | |
if gallery_data and isinstance(evt.index, int) and evt.index < len(gallery_data): | |
image, prompt = gallery_data[evt.index] | |
# Remove "Prompt: " prefix if it exists | |
prompt = prompt.replace("Prompt: ", "") if prompt else "" | |
return {"prompt": prompt}, prompt | |
return None, "" | |
except Exception as e: | |
print(f"Error in handle_gallery_select: {str(e)}") | |
return None, "" | |
def clear_gallery(): | |
return [], None, None, None, None # Returns empty gallery and clears the 4 images | |
def zip_gallery_images(gallery): | |
try: | |
if not gallery: | |
return None | |
import io | |
import zipfile | |
from datetime import datetime | |
import numpy as np | |
from PIL import Image | |
# Create zip file in memory | |
zip_buffer = io.BytesIO() | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"gallery_images_{timestamp}.zip" | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
for i, (img_data, prompt) in enumerate(gallery): | |
try: | |
if img_data is not None: | |
# Convert numpy array to PIL Image if needed | |
if isinstance(img_data, np.ndarray): | |
img = Image.fromarray(np.uint8(img_data)) | |
elif isinstance(img_data, Image.Image): | |
img = img_data | |
else: | |
print(f"Skipping image {i}: invalid type {type(img_data)}") | |
continue | |
# Save image to bytes | |
img_buffer = io.BytesIO() | |
img.save(img_buffer, format='PNG') | |
img_buffer.seek(0) | |
# Create filename with prompt | |
safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip() | |
img_filename = f"image_{i+1}_{safe_prompt}.png" | |
# Add to zip | |
zip_file.writestr(img_filename, img_buffer.getvalue()) | |
except Exception as img_error: | |
print(f"Error processing image {i}: {str(img_error)}") | |
continue | |
# Prepare zip for download | |
zip_buffer.seek(0) | |
# Return the file data and name | |
return { | |
"name": filename, | |
"data": zip_buffer.getvalue() | |
} | |
except Exception as e: | |
print(f"Error creating zip: {str(e)}") | |
return None | |
def create_interface(): | |
print("Creating interface...") | |
api_key = os.getenv("GROQ_API_KEY") | |
base_url = os.getenv("API_BASE_URL") | |
if not api_key: | |
print("GROQ_API_KEY not found in environment variables") | |
raise ValueError("GROQ_API_KEY not found in environment variables") | |
system = PromptEnhancementSystem(api_key, base_url) | |
print("PromptEnhancementSystem initialized") | |
def update_interface(prompt, user_directive): | |
try: | |
print(f"\n=== Processing prompt: {prompt}") | |
print(f"User directive: {user_directive}") | |
state = system.start_session(prompt, user_directive) | |
improvement_axes = state.get("improvement_axes", []) | |
initial_analysis = state.get("initial_analysis", {}) | |
enhanced_prompt = "" | |
if improvement_axes and len(improvement_axes) > 0: | |
enhanced_prompt = improvement_axes[0].get("enhanced_prompt", prompt) | |
button_updates = [] | |
for i in range(4): | |
if i < len(improvement_axes): | |
focus_area = improvement_axes[i].get("focus_area", f"Option {i+1}") | |
button_updates.append(gr.update(visible=True, value=focus_area)) | |
else: | |
button_updates.append(gr.update(visible=False)) | |
return [prompt, enhanced_prompt] + [ | |
initial_analysis.get(key, {}) for key in [ | |
"subject_analysis", | |
"style_evaluation", | |
"technical_assessment", | |
"composition_review", | |
"context_evaluation", | |
"mood_assessment" | |
] | |
] + [ | |
improvement_axes, | |
state.get("technical_recommendations", {}), | |
state | |
] + button_updates | |
except Exception as e: | |
print(f"Error in update_interface: {str(e)}") | |
logger.error(f"Error in update_interface: {str(e)}") | |
empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} | |
return [prompt, prompt] + [empty_analysis] * 6 + [{}, {}, {}] + [gr.update(visible=False)] * 4 | |
def handle_option_click(option_num, input_prompt, current_text, user_directive): | |
try: | |
print(f"\n=== Processing option {option_num}") | |
state = system.current_state | |
if state and "improvement_axes" in state: | |
improvement_axes = state["improvement_axes"] | |
if option_num < len(improvement_axes): | |
selected_prompt = improvement_axes[option_num]["enhanced_prompt"] | |
return [ | |
input_prompt, | |
selected_prompt, | |
state.get("initial_analysis", {}).get("subject_analysis", {}), | |
state.get("initial_analysis", {}).get("style_evaluation", {}), | |
state.get("initial_analysis", {}).get("technical_assessment", {}), | |
state.get("initial_analysis", {}).get("composition_review", {}), | |
state.get("initial_analysis", {}).get("context_evaluation", {}), | |
state.get("initial_analysis", {}).get("mood_assessment", {}), | |
improvement_axes, | |
state.get("technical_recommendations", {}), | |
state | |
] | |
return handle_error() | |
except Exception as e: | |
print(f"Error in handle_option_click: {str(e)}") | |
logger.error(f"Error in handle_option_click: {str(e)}") | |
return handle_error() | |
def handle_error(): | |
empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]} | |
return ["", "", empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, [], {}, {}] | |
with gr.Blocks( | |
title="AI Prompt Enhancement System", | |
theme=gr.themes.Soft(), | |
css="footer {visibility: hidden}" | |
) as interface: | |
gr.Markdown("# 🎨 AI Prompt Enhancement & Image Generation System") | |
with gr.TabItem("Images Generation"): | |
with gr.Row(): | |
input_prompt = gr.Textbox( | |
label="Initial Prompt", | |
placeholder="Enter your prompt here...", | |
lines=3, | |
scale=1 | |
) | |
with gr.Row(): | |
user_directive = gr.Textbox( | |
label="User Directive", | |
placeholder="Enter specific requirements...", | |
lines=2, | |
scale=1 | |
) | |
with gr.Row(): | |
start_btn = gr.Button("Start Enhancement", variant="primary") | |
with gr.Row(): | |
current_prompt = gr.Textbox( | |
label="Current Prompt", | |
lines=3, | |
scale=1, | |
interactive=True | |
) | |
with gr.Row(): | |
option_buttons = [gr.Button("", visible=False) for _ in range(4)] | |
with gr.Row(): | |
finalize_btn = gr.Button("Generate Images", variant="primary") | |
with gr.Row(): | |
generated_images = [ | |
gr.Image( | |
label=f"Image {i+1}", | |
type="pil", | |
show_label=False, | |
height=256, | |
width=256, | |
interactive=False, | |
show_download_button=False, | |
elem_id=f"image_{i}" | |
) for i in range(4) | |
] | |
with gr.TabItem("Images Gallery"): | |
with gr.Row(): | |
image_gallery = gr.Gallery( | |
label="Generated Images History", | |
show_label=False, | |
columns=4, | |
rows=None, | |
height=800, | |
object_fit="contain" | |
) | |
with gr.Row(): | |
clear_gallery_btn = gr.Button("Clear Gallery", variant="secondary") | |
with gr.Row(): | |
selected_image_data = gr.JSON(label="Selected Image Data", visible=True) | |
copy_to_prompt_btn = gr.Button("Copy Prompt to Current", visible=True) | |
with gr.TabItem("Image Generation Settings"): | |
with gr.Row(): | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42 | |
) | |
randomize_seed = gr.Checkbox( | |
label="Randomize seed", | |
value=True | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=256, | |
value=512 | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=256, | |
value=512 | |
) | |
num_inference_steps = gr.Slider( | |
label="Steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=4 | |
) | |
with gr.TabItem("Initial Analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
subject_analysis = gr.JSON(label="Subject Analysis") | |
with gr.Column(): | |
style_evaluation = gr.JSON(label="Style Evaluation") | |
with gr.Column(): | |
technical_assessment = gr.JSON(label="Technical Assessment") | |
with gr.Row(): | |
with gr.Column(): | |
composition_review = gr.JSON(label="Composition Review") | |
with gr.Column(): | |
context_evaluation = gr.JSON(label="Context Evaluation") | |
with gr.Column(): | |
mood_assessment = gr.JSON(label="Mood Assessment") | |
with gr.Accordion("Additional Information", open=False): | |
improvement_axes = gr.JSON(label="Improvement Axes") | |
technical_recommendations = gr.JSON(label="Technical Recommendations") | |
full_llm_response = gr.JSON(label="Full LLM Response") | |
# Add event handlers | |
for i, img in enumerate(generated_images): | |
img.select( | |
fn=handle_image_select, | |
inputs=[improvement_axes], | |
outputs=[current_prompt], | |
show_progress=False | |
) | |
start_btn.click( | |
update_interface, | |
inputs=[input_prompt, user_directive], | |
outputs=[ | |
input_prompt, | |
current_prompt, | |
subject_analysis, | |
style_evaluation, | |
technical_assessment, | |
composition_review, | |
context_evaluation, | |
mood_assessment, | |
improvement_axes, | |
technical_recommendations, | |
full_llm_response | |
] + option_buttons | |
) | |
for i, btn in enumerate(option_buttons): | |
btn.click( | |
handle_option_click, | |
inputs=[ | |
gr.Slider(value=i, visible=False), | |
input_prompt, | |
current_prompt, | |
user_directive | |
], | |
outputs=[ | |
input_prompt, | |
current_prompt, | |
subject_analysis, | |
style_evaluation, | |
technical_assessment, | |
composition_review, | |
context_evaluation, | |
mood_assessment, | |
improvement_axes, | |
technical_recommendations, | |
full_llm_response | |
] | |
) | |
finalize_btn.click( | |
generate_multiple_images_batch, | |
inputs=[ | |
improvement_axes, | |
image_gallery, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
num_inference_steps, | |
current_prompt, | |
input_prompt | |
], | |
outputs=generated_images + [image_gallery] + [seed] | |
) | |
clear_gallery_btn.click( | |
clear_gallery, | |
inputs=[], | |
outputs=[image_gallery] + generated_images | |
) | |
# Add gallery selection handler | |
image_gallery.select( | |
fn=handle_gallery_select, | |
inputs=[image_gallery], | |
outputs=[selected_image_data, current_prompt] | |
) | |
# Add copy button handler | |
# Fix the copy button handler by adding a null check | |
copy_to_prompt_btn.click( | |
lambda x: x["prompt"] if x and isinstance(x, dict) and "prompt" in x else "", | |
inputs=[selected_image_data], | |
outputs=[current_prompt] | |
) | |
print("Interface setup complete") | |
return interface | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch() |