Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from main import get_novelty_score | |
| from models import chat_with_model, embed | |
| from prompts import questions as predefined_questions, create_gen_prompt, create_judge_prompt | |
| import requests | |
| import numpy as np | |
| import os | |
| # Set the title in the browser tab | |
| st.set_page_config(page_title="Aidan Bench - Generator") | |
| st.title("Aidan Bench - Generator") | |
| # API Key Inputs with Security and User Experience Enhancements | |
| st.warning("Please keep your API keys secure and confidential. This app does not store or log your API keys.") | |
| if "open_router_key" not in st.session_state: | |
| st.session_state.open_router_key = "" | |
| if "openai_api_key" not in st.session_state: | |
| st.session_state.openai_api_key = "" | |
| open_router_key = st.text_input("Enter your Open Router API Key:", type="password", value=st.session_state.open_router_key) | |
| openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password", value=st.session_state.openai_api_key) | |
| if st.button("Confirm API Keys"): | |
| if open_router_key and openai_api_key: | |
| st.session_state.open_router_key = open_router_key | |
| st.session_state.openai_api_key = openai_api_key | |
| st.success("API keys confirmed!") | |
| else: | |
| st.warning("Please enter both API keys.") | |
| # Access API keys from session state | |
| if st.session_state.open_router_key and st.session_state.openai_api_key: | |
| # Fetch models from OpenRouter API | |
| try: | |
| response = requests.get("https://openrouter.ai/api/v1/models") | |
| response.raise_for_status() # Raise an exception for bad status codes | |
| models = response.json()["data"] | |
| # Sort models alphabetically by their ID | |
| models.sort(key=lambda model: model["id"]) | |
| model_names = [model["id"] for model in models] | |
| except requests.exceptions.RequestException as e: | |
| st.error(f"Error fetching models from OpenRouter API: {e}") | |
| model_names = [] # Provide an empty list if API call fails | |
| # Model Selection | |
| if model_names: | |
| model_name = st.selectbox("Select a Language Model", model_names) | |
| else: | |
| st.error("No models available. Please check your API connection.") | |
| st.stop() # Stop execution if no models are available | |
| # Initialize session state for user_questions and predefined_questions | |
| if "user_questions" not in st.session_state: | |
| st.session_state.user_questions = [] | |
| # Workflow Selection | |
| workflow = st.radio("Select Workflow:", ["Use Predefined Questions", "Use User-Defined Questions"]) | |
| # Handle Predefined Questions | |
| if workflow == "Use Predefined Questions": | |
| st.header("Question Selection") | |
| # Multiselect for predefined questions | |
| selected_questions = st.multiselect( | |
| "Select questions to benchmark:", | |
| predefined_questions, | |
| predefined_questions # Select all by default | |
| ) | |
| # Handle User-Defined Questions | |
| elif workflow == "Use User-Defined Questions": | |
| st.header("Question Input") | |
| # Input for adding a new question | |
| new_question = st.text_input("Enter a new question:") | |
| if st.button("Add Question") and new_question: | |
| new_question = new_question.strip() # Remove leading/trailing whitespace | |
| if new_question and new_question not in st.session_state.user_questions: | |
| st.session_state.user_questions.append(new_question) # Append to session state | |
| st.success(f"Question '{new_question}' added successfully.") | |
| else: | |
| st.warning("Question already exists or is empty!") | |
| # Display multiselect with updated user questions | |
| selected_questions = st.multiselect( | |
| "Select your custom questions:", | |
| options=st.session_state.user_questions, | |
| default=st.session_state.user_questions | |
| ) | |
| # Display selected questions | |
| st.write("Selected Questions:", selected_questions) | |
| # Benchmark Execution | |
| if st.button("Start Benchmark"): | |
| if not selected_questions: | |
| st.warning("Please select at least one question.") | |
| else: | |
| # Initialize progress bar | |
| progress_bar = st.progress(0) | |
| num_questions = len(selected_questions) | |
| results = [] # List to store results | |
| # Iterate through selected questions | |
| for i, question in enumerate(selected_questions): | |
| # Display current question | |
| st.write(f"Processing question {i+1}/{num_questions}: {question}") | |
| previous_answers = [] | |
| question_novelty = 0 | |
| try: | |
| while True: | |
| gen_prompt = create_gen_prompt(question, previous_answers) | |
| try: | |
| new_answer = chat_with_model( | |
| prompt=gen_prompt, | |
| model=model_name, | |
| open_router_key=st.session_state.open_router_key, | |
| openai_api_key=st.session_state.openai_api_key | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| st.error(f"API Error: {e}") | |
| break | |
| judge_prompt = create_judge_prompt(question, new_answer) | |
| judge = "openai/gpt-4o-mini" | |
| try: | |
| judge_response = chat_with_model( | |
| prompt=judge_prompt, | |
| model=judge, | |
| open_router_key=st.session_state.open_router_key, | |
| openai_api_key=st.session_state.openai_api_key | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| st.error(f"API Error (Judge): {e}") | |
| break | |
| coherence_score = int(judge_response.split("<coherence_score>")[1].split("</coherence_score>")[0]) | |
| if coherence_score <= 3: | |
| st.warning("Output is incoherent. Moving to next question.") | |
| break | |
| novelty_score = get_novelty_score(new_answer, previous_answers, st.session_state.openai_api_key) | |
| if novelty_score < 0.1: | |
| st.warning("Output is redundant. Moving to next question.") | |
| break | |
| st.write(f"New Answer:\n{new_answer}") | |
| st.write(f"Coherence Score: {coherence_score}") | |
| st.write(f"Novelty Score: {novelty_score}") | |
| previous_answers.append(new_answer) | |
| question_novelty += novelty_score | |
| except Exception as e: | |
| st.error(f"Error processing question: {e}") | |
| results.append({ | |
| "question": question, | |
| "answers": previous_answers, | |
| "coherence_score": coherence_score, | |
| "novelty_score": novelty_score | |
| }) | |
| # Update progress bar | |
| progress_bar.progress((i + 1) / num_questions) | |
| st.success("Benchmark completed!") | |
| # Display results in a table | |
| st.write("Results:") | |
| results_table = [] | |
| for result in results: | |
| for answer in result["answers"]: | |
| results_table.append({ | |
| "Question": result["question"], | |
| "Answer": answer, | |
| "Coherence Score": result["coherence_score"], | |
| "Novelty Score": result["novelty_score"] | |
| }) | |
| st.table(results_table) | |
| else: | |
| st.warning("Please confirm your API keys first.") | |