Spaces:
Running
Running
import base64 | |
import os | |
import mimetypes | |
from google import genai | |
from google.genai import types | |
import gradio as gr | |
import io | |
from PIL import Image | |
from huggingface_hub import login | |
GEMINI_KEY = os.environ.get("GEMINI_API_KEY") | |
def save_binary_file(file_name, data): | |
f = open(file_name, "wb") | |
f.write(data) | |
f.close() | |
def generate_image(prompt, image=None): | |
# Initialize client with the API key | |
client = genai.Client( | |
api_key=GEMINI_KEY, | |
) | |
model = "gemini-2.0-flash-exp-image-generation" | |
parts = [types.Part.from_text(text=prompt)] | |
# If an image is provided, add it to the content | |
if image: | |
# Convert PIL Image to bytes | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format="PNG") | |
img_bytes = img_byte_arr.getvalue() | |
# Add the image as a Part with inline_data | |
parts.append({ | |
"inline_data": { | |
"mime_type": "image/png", | |
"data": img_bytes | |
} | |
}) | |
contents = [ | |
types.Content( | |
role="user", | |
parts=parts, | |
), | |
] | |
generate_content_config = types.GenerateContentConfig( | |
temperature=1, | |
top_p=0.95, | |
top_k=40, | |
max_output_tokens=8192, | |
response_modalities=[ | |
"image", | |
"text", | |
], | |
safety_settings=[ | |
types.SafetySetting( | |
category="HARM_CATEGORY_CIVIC_INTEGRITY", | |
threshold="OFF", | |
), | |
], | |
response_mime_type="text/plain", | |
) | |
# Generate the content | |
response = client.models.generate_content_stream( | |
model=model, | |
contents=contents, | |
config=generate_content_config, | |
) | |
full_text_response = "" # For debugging text truncation | |
# Process the response | |
for chunk in response: | |
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: | |
continue | |
if chunk.candidates[0].content.parts[0].inline_data: | |
inline_data = chunk.candidates[0].content.parts[0].inline_data | |
file_extension = mimetypes.guess_extension(inline_data.mime_type) | |
filename = f"generated_image{file_extension}" # Hardcoded filename | |
save_binary_file(filename, inline_data.data) | |
# Convert binary data to PIL Image | |
img = Image.open(io.BytesIO(inline_data.data)) | |
return img, f"Image saved as {filename}" | |
elif chunk.text: | |
full_text_response += chunk.text # Append chunk text for full response | |
print("Chunk Text Response:", chunk.text) # Debugging chunk text | |
print("Full Text Response from Gemini:", full_text_response) # Debugging full text | |
return None, full_text_response # Return full text response | |
# Function to handle chat interaction | |
def chat_handler(prompt, user_image, chat_history): | |
# Add the user prompt to the chat history - ONLY TEXT PROMPT for user message | |
if prompt: | |
chat_history.append({"role": "user", "content": prompt}) | |
if user_image is not None: | |
# If there's a user image, add a separate message for the high-quality image in a smaller container | |
buffered = io.BytesIO() | |
user_image.save(buffered, format="PNG") | |
user_image_base64 = base64.b64encode(buffered.getvalue()).decode() | |
user_image_data_uri = f"data:image/png;base64,{user_image_base64}" | |
chat_history.append({"role": "user", "content": gr.HTML(f'<img src="{user_image_data_uri}" alt="Uploaded Image" style="width:100px; height:100px; object-fit:contain;">')}) | |
# If no input, return early | |
if not prompt and not user_image: | |
chat_history.append({"role": "assistant", "content": "Please provide a prompt or an image."}) | |
return chat_history, user_image, None, "" | |
# Generate image based on user input | |
img, status = generate_image(prompt or "Generate an image", user_image) | |
thumbnail_data_uri = None # Initialize to None | |
if img: | |
# Use full-resolution image in a smaller container | |
img = img.convert("RGB") # Force RGB mode for consistency | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
thumbnail_base64 = base64.b64encode(buffered.getvalue()).decode() | |
thumbnail_data_uri = f"data:image/png;base64,{thumbnail_base64}" | |
print("Image Data URI:", thumbnail_data_uri) # Print to console | |
assistant_message_content = gr.HTML(f'<img src="{thumbnail_data_uri}" alt="Generated Image" style="width:100px; height:100px; object-fit:contain;">') # Use gr.HTML with CSS | |
else: | |
assistant_message_content = status # If no image, send text status | |
# Update chat history - Assistant message is now EITHER gr.HTML or text | |
chat_history.append({"role": "assistant", "content": assistant_message_content}) | |
return chat_history, user_image, img, "" | |
# Create Gradio interface | |
with gr.Blocks(title="Image Editing Chatbot") as demo: | |
gr.Markdown("# Image Editing Chatbot") | |
gr.Markdown("Upload an image and/or type a prompt to generate or edit an image using Google's Gemini model") | |
# Chatbot display area for text messages | |
chatbot = gr.Chatbot( | |
label="Chat", | |
height=300, | |
type="messages", | |
avatar_images=(None, None) | |
) | |
# Separate image outputs | |
with gr.Row(): | |
uploaded_image_output = gr.Image(label="Uploaded Image") | |
generated_image_output = gr.Image(label="Generated Image") | |
# Input area | |
with gr.Row(): | |
with gr.Column(scale=2): # Increased scale for better spacing | |
image_input = gr.Image( | |
label="Upload Image", | |
type="pil", | |
scale=1, | |
height=150, # Increased height for better visibility | |
container=True, # Ensure the component has a container for padding | |
elem_classes="p-4" # Add padding via CSS class (4 units of padding) | |
) | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your image description here...", | |
lines=3, | |
elem_classes="mt-2" # Add margin-top for spacing between components | |
) | |
generate_btn = gr.Button("Generate Image", elem_classes="mt-2") # Add margin-top for the button | |
# State to maintain chat history | |
chat_state = gr.State([]) | |
# Connect the button to the chat handler | |
generate_btn.click( | |
fn=chat_handler, | |
inputs=[prompt_input, image_input, chat_state], | |
outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input] | |
) | |
# Also allow Enter key to submit | |
prompt_input.submit( | |
fn=chat_handler, | |
inputs=[prompt_input, image_input, chat_state], | |
outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input] | |
) | |
if __name__ == "__main__": | |
demo.launch() |