Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,14 @@ import sys
|
|
12 |
import io
|
13 |
from contextlib import redirect_stdout, redirect_stderr
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# FILES
|
16 |
iteration_output_file = "llm_benchmark_iteration_results.csv" # File to store iteration results, defined as global
|
17 |
results_file = "llm_benchmark_results.csv" # all data
|
@@ -41,12 +49,6 @@ difficulty_probabilities = {
|
|
41 |
"a very difficult": 0.6
|
42 |
}
|
43 |
|
44 |
-
# Create output displays for main log and debug log
|
45 |
-
if 'main_output' not in st.session_state:
|
46 |
-
st.session_state.main_output = []
|
47 |
-
if 'debug_output' not in st.session_state:
|
48 |
-
st.session_state.debug_output = []
|
49 |
-
|
50 |
# Custom print function to capture output
|
51 |
def custom_print(*args, **kwargs):
|
52 |
# Convert args to string and join with spaces
|
@@ -57,11 +59,17 @@ def custom_print(*args, **kwargs):
|
|
57 |
|
58 |
# Also print to standard output for console logging
|
59 |
print(*args, **kwargs)
|
|
|
|
|
|
|
60 |
|
61 |
# Custom function to capture warnings and errors
|
62 |
def log_debug(message):
|
63 |
st.session_state.debug_output.append(message)
|
64 |
print(f"DEBUG: {message}", file=sys.stderr)
|
|
|
|
|
|
|
65 |
|
66 |
def retry_api_request(max_retries=3, wait_time=10):
|
67 |
"""Decorator for retrying API requests with rate limit handling."""
|
@@ -724,8 +732,11 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
|
|
724 |
|
725 |
|
726 |
for model_id in active_models:
|
727 |
-
answer = answers
|
728 |
-
|
|
|
|
|
|
|
729 |
if answer == "Error answering": # Handle answer generation errors
|
730 |
consecutive_failures[model_id] += 1
|
731 |
if consecutive_failures[model_id] >= failure_threshold:
|
@@ -794,7 +805,7 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
|
|
794 |
results["question_prompt"].append(question_prompt)
|
795 |
results["question"].append(question)
|
796 |
results["answer"].append(answer)
|
797 |
-
results["answer_generation_duration"].append(
|
798 |
results["average_rank"].append(average_rank)
|
799 |
results["ranks"].append([ranks[m] for m in active_models if m in ranks]) # Store raw ranks including Nones, ensure order
|
800 |
results["question_rank_average"].append(question_avg_rank) # Store question rank average
|
@@ -816,7 +827,7 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
|
|
816 |
total_valid_rank = 0 # Keep track of the sum of valid (non-NaN) ranks
|
817 |
|
818 |
for m_id in active_models:
|
819 |
-
if cumulative_avg_rank[m_id]:
|
820 |
temp_weights[m_id] = cumulative_avg_rank[m_id]
|
821 |
total_valid_rank += cumulative_avg_rank[m_id]
|
822 |
else: # if cumulative is empty, keep original
|
@@ -884,10 +895,6 @@ def check_model_availability(models, token):
|
|
884 |
# Streamlit UI
|
885 |
st.title("LLM Benchmark")
|
886 |
|
887 |
-
# Initialize session state variables for progress tracking
|
888 |
-
if 'progress' not in st.session_state:
|
889 |
-
st.session_state.progress = 0
|
890 |
-
|
891 |
# Setup sidebar for configuration
|
892 |
st.sidebar.header("Configuration")
|
893 |
|
@@ -970,6 +977,7 @@ with tab1:
|
|
970 |
# Clear previous outputs
|
971 |
st.session_state.main_output = []
|
972 |
st.session_state.debug_output = []
|
|
|
973 |
|
974 |
if not hf_token:
|
975 |
st.error("Please enter your Hugging Face API token")
|
@@ -1038,21 +1046,28 @@ with tab1:
|
|
1038 |
with tab2:
|
1039 |
# Display main output log
|
1040 |
st.subheader("Execution Log")
|
1041 |
-
log_container = st.container()
|
1042 |
|
1043 |
# Display logs
|
1044 |
log_text = "\n".join(st.session_state.main_output)
|
1045 |
-
|
1046 |
|
1047 |
# Add a refresh button for the log
|
1048 |
-
if st.button("Refresh Log"):
|
1049 |
-
|
1050 |
|
1051 |
with tab3:
|
1052 |
# Display debug output
|
1053 |
st.subheader("Debug Log")
|
1054 |
-
debug_container = st.container()
|
1055 |
|
1056 |
# Display debug logs
|
1057 |
debug_text = "\n".join(st.session_state.debug_output)
|
1058 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import io
|
13 |
from contextlib import redirect_stdout, redirect_stderr
|
14 |
|
15 |
+
# Initialize session state variables
|
16 |
+
if 'main_output' not in st.session_state:
|
17 |
+
st.session_state.main_output = []
|
18 |
+
if 'debug_output' not in st.session_state:
|
19 |
+
st.session_state.debug_output = []
|
20 |
+
if 'progress' not in st.session_state:
|
21 |
+
st.session_state.progress = 0
|
22 |
+
|
23 |
# FILES
|
24 |
iteration_output_file = "llm_benchmark_iteration_results.csv" # File to store iteration results, defined as global
|
25 |
results_file = "llm_benchmark_results.csv" # all data
|
|
|
49 |
"a very difficult": 0.6
|
50 |
}
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
# Custom print function to capture output
|
53 |
def custom_print(*args, **kwargs):
|
54 |
# Convert args to string and join with spaces
|
|
|
59 |
|
60 |
# Also print to standard output for console logging
|
61 |
print(*args, **kwargs)
|
62 |
+
|
63 |
+
# Force an immediate update of the UI (when used inside a function)
|
64 |
+
st.session_state.update_counter = st.session_state.get('update_counter', 0) + 1
|
65 |
|
66 |
# Custom function to capture warnings and errors
|
67 |
def log_debug(message):
|
68 |
st.session_state.debug_output.append(message)
|
69 |
print(f"DEBUG: {message}", file=sys.stderr)
|
70 |
+
|
71 |
+
# Force an immediate update of the UI
|
72 |
+
st.session_state.update_counter = st.session_state.get('update_counter', 0) + 1
|
73 |
|
74 |
def retry_api_request(max_retries=3, wait_time=10):
|
75 |
"""Decorator for retrying API requests with rate limit handling."""
|
|
|
732 |
|
733 |
|
734 |
for model_id in active_models:
|
735 |
+
answer = answers.get(model_id)
|
736 |
+
if not answer: # Add guard clause
|
737 |
+
log_debug(f"No answer found for model {model_id}. Skipping ranking.")
|
738 |
+
continue
|
739 |
+
|
740 |
if answer == "Error answering": # Handle answer generation errors
|
741 |
consecutive_failures[model_id] += 1
|
742 |
if consecutive_failures[model_id] >= failure_threshold:
|
|
|
805 |
results["question_prompt"].append(question_prompt)
|
806 |
results["question"].append(question)
|
807 |
results["answer"].append(answer)
|
808 |
+
results["answer_generation_duration"].append(answer_durations.get(model_id, 0))
|
809 |
results["average_rank"].append(average_rank)
|
810 |
results["ranks"].append([ranks[m] for m in active_models if m in ranks]) # Store raw ranks including Nones, ensure order
|
811 |
results["question_rank_average"].append(question_avg_rank) # Store question rank average
|
|
|
827 |
total_valid_rank = 0 # Keep track of the sum of valid (non-NaN) ranks
|
828 |
|
829 |
for m_id in active_models:
|
830 |
+
if m_id in cumulative_avg_rank and not np.isnan(cumulative_avg_rank[m_id]):
|
831 |
temp_weights[m_id] = cumulative_avg_rank[m_id]
|
832 |
total_valid_rank += cumulative_avg_rank[m_id]
|
833 |
else: # if cumulative is empty, keep original
|
|
|
895 |
# Streamlit UI
|
896 |
st.title("LLM Benchmark")
|
897 |
|
|
|
|
|
|
|
|
|
898 |
# Setup sidebar for configuration
|
899 |
st.sidebar.header("Configuration")
|
900 |
|
|
|
977 |
# Clear previous outputs
|
978 |
st.session_state.main_output = []
|
979 |
st.session_state.debug_output = []
|
980 |
+
st.session_state.progress = 0
|
981 |
|
982 |
if not hf_token:
|
983 |
st.error("Please enter your Hugging Face API token")
|
|
|
1046 |
with tab2:
|
1047 |
# Display main output log
|
1048 |
st.subheader("Execution Log")
|
|
|
1049 |
|
1050 |
# Display logs
|
1051 |
log_text = "\n".join(st.session_state.main_output)
|
1052 |
+
st.text_area("Progress Log", log_text, height=400)
|
1053 |
|
1054 |
# Add a refresh button for the log
|
1055 |
+
if st.button("Refresh Progress Log"):
|
1056 |
+
pass # The rerun happens automatically at the end
|
1057 |
|
1058 |
with tab3:
|
1059 |
# Display debug output
|
1060 |
st.subheader("Debug Log")
|
|
|
1061 |
|
1062 |
# Display debug logs
|
1063 |
debug_text = "\n".join(st.session_state.debug_output)
|
1064 |
+
st.text_area("Debug Information", debug_text, height=400)
|
1065 |
+
|
1066 |
+
# Add a refresh button for the debug log
|
1067 |
+
if st.button("Refresh Debug Log"):
|
1068 |
+
pass # The rerun happens automatically at the end
|
1069 |
+
|
1070 |
+
# Auto-refresh mechanism
|
1071 |
+
if st.session_state.get('update_counter', 0) > 0:
|
1072 |
+
time.sleep(0.1) # Brief pause to allow UI to update
|
1073 |
+
st.experimental_rerun()
|