Flux-Trainer / app.py
Daniel Jarvis
app.py rename
83ab419
import gradio as gr
import time
import random
from PIL import Image, ImageDraw, ImageFont
import io
import base64
class ImageGeneratorBot:
def __init__(self):
self.conversation_history = []
self.current_images = {}
self.job_counter = 0
def generate_mock_image(self, prompt, width=512, height=512, seed=None):
"""Generate a mock image with the prompt text overlay"""
if seed:
random.seed(seed)
# Create a random colored background
colors = [(255, 200, 200), (200, 255, 200), (200, 200, 255),
(255, 255, 200), (255, 200, 255), (200, 255, 255)]
bg_color = random.choice(colors)
img = Image.new('RGB', (width, height), bg_color)
draw = ImageDraw.Draw(img)
# Add some random shapes
for _ in range(5):
x1, y1 = random.randint(0, width//2), random.randint(0, height//2)
x2, y2 = random.randint(width//2, width), random.randint(height//2, height)
shape_color = tuple(random.randint(100, 255) for _ in range(3))
draw.rectangle([x1, y1, x2, y2], fill=shape_color, outline=(0,0,0))
# Add prompt text
try:
font = ImageFont.load_default()
except:
font = None
# Wrap text
words = prompt.split()
lines = []
current_line = []
for word in words:
current_line.append(word)
if len(' '.join(current_line)) > 30:
lines.append(' '.join(current_line[:-1]))
current_line = [word]
if current_line:
lines.append(' '.join(current_line))
y_offset = height // 2 - len(lines) * 10
for line in lines:
bbox = draw.textbbox((0, 0), line, font=font)
text_width = bbox[2] - bbox[0]
x_offset = (width - text_width) // 2
draw.text((x_offset, y_offset), line, fill=(0, 0, 0), font=font)
y_offset += 20
return img
def process_prompt(self, prompt, chat_history):
"""Process the user prompt and generate images"""
if not prompt.strip():
return chat_history, "", gr.update(visible=False)
self.job_counter += 1
job_id = f"job_{self.job_counter}"
# Add user message to chat
chat_history.append([prompt, None])
# Simulate processing time
processing_msg = f"🎨 Generating images for: '{prompt}'\n⏳ Processing..."
chat_history.append([None, processing_msg])
# Generate 4 mock images
images = []
for i in range(4):
img = self.generate_mock_image(prompt, seed=hash(prompt + str(i)))
images.append(img)
# Store images for upscaling/variations
self.current_images[job_id] = {
'prompt': prompt,
'images': images,
'original_prompt': prompt
}
# Create response with images
response_html = self.create_image_response(job_id, prompt, images)
chat_history[-1] = [None, response_html]
return chat_history, "", gr.update(visible=True)
def create_image_response(self, job_id, prompt, images):
"""Create HTML response with images in a grid"""
html = f"""
<div style="border: 1px solid #ddd; padding: 15px; border-radius: 10px; background: #f9f9f9;">
<h4 style="margin-top: 0; color: #333;">βœ… Generated Images</h4>
<p style="margin: 5px 0; color: #666;"><strong>Prompt:</strong> {prompt}</p>
<p style="margin: 5px 0 15px 0; color: #666;"><strong>Job ID:</strong> {job_id}</p>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px; max-width: 400px;">
"""
for i, img in enumerate(images):
# Convert PIL image to base64 for HTML display
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
html += f"""
<div style="text-align: center;">
<img src="data:image/png;base64,{img_str}"
style="width: 100%; border-radius: 5px; border: 2px solid #ddd;">
<div style="margin-top: 5px; font-size: 12px; color: #666;">Image {i+1}</div>
</div>
"""
html += """
</div>
<p style="margin: 15px 0 5px 0; color: #666; font-size: 14px;">
Use the buttons below to upscale (U) or create variations (V) of specific images.
</p>
</div>
"""
return html
def upscale_image(self, job_id, image_index, chat_history):
"""Upscale a specific image"""
if job_id not in self.current_images:
return chat_history
original_data = self.current_images[job_id]
original_img = original_data['images'][image_index]
prompt = original_data['prompt']
# Generate upscaled version (2x size)
upscaled_img = self.generate_mock_image(
f"{prompt} (Upscaled)",
width=1024,
height=1024,
seed=hash(prompt + str(image_index) + "upscale")
)
# Create response
buffered = io.BytesIO()
upscaled_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
response_html = f"""
<div style="border: 1px solid #ddd; padding: 15px; border-radius: 10px; background: #f0f8ff;">
<h4 style="margin-top: 0; color: #333;">πŸ” Upscaled Image {image_index + 1}</h4>
<p style="margin: 5px 0; color: #666;"><strong>Original Prompt:</strong> {prompt}</p>
<div style="text-align: center; margin: 15px 0;">
<img src="data:image/png;base64,{img_str}"
style="max-width: 100%; border-radius: 5px; border: 2px solid #4CAF50;">
</div>
<p style="margin: 10px 0 5px 0; color: #666; font-size: 14px;">
✨ Image upscaled to higher resolution!
</p>
</div>
"""
chat_history.append([None, response_html])
return chat_history
def create_variation(self, job_id, image_index, chat_history):
"""Create variations of a specific image"""
if job_id not in self.current_images:
return chat_history
original_data = self.current_images[job_id]
prompt = original_data['prompt']
# Generate 4 variations
variation_images = []
for i in range(4):
var_img = self.generate_mock_image(
f"{prompt} (Variation)",
seed=hash(prompt + str(image_index) + "variation" + str(i))
)
variation_images.append(var_img)
# Store variations
new_job_id = f"job_{self.job_counter}_var_{image_index + 1}"
self.current_images[new_job_id] = {
'prompt': f"{prompt} (Variations of Image {image_index + 1})",
'images': variation_images,
'original_prompt': prompt
}
# Create response
response_html = self.create_image_response(
new_job_id,
f"{prompt} (Variations of Image {image_index + 1})",
variation_images
)
chat_history.append([None, response_html])
return chat_history
# Initialize the bot
bot = ImageGeneratorBot()
# Create the Gradio interface
with gr.Blocks(title="AI Image Generator Chat", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎨 AI Image Generator Chat
### Midjourney-style interface for image generation
Enter your prompt below and get 4 generated images. Use the U buttons to upscale specific images,
or V buttons to create variations of specific images.
""")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
height=500,
show_label=False,
container=True,
bubble_full_width=False
)
with gr.Row():
prompt_input = gr.Textbox(
placeholder="Enter your image prompt (e.g., 'a cat wearing a space suit on Mars --ar 16:9')",
show_label=False,
scale=4
)
generate_btn = gr.Button("Generate", variant="primary")
action_buttons_container = gr.Column(visible=False)
with action_buttons_container:
gr.Markdown("### πŸ”§ Image Actions")
gr.Markdown("**Upscale (U):** Get higher resolution version")
upscale_row = gr.Row()
with upscale_row:
u1_btn = gr.Button("U1", size="sm")
u2_btn = gr.Button("U2", size="sm")
u3_btn = gr.Button("U3", size="sm")
u4_btn = gr.Button("U4", size="sm")
gr.Markdown("**Variations (V):** Generate variations of specific image")
variation_row = gr.Row()
with variation_row:
v1_btn = gr.Button("V1", size="sm", variant="secondary")
v2_btn = gr.Button("V2", size="sm", variant="secondary")
v3_btn = gr.Button("V3", size="sm", variant="secondary")
v4_btn = gr.Button("V4", size="sm", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("""
### πŸ’‘ Tips
- Be descriptive in your prompts
- Use aspect ratios like `--ar 16:9`
- Try artistic styles: "in the style of..."
- Add quality modifiers: "--quality high"
### 🎯 Example Prompts
- "a magical forest with glowing mushrooms"
- "cyberpunk city at night --ar 16:9"
- "portrait of a wise owl wearing glasses"
- "abstract art with vibrant colors"
""")
current_job_id = gr.State(None)
# Event handlers
def handle_generate(prompt, history):
result = bot.process_prompt(prompt, history)
# Extract job_id from the last response to store in state
if bot.current_images:
latest_job = list(bot.current_images.keys())[-1]
return result + (latest_job,)
return result + (None,)
def handle_upscale(job_id, image_idx, history):
if job_id:
return bot.upscale_image(job_id, image_idx, history)
return history
def handle_variation(job_id, image_idx, history):
if job_id:
new_history = bot.create_variation(job_id, image_idx, history)
# Update job_id to the new variation job
if bot.current_images:
new_job_id = list(bot.current_images.keys())[-1]
return new_history, new_job_id
return history, job_id
# Connect events
generate_btn.click(
handle_generate,
inputs=[prompt_input, chatbot],
outputs=[chatbot, prompt_input, action_buttons_container, current_job_id]
)
prompt_input.submit(
handle_generate,
inputs=[prompt_input, chatbot],
outputs=[chatbot, prompt_input, action_buttons_container, current_job_id]
)
# Upscale button events
u1_btn.click(lambda job_id, hist: handle_upscale(job_id, 0, hist), [current_job_id, chatbot], [chatbot])
u2_btn.click(lambda job_id, hist: handle_upscale(job_id, 1, hist), [current_job_id, chatbot], [chatbot])
u3_btn.click(lambda job_id, hist: handle_upscale(job_id, 2, hist), [current_job_id, chatbot], [chatbot])
u4_btn.click(lambda job_id, hist: handle_upscale(job_id, 3, hist), [current_job_id, chatbot], [chatbot])
# Variation button events
v1_btn.click(lambda job_id, hist: handle_variation(job_id, 0, hist), [current_job_id, chatbot], [chatbot, current_job_id])
v2_btn.click(lambda job_id, hist: handle_variation(job_id, 1, hist), [current_job_id, chatbot], [chatbot, current_job_id])
v3_btn.click(lambda job_id, hist: handle_variation(job_id, 2, hist), [current_job_id, chatbot], [chatbot, current_job_id])
v4_btn.click(lambda job_id, hist: handle_variation(job_id, 3, hist), [current_job_id, chatbot], [chatbot, current_job_id])
if __name__ == "__main__":
demo.launch()