Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
import random | |
from PIL import Image, ImageDraw, ImageFont | |
import io | |
import base64 | |
class ImageGeneratorBot: | |
def __init__(self): | |
self.conversation_history = [] | |
self.current_images = {} | |
self.job_counter = 0 | |
def generate_mock_image(self, prompt, width=512, height=512, seed=None): | |
"""Generate a mock image with the prompt text overlay""" | |
if seed: | |
random.seed(seed) | |
# Create a random colored background | |
colors = [(255, 200, 200), (200, 255, 200), (200, 200, 255), | |
(255, 255, 200), (255, 200, 255), (200, 255, 255)] | |
bg_color = random.choice(colors) | |
img = Image.new('RGB', (width, height), bg_color) | |
draw = ImageDraw.Draw(img) | |
# Add some random shapes | |
for _ in range(5): | |
x1, y1 = random.randint(0, width//2), random.randint(0, height//2) | |
x2, y2 = random.randint(width//2, width), random.randint(height//2, height) | |
shape_color = tuple(random.randint(100, 255) for _ in range(3)) | |
draw.rectangle([x1, y1, x2, y2], fill=shape_color, outline=(0,0,0)) | |
# Add prompt text | |
try: | |
font = ImageFont.load_default() | |
except: | |
font = None | |
# Wrap text | |
words = prompt.split() | |
lines = [] | |
current_line = [] | |
for word in words: | |
current_line.append(word) | |
if len(' '.join(current_line)) > 30: | |
lines.append(' '.join(current_line[:-1])) | |
current_line = [word] | |
if current_line: | |
lines.append(' '.join(current_line)) | |
y_offset = height // 2 - len(lines) * 10 | |
for line in lines: | |
bbox = draw.textbbox((0, 0), line, font=font) | |
text_width = bbox[2] - bbox[0] | |
x_offset = (width - text_width) // 2 | |
draw.text((x_offset, y_offset), line, fill=(0, 0, 0), font=font) | |
y_offset += 20 | |
return img | |
def process_prompt(self, prompt, chat_history): | |
"""Process the user prompt and generate images""" | |
if not prompt.strip(): | |
return chat_history, "", gr.update(visible=False) | |
self.job_counter += 1 | |
job_id = f"job_{self.job_counter}" | |
# Add user message to chat | |
chat_history.append([prompt, None]) | |
# Simulate processing time | |
processing_msg = f"π¨ Generating images for: '{prompt}'\nβ³ Processing..." | |
chat_history.append([None, processing_msg]) | |
# Generate 4 mock images | |
images = [] | |
for i in range(4): | |
img = self.generate_mock_image(prompt, seed=hash(prompt + str(i))) | |
images.append(img) | |
# Store images for upscaling/variations | |
self.current_images[job_id] = { | |
'prompt': prompt, | |
'images': images, | |
'original_prompt': prompt | |
} | |
# Create response with images | |
response_html = self.create_image_response(job_id, prompt, images) | |
chat_history[-1] = [None, response_html] | |
return chat_history, "", gr.update(visible=True) | |
def create_image_response(self, job_id, prompt, images): | |
"""Create HTML response with images in a grid""" | |
html = f""" | |
<div style="border: 1px solid #ddd; padding: 15px; border-radius: 10px; background: #f9f9f9;"> | |
<h4 style="margin-top: 0; color: #333;">β Generated Images</h4> | |
<p style="margin: 5px 0; color: #666;"><strong>Prompt:</strong> {prompt}</p> | |
<p style="margin: 5px 0 15px 0; color: #666;"><strong>Job ID:</strong> {job_id}</p> | |
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px; max-width: 400px;"> | |
""" | |
for i, img in enumerate(images): | |
# Convert PIL image to base64 for HTML display | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
html += f""" | |
<div style="text-align: center;"> | |
<img src="data:image/png;base64,{img_str}" | |
style="width: 100%; border-radius: 5px; border: 2px solid #ddd;"> | |
<div style="margin-top: 5px; font-size: 12px; color: #666;">Image {i+1}</div> | |
</div> | |
""" | |
html += """ | |
</div> | |
<p style="margin: 15px 0 5px 0; color: #666; font-size: 14px;"> | |
Use the buttons below to upscale (U) or create variations (V) of specific images. | |
</p> | |
</div> | |
""" | |
return html | |
def upscale_image(self, job_id, image_index, chat_history): | |
"""Upscale a specific image""" | |
if job_id not in self.current_images: | |
return chat_history | |
original_data = self.current_images[job_id] | |
original_img = original_data['images'][image_index] | |
prompt = original_data['prompt'] | |
# Generate upscaled version (2x size) | |
upscaled_img = self.generate_mock_image( | |
f"{prompt} (Upscaled)", | |
width=1024, | |
height=1024, | |
seed=hash(prompt + str(image_index) + "upscale") | |
) | |
# Create response | |
buffered = io.BytesIO() | |
upscaled_img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
response_html = f""" | |
<div style="border: 1px solid #ddd; padding: 15px; border-radius: 10px; background: #f0f8ff;"> | |
<h4 style="margin-top: 0; color: #333;">π Upscaled Image {image_index + 1}</h4> | |
<p style="margin: 5px 0; color: #666;"><strong>Original Prompt:</strong> {prompt}</p> | |
<div style="text-align: center; margin: 15px 0;"> | |
<img src="data:image/png;base64,{img_str}" | |
style="max-width: 100%; border-radius: 5px; border: 2px solid #4CAF50;"> | |
</div> | |
<p style="margin: 10px 0 5px 0; color: #666; font-size: 14px;"> | |
β¨ Image upscaled to higher resolution! | |
</p> | |
</div> | |
""" | |
chat_history.append([None, response_html]) | |
return chat_history | |
def create_variation(self, job_id, image_index, chat_history): | |
"""Create variations of a specific image""" | |
if job_id not in self.current_images: | |
return chat_history | |
original_data = self.current_images[job_id] | |
prompt = original_data['prompt'] | |
# Generate 4 variations | |
variation_images = [] | |
for i in range(4): | |
var_img = self.generate_mock_image( | |
f"{prompt} (Variation)", | |
seed=hash(prompt + str(image_index) + "variation" + str(i)) | |
) | |
variation_images.append(var_img) | |
# Store variations | |
new_job_id = f"job_{self.job_counter}_var_{image_index + 1}" | |
self.current_images[new_job_id] = { | |
'prompt': f"{prompt} (Variations of Image {image_index + 1})", | |
'images': variation_images, | |
'original_prompt': prompt | |
} | |
# Create response | |
response_html = self.create_image_response( | |
new_job_id, | |
f"{prompt} (Variations of Image {image_index + 1})", | |
variation_images | |
) | |
chat_history.append([None, response_html]) | |
return chat_history | |
# Initialize the bot | |
bot = ImageGeneratorBot() | |
# Create the Gradio interface | |
with gr.Blocks(title="AI Image Generator Chat", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π¨ AI Image Generator Chat | |
### Midjourney-style interface for image generation | |
Enter your prompt below and get 4 generated images. Use the U buttons to upscale specific images, | |
or V buttons to create variations of specific images. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
height=500, | |
show_label=False, | |
container=True, | |
bubble_full_width=False | |
) | |
with gr.Row(): | |
prompt_input = gr.Textbox( | |
placeholder="Enter your image prompt (e.g., 'a cat wearing a space suit on Mars --ar 16:9')", | |
show_label=False, | |
scale=4 | |
) | |
generate_btn = gr.Button("Generate", variant="primary") | |
action_buttons_container = gr.Column(visible=False) | |
with action_buttons_container: | |
gr.Markdown("### π§ Image Actions") | |
gr.Markdown("**Upscale (U):** Get higher resolution version") | |
upscale_row = gr.Row() | |
with upscale_row: | |
u1_btn = gr.Button("U1", size="sm") | |
u2_btn = gr.Button("U2", size="sm") | |
u3_btn = gr.Button("U3", size="sm") | |
u4_btn = gr.Button("U4", size="sm") | |
gr.Markdown("**Variations (V):** Generate variations of specific image") | |
variation_row = gr.Row() | |
with variation_row: | |
v1_btn = gr.Button("V1", size="sm", variant="secondary") | |
v2_btn = gr.Button("V2", size="sm", variant="secondary") | |
v3_btn = gr.Button("V3", size="sm", variant="secondary") | |
v4_btn = gr.Button("V4", size="sm", variant="secondary") | |
with gr.Column(scale=1): | |
gr.Markdown(""" | |
### π‘ Tips | |
- Be descriptive in your prompts | |
- Use aspect ratios like `--ar 16:9` | |
- Try artistic styles: "in the style of..." | |
- Add quality modifiers: "--quality high" | |
### π― Example Prompts | |
- "a magical forest with glowing mushrooms" | |
- "cyberpunk city at night --ar 16:9" | |
- "portrait of a wise owl wearing glasses" | |
- "abstract art with vibrant colors" | |
""") | |
current_job_id = gr.State(None) | |
# Event handlers | |
def handle_generate(prompt, history): | |
result = bot.process_prompt(prompt, history) | |
# Extract job_id from the last response to store in state | |
if bot.current_images: | |
latest_job = list(bot.current_images.keys())[-1] | |
return result + (latest_job,) | |
return result + (None,) | |
def handle_upscale(job_id, image_idx, history): | |
if job_id: | |
return bot.upscale_image(job_id, image_idx, history) | |
return history | |
def handle_variation(job_id, image_idx, history): | |
if job_id: | |
new_history = bot.create_variation(job_id, image_idx, history) | |
# Update job_id to the new variation job | |
if bot.current_images: | |
new_job_id = list(bot.current_images.keys())[-1] | |
return new_history, new_job_id | |
return history, job_id | |
# Connect events | |
generate_btn.click( | |
handle_generate, | |
inputs=[prompt_input, chatbot], | |
outputs=[chatbot, prompt_input, action_buttons_container, current_job_id] | |
) | |
prompt_input.submit( | |
handle_generate, | |
inputs=[prompt_input, chatbot], | |
outputs=[chatbot, prompt_input, action_buttons_container, current_job_id] | |
) | |
# Upscale button events | |
u1_btn.click(lambda job_id, hist: handle_upscale(job_id, 0, hist), [current_job_id, chatbot], [chatbot]) | |
u2_btn.click(lambda job_id, hist: handle_upscale(job_id, 1, hist), [current_job_id, chatbot], [chatbot]) | |
u3_btn.click(lambda job_id, hist: handle_upscale(job_id, 2, hist), [current_job_id, chatbot], [chatbot]) | |
u4_btn.click(lambda job_id, hist: handle_upscale(job_id, 3, hist), [current_job_id, chatbot], [chatbot]) | |
# Variation button events | |
v1_btn.click(lambda job_id, hist: handle_variation(job_id, 0, hist), [current_job_id, chatbot], [chatbot, current_job_id]) | |
v2_btn.click(lambda job_id, hist: handle_variation(job_id, 1, hist), [current_job_id, chatbot], [chatbot, current_job_id]) | |
v3_btn.click(lambda job_id, hist: handle_variation(job_id, 2, hist), [current_job_id, chatbot], [chatbot, current_job_id]) | |
v4_btn.click(lambda job_id, hist: handle_variation(job_id, 3, hist), [current_job_id, chatbot], [chatbot, current_job_id]) | |
if __name__ == "__main__": | |
demo.launch() |