File size: 12,247 Bytes
6edda02
07334d0
 
 
a2f0cd0
07334d0
 
 
 
 
 
 
 
 
 
 
 
 
a2f0cd0
 
 
07334d0
 
 
a2f0cd0
 
07334d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f0cd0
 
07334d0
a2f0cd0
 
07334d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f0cd0
07334d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f0cd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07334d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f0cd0
 
 
 
07334d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f0cd0
 
 
 
 
 
 
 
 
 
 
 
 
07334d0
 
 
 
 
 
 
 
 
 
 
6edda02
07334d0
 
 
 
 
 
 
 
 
 
6edda02
07334d0
6edda02
 
 
07334d0
 
 
 
 
 
 
 
 
6edda02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import gradio as gr
import asyncio
import threading
import os
from io import BytesIO
from dotenv import load_dotenv
from google import genai
from google.genai.types import Part, FileData, Tool, GenerateContentConfig, GoogleSearch, Content

# Import Discord functionality
import discord
from discord import app_commands
from discord.ext import commands

load_dotenv()

# Environment variables
GOOGLE_KEY = os.getenv("GOOGLE_KEY")
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")
CHANNEL_ID = os.getenv("CHANNEL_ID")
ADDITIONAL_CHANNELS = os.getenv("ADDITIONAL_CHANNELS", "")

# Parse channel IDs for Discord bot
TARGET_CHANNEL_IDS = []
if CHANNEL_ID:
    TARGET_CHANNEL_IDS.append(int(CHANNEL_ID))
if ADDITIONAL_CHANNELS:
    ADDITIONAL_IDS = [int(channel_id.strip()) for channel_id in ADDITIONAL_CHANNELS.split(",") if channel_id.strip()]
    TARGET_CHANNEL_IDS.extend(ADDITIONAL_IDS)

# Model configuration - centralized model definitions
MODEL_DEFINITIONS = {
    "flash": "gemini-2.0-flash",
    "pro": "gemini-2.5-pro-preview-05-06",
    "image": "imagen-3.0-generate-002"
}

# Default models
chat_model_id = MODEL_DEFINITIONS["flash"]  # Default to flash model
image_model_id = MODEL_DEFINITIONS["image"]

# Initialize Google client
google_client = None
if GOOGLE_KEY:
    google_client = genai.Client(api_key=GOOGLE_KEY)

# Default system instruction (fallback if environment variable not set)
DEFAULT_SYSTEM_INSTRUCTION = ""

# Get system instruction from environment variable or use default
SYSTEM_INSTRUCTION = os.getenv("SYSTEM_INSTRUCTION", DEFAULT_SYSTEM_INSTRUCTION)

def respond_with_gemini(message, history):
    """Generate response using Google Gemini API with Seinfeld personality"""
    if not google_client:
        return "I need a Google API key to work! Set the GOOGLE_KEY environment variable."

    try:
        # Format history for Gemini API
        formatted_history = []
        for user_msg, assistant_msg in history:
            if user_msg:
                formatted_history.append(Content(role="user", parts=[Part(text=user_msg)]))
            if assistant_msg:
                formatted_history.append(Content(role="model", parts=[Part(text=assistant_msg)]))

        # Initialize Google Search tool
        google_search_tool = Tool(google_search=GoogleSearch())

        # Create chat
        chat = google_client.chats.create(
            model=chat_model_id,
            history=formatted_history,
            config=GenerateContentConfig(
                system_instruction=SYSTEM_INSTRUCTION,
                tools=[google_search_tool],
                response_modalities=["TEXT"]
            )
        )

        # Send message and get response
        response = chat.send_message(message)
        return response.text

    except Exception as e:
        print(f"Error with Gemini API: {e}")
        # Fallback to a Seinfeld-style response
        return f"What's the deal with API errors? I mean, you type something in, the computer thinks about it, and then... nothing! It's like asking your friend a question and they just stare at you. 'Hey, how are you?' *silence* 'Hello?' *more silence* It's the digital equivalent of being ignored at a party!"

def respond_gradio(message, history: list[tuple[str, str]]):
    """Response function for Gradio interface"""
    # Use Gemini with Seinfeld personality and default parameters
    response = respond_with_gemini(message, history)

    # Stream the response character by character for better UX
    partial_response = ""
    for char in response:
        partial_response += char
        yield partial_response

async def keep_typing(channel):
    """Continuously show the typing indicator until the task is cancelled."""
    print(f"Starting typing indicator in channel {channel.id}")
    try:
        while True:
            async with channel.typing():
                await asyncio.sleep(5)
    except asyncio.CancelledError:
        print(f"Typing indicator cancelled for channel {channel.id}")
        pass
    except Exception as e:
        print(f"Error in keep_typing: {type(e).__name__}: {str(e)}")

async def generate_image_bytes(prompt, google_client, image_model_id):
    """Generate an image using Gemini API and return the image bytes."""
    try:
        # Run image generation in a separate thread to avoid blocking the event loop
        def generate_image():
            response = google_client.models.generate_images(
                model=image_model_id,
                prompt=prompt,
                config=genai.types.GenerateImagesConfig(
                    number_of_images=1,
                    aspect_ratio="16:9"
                )
            )
            return response

        # Run the API call in a separate thread
        response = await asyncio.to_thread(generate_image)

        # Return the image bytes directly
        for generated_image in response.generated_images:
            return generated_image.image.image_bytes

        # If we get here, no images were generated
        print("ERROR: No images were generated in the response")
        raise Exception("No image was generated in the response")
    except Exception as e:
        print(f"Exception in image generation: {type(e).__name__}: {str(e)}")
        raise

async def handle_image_request(message, query, google_client, image_model_id):
    """Handle image generation requests from text messages."""
    if query.lower().startswith("generate image:") or query.lower().startswith("create image:"):
        # Start continuous typing in the background
        typing_task = asyncio.create_task(keep_typing(message.channel))

        try:
            prompt = query.split(":", 1)[1].strip()
            try:
                print(f"Generating image for prompt: {prompt[:30]}...")
                image_bytes = await generate_image_bytes(prompt, google_client, image_model_id)
                # Cancel typing before sending the response
                typing_task.cancel()
                # Send image directly from bytes without saving to disk
                await message.reply(f"Here's your image:", file=discord.File(BytesIO(image_bytes), filename="generated_image.png"))
            except Exception as e:
                print(f"Error generating image: {e}")
                # Cancel typing before sending the response
                typing_task.cancel()
                await message.reply("Sorry, I couldn't generate that image.")
        except Exception as e:
            # Make sure to cancel the typing task even if an error occurs
            typing_task.cancel()
            print(f"Exception during image generation: {e}")
            raise e
        return True
    return False

# Discord Bot Setup
discord_bot = None

async def setup_discord_bot():
    """Setup and run Discord bot"""
    if not DISCORD_TOKEN or not TARGET_CHANNEL_IDS:
        print("Discord bot disabled: Missing DISCORD_TOKEN or channel IDs")
        return

    global discord_bot

    # Initialize Discord bot
    intents = discord.Intents.default()
    intents.message_content = True
    discord_bot = commands.Bot(command_prefix="~", intents=intents)

    @discord_bot.event
    async def on_ready():
        print(f"Discord bot logged in as {discord_bot.user}")
        try:
            synced = await discord_bot.tree.sync()
            print(f"Synced {len(synced)} command(s)")
        except Exception as e:
            print(f"Failed to sync commands: {e}")

    @discord_bot.event
    async def on_message(message):
        await discord_bot.process_commands(message)

        if message.channel.id in TARGET_CHANNEL_IDS:
            if message.author == discord_bot.user:
                return
            if message.content.startswith('!') or message.content.startswith('~'):
                return
            if message.content.strip() == "":
                return

            # Check if this is an image generation request first
            if await handle_image_request(message, message.content, google_client, image_model_id):
                return

            # Show typing indicator
            async with message.channel.typing():
                # Get response using the same function as Gradio
                response = respond_with_gemini(message.content, [])

                # Split long responses
                if len(response) > 2000:
                    # Split by sentences to preserve formatting
                    sentences = response.split('. ')
                    current_msg = ""

                    for sentence in sentences:
                        if len(current_msg + sentence + '. ') > 1900:
                            if current_msg:
                                await message.reply(current_msg.strip())
                                current_msg = sentence + '. '
                            else:
                                # Single sentence too long, just send it
                                await message.reply(sentence[:1900] + "...")
                                current_msg = ""
                        else:
                            current_msg += sentence + '. '

                    if current_msg:
                        await message.channel.send(current_msg.strip())
                else:
                    await message.reply(response)

    @discord_bot.tree.command(name="model")
    @app_commands.describe(new_model_id="New model ID to use for Gemini API or shorthand ('flash', 'pro')")
    async def change_model(interaction: discord.Interaction, new_model_id: str):
        """Changes the Gemini chat model being used."""
        if not interaction.user.guild_permissions.administrator:
            await interaction.response.send_message("Only administrators can change the model.", ephemeral=True)
            return

        global chat_model_id

        # Use centralized model definitions
        actual_model_id = MODEL_DEFINITIONS.get(new_model_id.lower(), new_model_id)
        old_model = chat_model_id
        chat_model_id = actual_model_id

        await interaction.response.send_message(f"Chat model changed from `{old_model}` to `{actual_model_id}`", ephemeral=True)

    @discord_bot.tree.command(name="image")
    @app_commands.describe(prompt="Description of the image you want to generate")
    async def generate_image_command(interaction: discord.Interaction, prompt: str):
        """Generates an image using Gemini API based on the provided prompt."""
        await interaction.response.defer(thinking=True)

        try:
            image_bytes = await generate_image_bytes(prompt, google_client, image_model_id)
            await interaction.followup.send(f"Generated image based on: {prompt}", file=discord.File(BytesIO(image_bytes), filename="generated_image.png"))
        except Exception as e:
            print(f"Error generating image: {e}")
            await interaction.followup.send("Sorry, I couldn't generate that image.")

    # Run the bot
    await discord_bot.start(DISCORD_TOKEN)

def run_discord_bot():
    """Run Discord bot in separate thread"""
    try:
        asyncio.run(setup_discord_bot())
    except Exception as e:
        print(f"Discord bot error: {e}")

# Create Gradio interface
demo = gr.ChatInterface(
    respond_gradio,
    title="🥨 Seinfeld Chatbot",
    description="Chat with Jerry Seinfeld! What's the deal with chatbots anyway?",
    examples=[
        ["What's the deal with airplane food?"],
        ["Why do people say 'after dark' when it's really after light?"],
        ["What's up with people who take forever to order at restaurants?"],
        ["Why do we park in driveways and drive on parkways?"],
        ["Tell me about the soup nazi"],
        ["What's your take on people who don't return shopping carts?"],
    ],
    cache_examples=True,
)

if __name__ == "__main__":
    # Start Discord bot in separate thread if credentials are available
    if DISCORD_TOKEN and TARGET_CHANNEL_IDS:
        discord_thread = threading.Thread(target=run_discord_bot, daemon=True)
        discord_thread.start()
        print("Discord bot starting in background...")
    else:
        print("Discord bot disabled: Missing credentials or channel IDs")

    # Launch Gradio interface
    demo.launch()