Spaces:
Running
Running
import base64 | |
import os | |
import mimetypes | |
from google import genai | |
from google.genai import types | |
import gradio as gr | |
import io | |
from PIL import Image | |
import tempfile | |
def save_binary_file(file_name, data): | |
f = open(file_name, "wb") | |
f.write(data) | |
f.close() | |
def generate_image(prompt, image=None, output_filename="generated_image"): | |
# Initialize client with the API key | |
client = genai.Client( | |
api_key="AIzaSyAQcy3LfrkMy6DqS_8MqftAXu1Bx_ov_E8", | |
) | |
model = "gemini-2.0-flash-exp-image-generation" | |
parts = [types.Part.from_text(text=prompt)] | |
# If an image is provided, add it to the content | |
if image: | |
# Convert PIL Image to bytes | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format="PNG") | |
img_bytes = img_byte_arr.getvalue() | |
# Add the image as a Part with inline_data | |
parts.append({ | |
"inline_data": { | |
"mime_type": "image/png", | |
"data": img_bytes | |
} | |
}) | |
contents = [ | |
types.Content( | |
role="user", | |
parts=parts, | |
), | |
] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=1, | |
top_p=0.95, | |
top_k=40, | |
max_output_tokens=8192, | |
response_modalities=[ | |
"image", | |
"text", | |
], | |
safety_settings=[ | |
types.SafetySetting( | |
category="HARM_CATEGORY_CIVIC_INTEGRITY", | |
threshold="OFF", | |
), | |
], | |
response_mime_type="text/plain", | |
) | |
# Generate the content | |
response = client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
) | |
# Process the response | |
for chunk in response: | |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
continue | |
if chunk.candidates[0].content.parts[0].inline_data: | |
inline_data = chunk.candidates[0].content.parts[0].inline_data | |
file_extension = mimetypes.guess_extension(inline_data.mime_type) | |
filename = f"{output_filename}{file_extension}" | |
save_binary_file(filename, inline_data.data) | |
# Convert binary data to PIL Image | |
img = Image.open(io.BytesIO(inline_data.data)) | |
return img, f"Image saved as {filename}" | |
else: | |
return None, chunk.text | |
return None, "No image generated" | |
# Function to handle chat interaction | |
def chat_handler(user_input, user_image, chat_history): | |
# Add user message to chat history | |
if user_image: | |
# Save the uploaded image to a temporary file so Gradio can display it | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: | |
user_image.save(tmp_file.name) | |
# Add the image to the chat history | |
chat_history.append({"role": "user", "content": tmp_file.name}) | |
# Add the text prompt to the chat history | |
if user_input: | |
chat_history.append({"role": "user", "content": user_input}) | |
# If no input (neither text nor image), return early | |
if not user_input and not user_image: | |
chat_history.append({"role": "assistant", "content": "Please provide a prompt or an image."}) | |
return chat_history, None, "" | |
# Generate image based on user input | |
img, status = generate_image(user_input or "Generate an image", user_image) | |
# Add AI response to chat history | |
if img: | |
# Save the PIL Image to a temporary file so Gradio can display it | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: | |
img.save(tmp_file.name) | |
# Add the image as a file path that Gradio can serve | |
chat_history.append({"role": "assistant", "content": tmp_file.name}) | |
# Add the status message | |
chat_history.append({"role": "assistant", "content": status}) | |
return chat_history, None, "" | |
# Create Gradio interface with chatbot layout | |
with gr.Blocks(title="Image Editing Chatbot") as demo: | |
gr.Markdown("# Image Editing Chatbot") | |
gr.Markdown("Upload an image and/or type a prompt to generate or edit an image using Google's Gemini model") | |
# Chatbot display area for the conversation thread | |
chatbot = gr.Chatbot( | |
label="Chat", | |
height=300, # Reduced height from 400 to 300 | |
type="messages", # Explicitly set to 'messages' format | |
avatar_images=(None, None) # No avatars for simplicity | |
) | |
# Input area | |
with gr.Row(): | |
# Image upload button | |
image_input = gr.Image( | |
label="Upload Image", | |
type="pil", | |
scale=1, | |
height=100 | |
) | |
# Text input | |
prompt_input = gr.Textbox( | |
label="", | |
placeholder="Type something", | |
show_label=False, | |
container=False, | |
scale=3 | |
) | |
# Run button | |
run_btn = gr.Button("Run", scale=1) | |
# State to maintain chat history | |
chat_state = gr.State([]) | |
# Connect the button to the chat handler | |
run_btn.click( | |
fn=chat_handler, | |
inputs=[prompt_input, image_input, chat_state], | |
outputs=[chatbot, image_input, prompt_input] | |
) | |
# Also allow Enter key to submit | |
prompt_input.submit( | |
fn=chat_handler, | |
inputs=[prompt_input, image_input, chat_state], | |
outputs=[chatbot, image_input, prompt_input] | |
) | |
if __name__ == "__main__": | |
demo.launch() |