File size: 6,560 Bytes
1a2d550
 
 
 
 
e8d1c65
1a2d550
 
 
 
1b6c50f
 
 
1a2d550
 
1b6c50f
 
 
 
 
1a2d550
 
1b6c50f
 
1a2d550
1b6c50f
1a2d550
 
 
1b6c50f
1a2d550
 
 
 
 
 
1b6c50f
 
 
 
 
 
 
1a2d550
 
 
 
 
1b6c50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a2d550
1b6c50f
 
1a2d550
 
 
 
 
 
 
 
1b6c50f
 
1a2d550
 
 
 
1b6c50f
1a2d550
 
1b6c50f
 
 
 
1a2d550
1b6c50f
6c4c5f5
 
 
 
 
 
 
1b6c50f
 
 
 
1a2d550
1b6c50f
 
 
 
1a2d550
1b6c50f
6c4c5f5
4c18034
1b6c50f
 
 
 
 
 
 
 
 
 
6c4c5f5
 
 
1b6c50f
 
6c4c5f5
1a2d550
1b6c50f
 
 
ed4625b
 
662515a
1b6c50f
 
 
 
 
 
 
 
 
 
844ed4d
 
 
1b6c50f
 
19ffd31
ff6bd0c
1b6c50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff6bd0c
1b6c50f
 
ed4625b
1b6c50f
 
ff6bd0c
ed4625b
1b6c50f
 
ed4625b
1b6c50f
 
ed4625b
 
1b6c50f
 
19ffd31
ac3fd77
 
1b6c50f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import base64
import os
import mimetypes
from google import genai
from google.genai import types
import gradio as gr
import io
from PIL import Image

def save_binary_file(file_name, data):
    f = open(file_name, "wb")
    f.write(data)
    f.close()

def generate_image(prompt, image=None, output_filename="generated_image"):
    # Initialize client with the API key
    client = genai.Client(
        api_key="AIzaSyAQcy3LfrkMy6DqS_8MqftAXu1Bx_ov_E8",
    )

    model = "gemini-2.0-flash-exp-image-generation"
    parts = [types.Part.from_text(text=prompt)]

    # If an image is provided, add it to the content
    if image:
        # Convert PIL Image to bytes
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format="PNG")
        img_bytes = img_byte_arr.getvalue()
        # Add the image as a Part with inline_data
        parts.append({
            "inline_data": {
                "mime_type": "image/png",
                "data": img_bytes
            }
        })

    contents = [
        types.Content(
            role="user",
            parts=parts,
        ),
    ]
    generate_content_config = types.GenerateContentConfig(
        temperature=1,
        top_p=0.95,
        top_k=40,
        max_output_tokens=8192,
        response_modalities=[
            "image",
            "text",
        ],
        safety_settings=[
            types.SafetySetting(
                category="HARM_CATEGORY_CIVIC_INTEGRITY",
                threshold="OFF",
            ),
        ],
        response_mime_type="text/plain",
    )

    # Generate the content
    response = client.models.generate_content_stream(
        model=model,
        contents=contents,
        config=generate_content_config,
    )

    # Process the response
    for chunk in response:
        if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
            continue
        if chunk.candidates[0].content.parts[0].inline_data:
            inline_data = chunk.candidates[0].content.parts[0].inline_data
            file_extension = mimetypes.guess_extension(inline_data.mime_type)
            filename = f"{output_filename}{file_extension}"
            save_binary_file(filename, inline_data.data)

            # Convert binary data to PIL Image
            img = Image.open(io.BytesIO(inline_data.data))
            return img, f"Image saved as {filename}"
        else:
            return None, chunk.text

    return None, "No image generated"

# Function to handle chat interaction
def chat_handler(prompt, user_image, chat_history, output_filename="generated_image"):
    # Add the user prompt and image to the chat history
    user_message_content = []
    if prompt:
        user_message_content.append(prompt)
    if user_image is not None: # Handle case where no image is uploaded initially
        # Convert user image to base64 for chatbot display
        buffered = io.BytesIO()
        user_image.save(buffered, format="PNG")
        user_image_base64 = base64.b64encode(buffered.getvalue()).decode()
        user_image_data_uri = f"data:image/png;base64,{user_image_base64}"
        user_message_content.append(user_image_data_uri) # Use data URI for user image in chat history
    if user_message_content:
        chat_history.append({"role": "user", "content": user_message_content if len(user_message_content) > 1 else user_message_content[0]})

    # If no input, return early
    if not prompt and not user_image:
        chat_history.append({"role": "assistant", "content": "Please provide a prompt or an image."})
        return chat_history, user_image, None, ""

    # Generate image based on user input
    img, status = generate_image(prompt or "Generate an image", user_image, output_filename)

    assistant_message_content = None # Initialize to None
    if img:
        # Create thumbnail for chatbot
        thumbnail_size = (100, 100) # Define thumbnail size
        thumbnail = img.copy()
        thumbnail.thumbnail(thumbnail_size)

        # Convert thumbnail to base64 for chatbot display
        buffered = io.BytesIO()
        thumbnail.save(buffered, format="PNG")
        thumbnail_base64 = base64.b64encode(buffered.getvalue()).decode()
        thumbnail_data_uri = f"data:image/png;base64,{thumbnail_base64}"
        assistant_message_content = thumbnail_data_uri # ONLY data URI as assistant message
    else:
        assistant_message_content = status # If no image, send text status

    # Add assistant's response to chat history
    chat_history.append({"role": "assistant", "content": assistant_message_content})

    return chat_history, user_image, img, ""

# Create Gradio interface
with gr.Blocks(title="Image Editing Chatbot") as demo:
    gr.Markdown("# Image Editing Chatbot")
    gr.Markdown("Upload an image and/or type a prompt to generate or edit an image using Google's Gemini model")

    # Chatbot display area for text messages
    chatbot = gr.Chatbot(
        label="Chat",
        height=300,
        type="messages",
        avatar_images=(None, None)
    )

    # Separate image outputs
    with gr.Row():
        uploaded_image_output = gr.Image(label="Uploaded Image")
        generated_image_output = gr.Image(label="Generated Image")

    # Input area
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(
                label="Upload Image",
                type="pil",
                scale=1,
                height=100
            )
            prompt_input = gr.Textbox(
                label="Prompt",
                placeholder="Enter your image description here...",
                lines=3
            )
            filename_input = gr.Textbox(
                label="Output Filename",
                value="generated_image",
                placeholder="Enter desired filename (without extension)"
            )
            generate_btn = gr.Button("Generate Image")

    # State to maintain chat history
    chat_state = gr.State([])

    # Connect the button to the chat handler
    generate_btn.click(
        fn=chat_handler,
        inputs=[prompt_input, image_input, chat_state, filename_input],
        outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
    )

    # Also allow Enter key to submit
    prompt_input.submit(
        fn=chat_handler,
        inputs=[prompt_input, image_input, chat_state, filename_input],
        outputs=[chatbot, uploaded_image_output, generated_image_output, prompt_input]
    )

if __name__ == "__main__":
    demo.launch()