File size: 12,247 Bytes
6edda02 07334d0 a2f0cd0 07334d0 a2f0cd0 07334d0 a2f0cd0 07334d0 a2f0cd0 07334d0 a2f0cd0 07334d0 a2f0cd0 07334d0 a2f0cd0 07334d0 a2f0cd0 07334d0 a2f0cd0 07334d0 6edda02 07334d0 6edda02 07334d0 6edda02 07334d0 6edda02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
import gradio as gr
import asyncio
import threading
import os
from io import BytesIO
from dotenv import load_dotenv
from google import genai
from google.genai.types import Part, FileData, Tool, GenerateContentConfig, GoogleSearch, Content
# Import Discord functionality
import discord
from discord import app_commands
from discord.ext import commands
load_dotenv()
# Environment variables
GOOGLE_KEY = os.getenv("GOOGLE_KEY")
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")
CHANNEL_ID = os.getenv("CHANNEL_ID")
ADDITIONAL_CHANNELS = os.getenv("ADDITIONAL_CHANNELS", "")
# Parse channel IDs for Discord bot
TARGET_CHANNEL_IDS = []
if CHANNEL_ID:
TARGET_CHANNEL_IDS.append(int(CHANNEL_ID))
if ADDITIONAL_CHANNELS:
ADDITIONAL_IDS = [int(channel_id.strip()) for channel_id in ADDITIONAL_CHANNELS.split(",") if channel_id.strip()]
TARGET_CHANNEL_IDS.extend(ADDITIONAL_IDS)
# Model configuration - centralized model definitions
MODEL_DEFINITIONS = {
"flash": "gemini-2.0-flash",
"pro": "gemini-2.5-pro-preview-05-06",
"image": "imagen-3.0-generate-002"
}
# Default models
chat_model_id = MODEL_DEFINITIONS["flash"] # Default to flash model
image_model_id = MODEL_DEFINITIONS["image"]
# Initialize Google client
google_client = None
if GOOGLE_KEY:
google_client = genai.Client(api_key=GOOGLE_KEY)
# Default system instruction (fallback if environment variable not set)
DEFAULT_SYSTEM_INSTRUCTION = ""
# Get system instruction from environment variable or use default
SYSTEM_INSTRUCTION = os.getenv("SYSTEM_INSTRUCTION", DEFAULT_SYSTEM_INSTRUCTION)
def respond_with_gemini(message, history):
"""Generate response using Google Gemini API with Seinfeld personality"""
if not google_client:
return "I need a Google API key to work! Set the GOOGLE_KEY environment variable."
try:
# Format history for Gemini API
formatted_history = []
for user_msg, assistant_msg in history:
if user_msg:
formatted_history.append(Content(role="user", parts=[Part(text=user_msg)]))
if assistant_msg:
formatted_history.append(Content(role="model", parts=[Part(text=assistant_msg)]))
# Initialize Google Search tool
google_search_tool = Tool(google_search=GoogleSearch())
# Create chat
chat = google_client.chats.create(
model=chat_model_id,
history=formatted_history,
config=GenerateContentConfig(
system_instruction=SYSTEM_INSTRUCTION,
tools=[google_search_tool],
response_modalities=["TEXT"]
)
)
# Send message and get response
response = chat.send_message(message)
return response.text
except Exception as e:
print(f"Error with Gemini API: {e}")
# Fallback to a Seinfeld-style response
return f"What's the deal with API errors? I mean, you type something in, the computer thinks about it, and then... nothing! It's like asking your friend a question and they just stare at you. 'Hey, how are you?' *silence* 'Hello?' *more silence* It's the digital equivalent of being ignored at a party!"
def respond_gradio(message, history: list[tuple[str, str]]):
"""Response function for Gradio interface"""
# Use Gemini with Seinfeld personality and default parameters
response = respond_with_gemini(message, history)
# Stream the response character by character for better UX
partial_response = ""
for char in response:
partial_response += char
yield partial_response
async def keep_typing(channel):
"""Continuously show the typing indicator until the task is cancelled."""
print(f"Starting typing indicator in channel {channel.id}")
try:
while True:
async with channel.typing():
await asyncio.sleep(5)
except asyncio.CancelledError:
print(f"Typing indicator cancelled for channel {channel.id}")
pass
except Exception as e:
print(f"Error in keep_typing: {type(e).__name__}: {str(e)}")
async def generate_image_bytes(prompt, google_client, image_model_id):
"""Generate an image using Gemini API and return the image bytes."""
try:
# Run image generation in a separate thread to avoid blocking the event loop
def generate_image():
response = google_client.models.generate_images(
model=image_model_id,
prompt=prompt,
config=genai.types.GenerateImagesConfig(
number_of_images=1,
aspect_ratio="16:9"
)
)
return response
# Run the API call in a separate thread
response = await asyncio.to_thread(generate_image)
# Return the image bytes directly
for generated_image in response.generated_images:
return generated_image.image.image_bytes
# If we get here, no images were generated
print("ERROR: No images were generated in the response")
raise Exception("No image was generated in the response")
except Exception as e:
print(f"Exception in image generation: {type(e).__name__}: {str(e)}")
raise
async def handle_image_request(message, query, google_client, image_model_id):
"""Handle image generation requests from text messages."""
if query.lower().startswith("generate image:") or query.lower().startswith("create image:"):
# Start continuous typing in the background
typing_task = asyncio.create_task(keep_typing(message.channel))
try:
prompt = query.split(":", 1)[1].strip()
try:
print(f"Generating image for prompt: {prompt[:30]}...")
image_bytes = await generate_image_bytes(prompt, google_client, image_model_id)
# Cancel typing before sending the response
typing_task.cancel()
# Send image directly from bytes without saving to disk
await message.reply(f"Here's your image:", file=discord.File(BytesIO(image_bytes), filename="generated_image.png"))
except Exception as e:
print(f"Error generating image: {e}")
# Cancel typing before sending the response
typing_task.cancel()
await message.reply("Sorry, I couldn't generate that image.")
except Exception as e:
# Make sure to cancel the typing task even if an error occurs
typing_task.cancel()
print(f"Exception during image generation: {e}")
raise e
return True
return False
# Discord Bot Setup
discord_bot = None
async def setup_discord_bot():
"""Setup and run Discord bot"""
if not DISCORD_TOKEN or not TARGET_CHANNEL_IDS:
print("Discord bot disabled: Missing DISCORD_TOKEN or channel IDs")
return
global discord_bot
# Initialize Discord bot
intents = discord.Intents.default()
intents.message_content = True
discord_bot = commands.Bot(command_prefix="~", intents=intents)
@discord_bot.event
async def on_ready():
print(f"Discord bot logged in as {discord_bot.user}")
try:
synced = await discord_bot.tree.sync()
print(f"Synced {len(synced)} command(s)")
except Exception as e:
print(f"Failed to sync commands: {e}")
@discord_bot.event
async def on_message(message):
await discord_bot.process_commands(message)
if message.channel.id in TARGET_CHANNEL_IDS:
if message.author == discord_bot.user:
return
if message.content.startswith('!') or message.content.startswith('~'):
return
if message.content.strip() == "":
return
# Check if this is an image generation request first
if await handle_image_request(message, message.content, google_client, image_model_id):
return
# Show typing indicator
async with message.channel.typing():
# Get response using the same function as Gradio
response = respond_with_gemini(message.content, [])
# Split long responses
if len(response) > 2000:
# Split by sentences to preserve formatting
sentences = response.split('. ')
current_msg = ""
for sentence in sentences:
if len(current_msg + sentence + '. ') > 1900:
if current_msg:
await message.reply(current_msg.strip())
current_msg = sentence + '. '
else:
# Single sentence too long, just send it
await message.reply(sentence[:1900] + "...")
current_msg = ""
else:
current_msg += sentence + '. '
if current_msg:
await message.channel.send(current_msg.strip())
else:
await message.reply(response)
@discord_bot.tree.command(name="model")
@app_commands.describe(new_model_id="New model ID to use for Gemini API or shorthand ('flash', 'pro')")
async def change_model(interaction: discord.Interaction, new_model_id: str):
"""Changes the Gemini chat model being used."""
if not interaction.user.guild_permissions.administrator:
await interaction.response.send_message("Only administrators can change the model.", ephemeral=True)
return
global chat_model_id
# Use centralized model definitions
actual_model_id = MODEL_DEFINITIONS.get(new_model_id.lower(), new_model_id)
old_model = chat_model_id
chat_model_id = actual_model_id
await interaction.response.send_message(f"Chat model changed from `{old_model}` to `{actual_model_id}`", ephemeral=True)
@discord_bot.tree.command(name="image")
@app_commands.describe(prompt="Description of the image you want to generate")
async def generate_image_command(interaction: discord.Interaction, prompt: str):
"""Generates an image using Gemini API based on the provided prompt."""
await interaction.response.defer(thinking=True)
try:
image_bytes = await generate_image_bytes(prompt, google_client, image_model_id)
await interaction.followup.send(f"Generated image based on: {prompt}", file=discord.File(BytesIO(image_bytes), filename="generated_image.png"))
except Exception as e:
print(f"Error generating image: {e}")
await interaction.followup.send("Sorry, I couldn't generate that image.")
# Run the bot
await discord_bot.start(DISCORD_TOKEN)
def run_discord_bot():
"""Run Discord bot in separate thread"""
try:
asyncio.run(setup_discord_bot())
except Exception as e:
print(f"Discord bot error: {e}")
# Create Gradio interface
demo = gr.ChatInterface(
respond_gradio,
title="🥨 Seinfeld Chatbot",
description="Chat with Jerry Seinfeld! What's the deal with chatbots anyway?",
examples=[
["What's the deal with airplane food?"],
["Why do people say 'after dark' when it's really after light?"],
["What's up with people who take forever to order at restaurants?"],
["Why do we park in driveways and drive on parkways?"],
["Tell me about the soup nazi"],
["What's your take on people who don't return shopping carts?"],
],
cache_examples=True,
)
if __name__ == "__main__":
# Start Discord bot in separate thread if credentials are available
if DISCORD_TOKEN and TARGET_CHANNEL_IDS:
discord_thread = threading.Thread(target=run_discord_bot, daemon=True)
discord_thread.start()
print("Discord bot starting in background...")
else:
print("Discord bot disabled: Missing credentials or channel IDs")
# Launch Gradio interface
demo.launch()
|