Sampler-Arena / app.py
rwitz's picture
Update app.py
88083e1 verified
raw
history blame
9.85 kB
import gradio as gr
import requests
import os
import pandas as pd
import json
import random
from elo import update_elo_ratings # Custom function for ELO ratings
enable_btn = gr.Button.update(interactive=True)
import sqlite3
def init_database():
conn = sqlite3.connect('elo_ratings.db')
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS elo_ratings
(bot_name TEXT PRIMARY KEY,
elo_rating INTEGER,
games_played INTEGER)''')
conn.commit()
conn.close()
# Load chatbot URLs and model names from a JSON file
# Load chatbot model adapter names from a text file
with open('chatbots.txt', 'r') as file:
chatbots = file.read().splitlines()
def clear_chat(state):
# Reset state including the chatbot order
state = {} if state is not None else state
# Shuffle and reinitialize chatbots in the state
bot_names = list(chatbots.keys())
random.shuffle(bot_names)
state['last_bots'] = [bot_names[0], bot_names[1]]
# Reset other components
return state, None, None, gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(interactive=True), gr.Button.update(interactive=True)
global_elo_ratings=None
from datasets import load_dataset,DatasetDict,Dataset
import requests
import os
def get_user_elo_ratings():
conn = sqlite3.connect('elo_ratings.db')
c = conn.cursor()
c.execute("SELECT * FROM elo_ratings")
rows = c.fetchall()
conn.close()
if rows:
return {row[0]: {'elo_rating': row[1], 'games_played': row[2]} for row in rows}
else:
return {"default": {'elo_rating': 1200, 'games_played': 0}}
def update_elo_rating(updated_ratings, winner, loser):
conn = sqlite3.connect('elo_ratings.db')
c = conn.cursor()
c.execute("INSERT OR REPLACE INTO elo_ratings (bot_name, elo_rating, games_played) VALUES (?, ?, ?)",
(winner, updated_ratings[winner]['elo_rating'], updated_ratings[winner]['games_played']))
c.execute("INSERT OR REPLACE INTO elo_ratings (bot_name, elo_rating, games_played) VALUES (?, ?, ?)",
(loser, updated_ratings[loser]['elo_rating'], updated_ratings[loser]['games_played']))
conn.commit()
conn.close()
# Function to get bot response
def format_alpaca_prompt(state):
alpaca_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
alpaca_prompt2 = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
for message in state["history"][0]:
j=""
if message['role']=='user':
j="### Instruction:\n"
else:
j="### Response:\n"
alpaca_prompt += j+ message['content']+"\n\n"
for message in state["history"][1]:
j=""
if message['role']=='user':
j="### Instruction:\n"
else:
j="### Response:\n"
alpaca_prompt2 += j+ message['content']+"\n\n"
return [alpaca_prompt+"### Response:\n",alpaca_prompt2+"### Response:\n"]
import aiohttp
import asyncio
from tenacity import retry, stop_after_attempt, wait_exponential
async def get_bot_response(adapter_id, prompt, state, bot_index):
alpaca_prompt = format_alpaca_prompt(state)
print(alpaca_prompt)
payload = {
"inputs": alpaca_prompt[bot_index],
"parameters": {
"adapter_id": adapter_id,
"adapter_source": "hub",
"temperature": 1,
"max_new_tokens": 100
}
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('PREDIBASE_TOKEN')}"
}
async with aiohttp.ClientSession() as session:
try:
async with session.post("https://serving.app.predibase.com/79957f/deployments/v2/llms/mistral-7b/generate",
json=payload, headers=headers, timeout=30) as response:
if response.status == 200:
response_data = await response.json()
response_text = response_data.get('generated_text', '')
else:
response_text = "Sorry, I couldn't generate a response."
except (aiohttp.ClientError, asyncio.TimeoutError):
print(response.text)
response_text = "Sorry, I couldn't generate a response."
return response_text.split('### Instruction')[0]
async def chat_with_bots(user_input, state):
# Use existing bot order from state if available, otherwise shuffle and initialize
if 'last_bots' not in state or not state['last_bots']:
random.shuffle(chatbots)
state['last_bots'] = [chatbots[0], chatbots[1]]
bot1_adapter, bot2_adapter = state['last_bots'][0], state['last_bots'][1]
bot1_response, bot2_response = await asyncio.gather(
get_bot_response(bot1_adapter, user_input, state, 0),
get_bot_response(bot2_adapter, user_input, state, 1)
)
return bot1_response, bot2_response
def update_ratings(state, winner_index):
elo_ratings = get_user_elo_ratings()
winner = state['last_bots'][winner_index]
loser = state['last_bots'][1 - winner_index]
elo_ratings = update_elo_ratings(elo_ratings, winner, loser)
update_elo_rating(elo_ratings, winner, loser)
return [('Winner: ', winner), ('Loser: ', loser)]
def vote_up_model(state, chatbot,chatbot2):
update_message = update_ratings(state, 0)
chatbot.append(update_message[0])
chatbot2.append(update_message[1])
return chatbot, chatbot2,gr.Button.update(interactive=False),gr.Button.update(interactive=False),gr.Textbox.update(interactive=False),gr.Button.update(interactive=False) # Disable voting buttons
def vote_down_model(state, chatbot,chatbot2):
update_message = update_ratings(state, 1)
chatbot2.append(update_message[0])
chatbot.append(update_message[1])
return chatbot,chatbot2, gr.Button.update(interactive=False),gr.Button.update(interactive=False),gr.Textbox.update(interactive=False),gr.Button.update(interactive=False) # Disable voting buttons
async def user_ask(state, chatbot1, chatbot2, textbox):
global enable_btn
user_input = textbox
if len(user_input) > 200:
user_input = user_input[:200] # Limit user input to 200 characters
# Updating state with the current ELO ratings
state["elo_ratings"] = get_user_elo_ratings()
if "history" not in state:
state.update({'history': [[],[]]})
state["history"][0].extend([
{"role": "user", "content": user_input}])
state["history"][1].extend([
{"role": "user", "content": user_input}])
# Chat with bots
bot1_response, bot2_response = await chat_with_bots(user_input, state)
state["history"][0].extend([
{"role": "bot1", "content": bot1_response},
])
state["history"][1].extend([
{"role": "bot2", "content": bot2_response},
])
chatbot1.append((user_input,bot1_response))
chatbot2.append((user_input,bot2_response))
# Keep only the last 10 messages in history
state["history"] = state["history"][-10:]
# Format the conversation in ChatML format
return state, chatbot1, chatbot2, gr.update(value=''),enable_btn,enable_btn
import pandas as pd
# Function to generate leaderboard data
def generate_leaderboard():
conn = sqlite3.connect('elo_ratings.db')
c = conn.cursor()
c.execute("SELECT bot_name, elo_rating, games_played FROM elo_ratings ORDER BY elo_rating DESC")
rows = c.fetchall()
conn.close()
leaderboard_data = pd.DataFrame(rows, columns=['Chatbot', 'ELO Score', 'Games Played'])
return leaderboard_data
def refresh_leaderboard():
return generate_leaderboard()
# Gradio interface setup
with gr.Blocks() as demo:
state = gr.State({})
with gr.Tab("Chatbot Arena"):
with gr.Row():
with gr.Column():
chatbot1 = gr.Chatbot(label='Model A').style(height=600)
upvote_btn_a = gr.Button(value="πŸ‘ Upvote A",interactive=False)
with gr.Column():
chatbot2 = gr.Chatbot(label='Model B').style(height=600)
upvote_btn_b = gr.Button(value="πŸ‘ Upvote B",interactive=False)
textbox = gr.Textbox(placeholder="Enter your prompt (up to 200 characters)", max_chars=200)
with gr.Row():
submit_btn = gr.Button(value="Send")
reset_btn = gr.Button(value="Reset")
reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b,textbox,submit_btn])
textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox,upvote_btn_a,upvote_btn_b], queue=True)
submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox], outputs=[state, chatbot1, chatbot2, textbox,upvote_btn_a,upvote_btn_b], queue=True)
upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1,chatbot2], outputs=[chatbot1,chatbot2,upvote_btn_a,upvote_btn_b,textbox,submit_btn])
upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1,chatbot2], outputs=[chatbot1,chatbot2,upvote_btn_a,upvote_btn_b,textbox,submit_btn])
with gr.Tab("Leaderboard"):
try:
leaderboard = gr.Dataframe(refresh_leaderboard())
except:
leaderboard=gr.Dataframe(columns=['Chatbot', 'ELO Score', 'Games Played'])
refresh_btn = gr.Button("Refresh Leaderboard")
# Function to refresh leaderboard
# Event handler for the refresh button
refresh_btn.click(refresh_leaderboard, inputs=[], outputs=[leaderboard])
# Launch the Gradio interface
if __name__ == "__main__":
init_database()
demo.launch(share=False)