import base64
import os
import mimetypes
from google import genai
from google.genai import types
import gradio as gr
import io
from PIL import Image
from huggingface_hub import login
GEMINI_KEY = os.environ.get("GEMINI_API_KEY")
def save_binary_file(file_name, data):
f = open(file_name, "wb")
f.write(data)
f.close()
def generate_image(prompt, image=None):
# Initialize client with the API key
client = genai.Client(
api_key=GEMINI_KEY,
)
model = "gemini-2.0-flash-exp-image-generation"
parts = [types.Part.from_text(text=prompt)]
# If an image is provided, add it to the content
if image:
# Convert PIL Image to bytes
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
img_bytes = img_byte_arr.getvalue()
# Add the image as a Part with inline_data
parts.append({
"inline_data": {
"mime_type": "image/png",
"data": img_bytes
}
})
contents = [
types.Content(
role="user",
parts=parts,
),
]
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=[
"image",
"text",
],
safety_settings=[
types.SafetySetting(
category="HARM_CATEGORY_CIVIC_INTEGRITY",
threshold="OFF",
),
],
response_mime_type="text/plain",
)
# Generate the content
response = client.models.generate_content_stream(
model=model,
contents=contents,
config=generate_content_config,
)
full_text_response = "" # For debugging text truncation
# Process the response
for chunk in response:
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
continue
if chunk.candidates[0].content.parts[0].inline_data:
inline_data = chunk.candidates[0].content.parts[0].inline_data
file_extension = mimetypes.guess_extension(inline_data.mime_type)
filename = f"generated_image{file_extension}" # Hardcoded filename
save_binary_file(filename, inline_data.data)
# Convert binary data to PIL Image
img = Image.open(io.BytesIO(inline_data.data))
return img, f"Image saved as {filename}"
elif chunk.text:
full_text_response += chunk.text # Append chunk text for full response
print("Chunk Text Response:", chunk.text) # Debugging chunk text
print("Full Text Response from Gemini:", full_text_response) # Debugging full text
return None, full_text_response # Return full text response
# Function to handle chat interaction
def chat_handler(prompt, user_image, chat_history):
# Add the user prompt to the chat history - ONLY TEXT PROMPT for user message
if prompt:
chat_history.append({"role": "user", "content": prompt})
if user_image is not None:
# If there's a user image, add a separate message for the high-quality image in a smaller container
buffered = io.BytesIO()
user_image.save(buffered, format="PNG")
user_image_base64 = base64.b64encode(buffered.getvalue()).decode()
user_image_data_uri = f"data:image/png;base64,{user_image_base64}"
chat_history.append({"role": "user", "content": gr.HTML(f'
')})
# If no input, return early
if not prompt and not user_image:
chat_history.append({"role": "assistant", "content": "Please provide a prompt or an image."})
return chat_history, user_image, None, ""
# Generate image based on user input
img, status = generate_image(prompt or "Generate an image", user_image)
thumbnail_data_uri = None # Initialize to None
if img:
# Use full-resolution image in a smaller container
img = img.convert("RGB") # Force RGB mode for consistency
buffered = io.BytesIO()
img.save(buffered, format="PNG")
thumbnail_base64 = base64.b64encode(buffered.getvalue()).decode()
thumbnail_data_uri = f"data:image/png;base64,{thumbnail_base64}"
print("Image Data URI:", thumbnail_data_uri) # Print to console
assistant_message_content = gr.HTML(f'
') # Use gr.HTML with CSS
else:
assistant_message_content = status # If no image, send text status
# Update chat history - Assistant message is now EITHER gr.HTML or text
chat_history.append({"role": "assistant", "content": assistant_message_content})
return chat_history, user_image, img, ""
# Create Gradio interface
with gr.Blocks(title="Image Editing Chatbot") as demo:
gr.Markdown("# Image Editing Chatbot")
gr.Markdown("Upload an image and/or type a prompt to generate or edit an image using Google's Gemini model")
# Chatbot display area for text messages
chatbot = gr.Chatbot(
label="Chat",
height=300,
type="messages",
avatar_images=(None, None)
)
# Separate image outputs
with gr.Row():
uploaded_image_output = gr.Image(label="Uploaded Image")
generated_image_output = gr.Image(label="Generated Image")
# Input area
with gr.Row():
with gr.Column(scale=2): # Increased scale for better spacing
image_input = gr.Image(
label="Upload Image",
type="pil",
scale=1,
height=150, # Increased height for better visibility
container=True, # Ensure the component has a container for padding
elem_classes="p-4" # Add padding via CSS class (4 units of padding)
)
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your image description here...",
lines=3,
elem_classes="mt-2" # Add margin-top for spacing between components
)
generate_btn = gr.Button("Generate Image", elem_classes="mt-2") # Add margin-top for the button
# State to maintain chat history
chat_state = gr.State([])
# Connect the button to the chat handler
generate_btn.click(
fn=chat_handler,
inputs=[prompt_input, image_input, chat_state],
outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
)
# Also allow Enter key to submit
prompt_input.submit(
fn=chat_handler,
inputs=[prompt_input, image_input, chat_state],
outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
)
if __name__ == "__main__":
demo.launch()