Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import os | |
| import pandas as pd | |
| import json | |
| import ssl | |
| import random | |
| from elo import update_elo_ratings # Custom function for ELO ratings | |
| enable_btn = gr.Button.update(interactive=True) | |
| import sqlite3 | |
| import requests | |
| def classify_vote(user_input): | |
| url = "https://api-inference.huggingface.co/models/facebook/bart-large-mnli" | |
| headers = { | |
| "accept": "*/*", | |
| "accept-language": "en-US,en;q=0.9", | |
| "content-type": "application/json", | |
| } | |
| payload = { | |
| "inputs": user_input, | |
| "parameters": { | |
| "candidate_labels": "roleplay,small talk,mathematics,logic,creative", | |
| "multi_class": True | |
| } | |
| } | |
| response = requests.post(url, headers=headers, json=payload) | |
| if response.status_code == 200: | |
| response_data = response.json() | |
| top_category = response_data["labels"][0].strip() | |
| return top_category | |
| else: | |
| print(f"Error: {response.status_code}") | |
| return None | |
| from pymongo.mongo_client import MongoClient | |
| from pymongo.server_api import ServerApi | |
| async def direct_regenerate(model, user_input, chatbot, character_name, character_description, user_name): | |
| adapter = next(entry['adapter'] for entry in chatbots_data if entry['original_model'] == model) | |
| temp_state = { | |
| "history": [ | |
| [{"role": "user", "content": chatbot[-1][0]}] # Keep the user's last message | |
| ] | |
| } | |
| response = await get_bot_response(adapter, user_input, temp_state, 0, character_name, character_description, user_name) | |
| chatbot[-1] = (chatbot[-1][0], response) # Update only the assistant's response | |
| return "", chatbot | |
| password=os.environ.get("MONGODB") | |
| def init_database(): | |
| uri = f"mongodb+srv://new-user:{password}@cluster0.xb2urf6.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0" | |
| client = MongoClient(uri) | |
| db = client["elo_ratings2"] | |
| collection = db["elo_ratings"] | |
| return collection | |
| import json | |
| with open('chatbots.txt', 'r') as file: | |
| chatbots_data = json.load(file) | |
| chatbots = [entry['adapter'] for entry in chatbots_data] | |
| def clear_chat(state): | |
| # Reset state including the chatbot order | |
| state = {} if state is not None else state | |
| # Initialize the collection object | |
| collection = init_database() | |
| # Get the list of adapter names | |
| adapter_names = [entry['adapter'] for entry in chatbots_data] | |
| # Randomly select two new adapters | |
| selected_adapters = random.sample(adapter_names, 2) | |
| state['last_bots'] = selected_adapters | |
| # Reset other components specific to the "Chatbot Arena" tab | |
| return state, [], [], gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True) | |
| from datasets import load_dataset,DatasetDict,Dataset | |
| import requests | |
| import os | |
| # Function to get bot response | |
| def format_prompt(state, bot_index, character_name, character_description, user_name, num_messages=20): | |
| if character_name is None or character_name.strip() == "": | |
| character_name = "Ryan" | |
| if character_description is None or character_description.strip() == "": | |
| character_description = "Ryan is a college student who is always willing to help. He knows a lot about math and coding." | |
| if user_name is None or user_name.strip() == "": | |
| user_name = "You" | |
| prompt = f"{character_description}\n\n" | |
| # Get the last num_messages messages from the conversation history | |
| recent_messages = state["history"][bot_index][-num_messages:] | |
| for message in recent_messages: | |
| if message['role'] == 'user': | |
| prompt += f"{user_name}: {message['content']}\n" | |
| else: | |
| prompt += f"{character_name}: {message['content']}\n" | |
| prompt += f"{character_name}:" | |
| print(prompt) | |
| return prompt | |
| import aiohttp | |
| import asyncio | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| async def get_bot_response(adapter_id, prompt, state, bot_index, character_name, character_description, user_name): | |
| prompt = format_prompt(state, bot_index, character_name, character_description, user_name) | |
| fireworks_adapter_name = next(entry['fireworks_adapter_name'] for entry in chatbots_data if entry['adapter'] == adapter_id) | |
| url = "https://api.fireworks.ai/inference/v1/completions" | |
| payload = { | |
| "model": f"accounts/gaingg19-432d9f/models/{fireworks_adapter_name}", | |
| "max_tokens": 250, | |
| "temperature": 0.7, | |
| "prompt": prompt, | |
| "stop": ["<|im_end|>",f"{character_name}:",f"{user_name}:"] | |
| } | |
| headers = { | |
| "Accept": "application/json", | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}" | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| try: | |
| async with session.post(url, json=payload, headers=headers, timeout=30) as response: | |
| if response.status == 200: | |
| response_data = await response.json() | |
| response_text = response_data['choices'][0]['text'] | |
| else: | |
| error_text = await response.text() | |
| print(error_text) | |
| response_text = "Sorry, I couldn't generate a response." | |
| except (aiohttp.ClientError, asyncio.TimeoutError): | |
| response_text = "Sorry, I couldn't generate a response." | |
| return response_text.strip() | |
| async def chat_with_bots(user_input, state, character_name, character_description, user_name): | |
| # Use existing bot order from state if available, otherwise shuffle and initialize | |
| if 'last_bots' not in state or not state['last_bots']: | |
| random.shuffle(chatbots) | |
| state['last_bots'] = [chatbots[0], chatbots[1]] | |
| bot1_adapter, bot2_adapter = state['last_bots'][0], state['last_bots'][1] | |
| bot1_response, bot2_response = await asyncio.gather( | |
| get_bot_response(bot1_adapter, user_input, state, 0, character_name, character_description, user_name), | |
| get_bot_response(bot2_adapter, user_input, state, 1, character_name, character_description, user_name) | |
| ) | |
| return bot1_response.replace("<|im_end|>",""), bot2_response.replace("<|im_end|>","") | |
| def update_ratings(state, winner_index, collection, category): | |
| elo_ratings = get_user_elo_ratings(collection) | |
| winner_adapter = state['last_bots'][winner_index] | |
| loser_adapter = state['last_bots'][1 - winner_index] | |
| winner = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == winner_adapter) | |
| loser = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == loser_adapter) | |
| elo_ratings = update_elo_ratings(elo_ratings, winner_adapter, loser_adapter, category) | |
| update_elo_rating(collection, elo_ratings, winner_adapter, loser_adapter, category) | |
| return [('Winner: ', winner), ('Loser: ', loser)] | |
| def vote_up_model(state, chatbot, chatbot2, character_name, character_description, user_name): | |
| user_input = format_prompt(state, 0, character_name, character_description, user_name) | |
| collection = init_database() | |
| top_category = classify_vote(user_input) | |
| if top_category: | |
| update_message = update_ratings(state, 0, collection, top_category) | |
| chatbot.append(update_message[0]) | |
| chatbot2.append(update_message[1]) | |
| return chatbot, chatbot2, gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True) | |
| def vote_down_model(state, chatbot, chatbot2, character_name, character_description, user_name): | |
| user_input = format_prompt(state, 1, character_name, character_description, user_name) | |
| collection = init_database() | |
| top_category = classify_vote(user_input) | |
| if top_category: | |
| update_message = update_ratings(state, 1, collection, top_category) | |
| chatbot2.append(update_message[0]) | |
| chatbot.append(update_message[1]) | |
| return chatbot, chatbot2, gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Textbox.update(value='', interactive=True), gr.Button.update(interactive=True) | |
| async def user_ask(state, chatbot1, chatbot2, textbox, character_name, character_description, user_name): | |
| if character_name and len(character_name) > 20: | |
| character_name = character_name[:20] # Limit character name to 20 characters | |
| if character_description and len(character_description) > 500: | |
| character_description = character_description[:500] # Limit character description to 200 characters | |
| if user_name and len(user_name) > 20: | |
| user_name = user_name[:20] # Limit user name to 20 characters | |
| global enable_btn | |
| user_input = textbox | |
| if len(user_input) > 500: | |
| user_input = user_input[:500] # Limit user input to 200 characters | |
| collection = init_database() # Initialize the collection object | |
| # Keep only the last 10 messages in history | |
| # Updating state with the current ELO ratings | |
| state["elo_ratings"] = get_user_elo_ratings(collection) | |
| if "history" not in state: | |
| state.update({'history': [[],[]]}) | |
| state["history"][0].extend([ | |
| {"role": "user", "content": user_input}]) | |
| state["history"][1].extend([ | |
| {"role": "user", "content": user_input}]) | |
| if len(state["history"][0])>20: | |
| state["history"][0] = state["history"][0][-20:] | |
| state["history"][1] = state["history"][1][-20:] | |
| # Chat with bots | |
| bot1_response, bot2_response = await chat_with_bots(user_input, state, character_name, character_description, user_name) | |
| state["history"][0].extend([ | |
| {"role": "bot1", "content": bot1_response}, | |
| ]) | |
| state["history"][1].extend([ | |
| {"role": "bot2", "content": bot2_response}, | |
| ]) | |
| chatbot1.append((user_input,bot1_response)) | |
| chatbot2.append((user_input,bot2_response)) | |
| # Keep only the last 10 messages in history | |
| # Format the conversation in ChatML format | |
| return state, chatbot1, chatbot2, gr.update(value=''),enable_btn,enable_btn | |
| import pandas as pd | |
| # Function to generate leaderboard data | |
| import requests | |
| def submit_model(model_name): | |
| discord_url = os.environ.get("DISCORD_URL") | |
| if discord_url: | |
| payload = { | |
| "content": f"New model submitted: {model_name}" | |
| } | |
| response = requests.post(discord_url, json=payload) | |
| if response.status_code == 204: | |
| return "Model submitted successfully!" | |
| else: | |
| return "Failed to submit the model." | |
| else: | |
| return "Discord webhook URL not configured." | |
| def get_user_elo_ratings(collection): | |
| rows = list(collection.find()) | |
| if rows: | |
| elo_ratings = {} | |
| for row in rows: | |
| bot_name = row['bot_name'] | |
| if bot_name not in elo_ratings: | |
| elo_ratings[bot_name] = {} | |
| for category in row['categories']: | |
| elo_ratings[bot_name][category] = {'elo_rating': row['categories'][category]['elo_rating'], 'games_played': row['categories'][category]['games_played']} | |
| return elo_ratings | |
| else: | |
| return {"default": {'overall': {'elo_rating': 1200, 'games_played': 0}}} | |
| def update_elo_rating(collection, updated_ratings, winner, loser, category): | |
| collection.update_one({"bot_name": winner}, {"$set": {f"categories.{category}.elo_rating": updated_ratings[winner][category]['elo_rating'], f"categories.{category}.games_played": updated_ratings[winner][category]['games_played']}}, upsert=True) | |
| collection.update_one({"bot_name": loser}, {"$set": {f"categories.{category}.elo_rating": updated_ratings[loser][category]['elo_rating'], f"categories.{category}.games_played": updated_ratings[loser][category]['games_played']}}, upsert=True) | |
| def generate_leaderboard(collection): | |
| rows = list(collection.find()) | |
| leaderboard_data = [] | |
| for row in rows: | |
| bot_name = row['bot_name'] | |
| original_model = next(entry['original_model'] for entry in chatbots_data if entry['adapter'] == bot_name) | |
| total_elo = sum(category['elo_rating'] for category in row['categories'].values()) | |
| total_games = sum(category['games_played'] for category in row['categories'].values()) | |
| avg_elo = total_elo / len(row['categories']) if len(row['categories']) > 0 else 0 | |
| leaderboard_data.append([original_model, avg_elo, total_games]) | |
| leaderboard_data = pd.DataFrame(leaderboard_data, columns=['Chatbot', 'Avg ELO Score', 'Total Games Played']) | |
| leaderboard_data['Avg ELO Score'] = leaderboard_data['Avg ELO Score'].round().astype(int) | |
| leaderboard_data = leaderboard_data.sort_values('Avg ELO Score', ascending=False) | |
| return leaderboard_data | |
| def refresh_leaderboard(): | |
| collection = init_database() | |
| leaderboard_data = generate_leaderboard(collection) | |
| return leaderboard_data | |
| async def direct_chat(model, user_input, state, chatbot, character_name, character_description, user_name): | |
| adapter = next(entry['adapter'] for entry in chatbots_data if entry['original_model'] == model) | |
| if "direct_history" not in state: | |
| state["direct_history"] = [] | |
| if len(state["direct_history"])>20: | |
| state["direct_history"] = state["direct_history"][-20:] | |
| state["direct_history"].append({"role": "user", "content": user_input}) | |
| temp_state = { | |
| "history": [ | |
| state["direct_history"], | |
| state["direct_history"] | |
| ] | |
| } | |
| response = await get_bot_response(adapter, user_input, temp_state, 0, character_name, character_description, user_name) | |
| chatbot.append((user_input, response)) | |
| state["direct_history"].append({"role": "bot", "content": response}) | |
| return "", chatbot, state | |
| def reset_direct_chat(state): | |
| state["direct_history"] = [] | |
| return [], gr.Textbox.update(value=''), state | |
| refresh_leaderboard() | |
| # Gradio interface setup | |
| # Gradio interface setup | |
| with gr.Blocks() as demo: | |
| state = gr.State({}) | |
| with gr.Tab("π€ Chatbot Arena"): | |
| gr.Markdown("## π₯ Let's see which chatbot wins!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| chatbot1 = gr.Chatbot(label='π€ Model A').style(height=350) | |
| upvote_btn_a = gr.Button(value="π Upvote A", interactive=False).style(full_width=True) | |
| with gr.Column(): | |
| chatbot2 = gr.Chatbot(label='π€ Model B').style(height=350) | |
| upvote_btn_b = gr.Button(value="π Upvote B", interactive=False).style(full_width=True) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| textbox = gr.Textbox(placeholder="π€ Enter your prompt (up to 500 characters)") | |
| submit_btn = gr.Button(value="Submit") | |
| with gr.Row(): | |
| reset_btn = gr.Button(value="ποΈ Reset") | |
| with gr.Row(): | |
| character_name = gr.Textbox(label="Character Name", value="Ryan", placeholder="Enter character name (max 20 chars)") | |
| character_description = gr.Textbox(label="Character Description", value="Ryan is a college student who is always willing to help. He knows a lot about math and coding.", placeholder="Enter character description (max 500 chars)") | |
| with gr.Row(): | |
| user_name = gr.Textbox(label="Your Name", value="You", placeholder="Enter your name (max 20 chars)") | |
| # ... | |
| reset_btn.click(clear_chat, inputs=[state], outputs=[state, chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
| submit_btn.click(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True) | |
| textbox.submit(user_ask, inputs=[state, chatbot1, chatbot2, textbox, character_name, character_description, user_name], outputs=[state, chatbot1, chatbot2, textbox, upvote_btn_a, upvote_btn_b], queue=True) | |
| collection = init_database() | |
| upvote_btn_a.click(vote_up_model, inputs=[state, chatbot1, chatbot2, character_name, character_description, user_name], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
| upvote_btn_b.click(vote_down_model, inputs=[state, chatbot1, chatbot2, character_name, character_description, user_name], outputs=[chatbot1, chatbot2, upvote_btn_a, upvote_btn_b, textbox, submit_btn]) | |
| # ... | |
| with gr.Tab("π¬ Direct Chat"): | |
| gr.Markdown("## π£οΈ Chat directly with a model!") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown(choices=[entry['original_model'] for entry in chatbots_data], value=chatbots_data[0]['original_model'], label="π€ Select a model") | |
| with gr.Row(): | |
| direct_chatbot = gr.Chatbot(label="π¬ Direct Chat").style(height=500) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| direct_textbox = gr.Textbox(placeholder="π Enter your message") | |
| direct_submit_btn = gr.Button(value="Submit") | |
| with gr.Row(): | |
| direct_regenerate_btn = gr.Button(value="π Regenerate") | |
| direct_reset_btn = gr.Button(value="ποΈ Reset Chat") | |
| # ... | |
| direct_regenerate_btn.click(direct_regenerate, inputs=[model_dropdown, direct_textbox, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot]) | |
| direct_textbox.submit(direct_chat, inputs=[model_dropdown, direct_textbox, state, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot, state]) | |
| direct_submit_btn.click(direct_chat, inputs=[model_dropdown, direct_textbox, state, direct_chatbot, character_name, character_description, user_name], outputs=[direct_textbox, direct_chatbot, state]) | |
| direct_reset_btn.click(reset_direct_chat, inputs=[state], outputs=[direct_chatbot, direct_textbox, state]) | |
| with gr.Tab("π Leaderboard"): | |
| gr.Markdown("## π Check out the top-performing models!") | |
| try: | |
| leaderboard = gr.Dataframe(refresh_leaderboard()) | |
| except: | |
| leaderboard = gr.Dataframe() | |
| with gr.Row(): | |
| refresh_btn = gr.Button("π Refresh Leaderboard") | |
| refresh_btn.click(refresh_leaderboard, outputs=[leaderboard]) | |
| with gr.Tab("π¨ Submit Model"): | |
| gr.Markdown("## π¨ Submit a new model to be added to the chatbot arena!") | |
| with gr.Row(): | |
| model_name_input = gr.Textbox(placeholder="Enter the model name") | |
| submit_model_btn = gr.Button(value="Submit Model") | |
| submit_model_btn.click(submit_model, inputs=[model_name_input], outputs=[model_name_input]) | |
| # Launch the Gradio interface | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |