gemini-image / app.py
Deadmon's picture
Update app.py
1a2d550 verified
raw
history blame
8.34 kB
import base64
import os
import mimetypes
from google import genai
from google.genai import types
import gradio as gr # Correct import for Gradio
import io
from PIL import Image
def save_binary_file(file_name, data):
f = open(file_name, "wb")
f.write(data)
f.close()
def generate_image(prompt, image=None, output_filename="generated_image"):
# Initialize client with the API key
client = genai.Client(
api_key="AIzaSyAQcy3LfrkMy6DqS_8MqftAXu1Bx_ov_E8",
)
model = "gemini-2.0-flash-exp-image-generation"
parts = [types.Part.from_text(text=prompt)]
# If an image is provided, add it to the content
if image:
# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
img_bytes = img_byte_arr.getvalue()
# Add the image as a Part with inline_data
parts.append({
"inline_data": {
"mime_type": "image/png",
"data": img_bytes
}
})
contents = [
types.Content(
role="user",
parts=parts,
),
]
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=[
"image",
"text",
],
safety_settings=[
types.SafetySetting(
category="HARM_CATEGORY_CIVIC_INTEGRITY",
threshold="OFF",
),
],
response_mime_type="text/plain",
)
# Generate the content
response = client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
)
# Process the response
for chunk in response:
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
continue
if chunk.candidates[0].content.parts[0].inline_data:
inline_data = chunk.candidates[0].content.parts[0].inline_data
file_extension = mimetypes.guess_extension(inline_data.mime_type)
filename = f"{output_filename}{file_extension}"
save_binary_file(filename, inline_data.data)
# Convert binary data to PIL Image for Gradio display
img = Image.open(io.BytesIO(inline_data.data))
return img, f"Image saved as {filename}"
else:
return None, chunk.text
return None, "No image generated"
# Function to handle chat interaction
def chat_handler(prompt, user_image, chat_history, output_filename="generated_image"):
# Add the prompt to the chat history
if prompt:
chat_history.append({"role": "user", "content": prompt})
# If no input, return early
if not prompt and not user_image:
chat_history.append({"role": "assistant", "content": "Please provide a prompt or an image."})
return chat_history, user_image, None, ""
# Generate image based on user input
img, status = generate_image(prompt or "Generate an image", user_image, output_filename)
# Add the status message to the chat history
chat_history.append({"role": "assistant", "content": status})
return chat_history, user_image, img, ""
# Function to update chat history with thumbnails
def update_chat_with_thumbnails(chat_history, uploaded_image_url, generated_image_url):
# Create a copy of the chat history to avoid modifying the input directly
updated_history = chat_history.copy()
# If there's an uploaded image, add its thumbnail to the chat history
if uploaded_image_url and uploaded_image_url.strip():
thumbnail_html = f'<img src="{uploaded_image_url}" width="100px" style="margin: 5px;" />'
updated_history.append({"role": "user", "content": thumbnail_html})
# If there's a generated image, add its thumbnail to the chat history
if generated_image_url and generated_image_url.strip():
thumbnail_html = f'<img src="{generated_image_url}" width="100px" style="margin: 5px;" />'
updated_history.append({"role": "assistant", "content": thumbnail_html})
return updated_history
# Create Gradio interface
with gr.Blocks(title="Image Editing Chatbot") as demo:
gr.Markdown("# Image Editing Chatbot")
gr.Markdown("Upload an image and/or type a prompt to generate or edit an image using Google's Gemini model")
# Chatbot display area for text messages and thumbnails
chatbot = gr.Chatbot(
label="Chat",
height=300,
type="messages",
avatar_images=(None, None)
)
# Separate image outputs
with gr.Row():
uploaded_image_output = gr.Image(label="Uploaded Image")
generated_image_output = gr.Image(label="Generated Image")
# Input area
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Upload Image",
type="pil",
scale=1,
height=100
)
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your image description here...",
lines=3
)
filename_input = gr.Textbox(
label="Output Filename",
value="generated_image",
placeholder="Enter desired filename (without extension)"
)
generate_btn = gr.Button("Generate Image")
# Hidden components to store image URLs
uploaded_image_url = gr.State("")
generated_image_url = gr.State("")
# State to maintain chat history
chat_state = gr.State([])
# JavaScript to extract image URLs and update the chat history
gr.HTML("""
<script>
async function updateImageURLs() {
// Get the image elements from the gr.Image components
const uploadedImage = document.querySelector('#component-3 img'); // Uploaded Image component
const generatedImage = document.querySelector('#component-4 img'); // Generated Image component
// Extract the src attributes (URLs)
const uploadedImageURL = uploadedImage && uploadedImage.src ? uploadedImage.src : "";
const generatedImageURL = generatedImage && generatedImage.src ? generatedImage.src : "";
// Call the Python function to update the chat history with the URLs
const chatHistory = await gradioApp().querySelector('#component-2').value; // Chat state component
const result = await gradioApp().querySelector('#component-0').callFunction('update_chat_with_thumbnails', [chatHistory, uploadedImageURL, generatedImageURL]);
return result;
}
// Run the function when the Generate Image button is clicked
document.querySelector('#component-8').addEventListener('click', async () => { // Generate Image button
// Wait for the images to update
setTimeout(async () => {
const updatedChat = await updateImageURLs();
// Update the chatbot component with the new history
document.querySelector('#component-1').value = updatedChat; // Chatbot component
}, 1000); // Delay to ensure images are loaded
});
// Also run the function when the Enter key is pressed in the prompt input
document.querySelector('#component-6').addEventListener('keypress', async (e) => { // Prompt input
if (e.key === 'Enter') {
setTimeout(async () => {
const updatedChat = await updateImageURLs();
document.querySelector('#component-1').value = updatedChat;
}, 1000);
}
});
</script>
""")
# Connect the button to the chat handler
generate_btn.click(
fn=chat_handler,
inputs=[prompt_input, image_input, chat_state, filename_input],
outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
)
# Also allow Enter key to submit
prompt_input.submit(
fn=chat_handler,
inputs=[prompt_input, image_input, chat_state, filename_input],
outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
)
if __name__ == "__main__":
demo.launch()