PeterKruger commited on
Commit
dfa358f
·
verified ·
1 Parent(s): 819adb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -20
app.py CHANGED
@@ -12,6 +12,14 @@ import sys
12
  import io
13
  from contextlib import redirect_stdout, redirect_stderr
14
 
 
 
 
 
 
 
 
 
15
  # FILES
16
  iteration_output_file = "llm_benchmark_iteration_results.csv" # File to store iteration results, defined as global
17
  results_file = "llm_benchmark_results.csv" # all data
@@ -41,12 +49,6 @@ difficulty_probabilities = {
41
  "a very difficult": 0.6
42
  }
43
 
44
- # Create output displays for main log and debug log
45
- if 'main_output' not in st.session_state:
46
- st.session_state.main_output = []
47
- if 'debug_output' not in st.session_state:
48
- st.session_state.debug_output = []
49
-
50
  # Custom print function to capture output
51
  def custom_print(*args, **kwargs):
52
  # Convert args to string and join with spaces
@@ -57,11 +59,17 @@ def custom_print(*args, **kwargs):
57
 
58
  # Also print to standard output for console logging
59
  print(*args, **kwargs)
 
 
 
60
 
61
  # Custom function to capture warnings and errors
62
  def log_debug(message):
63
  st.session_state.debug_output.append(message)
64
  print(f"DEBUG: {message}", file=sys.stderr)
 
 
 
65
 
66
  def retry_api_request(max_retries=3, wait_time=10):
67
  """Decorator for retrying API requests with rate limit handling."""
@@ -724,8 +732,11 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
724
 
725
 
726
  for model_id in active_models:
727
- answer = answers[model_id] # Retrieve pre-generated answer
728
-
 
 
 
729
  if answer == "Error answering": # Handle answer generation errors
730
  consecutive_failures[model_id] += 1
731
  if consecutive_failures[model_id] >= failure_threshold:
@@ -794,7 +805,7 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
794
  results["question_prompt"].append(question_prompt)
795
  results["question"].append(question)
796
  results["answer"].append(answer)
797
- results["answer_generation_duration"].append(duration)
798
  results["average_rank"].append(average_rank)
799
  results["ranks"].append([ranks[m] for m in active_models if m in ranks]) # Store raw ranks including Nones, ensure order
800
  results["question_rank_average"].append(question_avg_rank) # Store question rank average
@@ -816,7 +827,7 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
816
  total_valid_rank = 0 # Keep track of the sum of valid (non-NaN) ranks
817
 
818
  for m_id in active_models:
819
- if cumulative_avg_rank[m_id]:
820
  temp_weights[m_id] = cumulative_avg_rank[m_id]
821
  total_valid_rank += cumulative_avg_rank[m_id]
822
  else: # if cumulative is empty, keep original
@@ -884,10 +895,6 @@ def check_model_availability(models, token):
884
  # Streamlit UI
885
  st.title("LLM Benchmark")
886
 
887
- # Initialize session state variables for progress tracking
888
- if 'progress' not in st.session_state:
889
- st.session_state.progress = 0
890
-
891
  # Setup sidebar for configuration
892
  st.sidebar.header("Configuration")
893
 
@@ -970,6 +977,7 @@ with tab1:
970
  # Clear previous outputs
971
  st.session_state.main_output = []
972
  st.session_state.debug_output = []
 
973
 
974
  if not hf_token:
975
  st.error("Please enter your Hugging Face API token")
@@ -1038,21 +1046,28 @@ with tab1:
1038
  with tab2:
1039
  # Display main output log
1040
  st.subheader("Execution Log")
1041
- log_container = st.container()
1042
 
1043
  # Display logs
1044
  log_text = "\n".join(st.session_state.main_output)
1045
- log_container.text_area("Progress Log", log_text, height=400)
1046
 
1047
  # Add a refresh button for the log
1048
- if st.button("Refresh Log"):
1049
- st.experimental_rerun()
1050
 
1051
  with tab3:
1052
  # Display debug output
1053
  st.subheader("Debug Log")
1054
- debug_container = st.container()
1055
 
1056
  # Display debug logs
1057
  debug_text = "\n".join(st.session_state.debug_output)
1058
- debug_container.text_area("Debug Information", debug_text, height=400)
 
 
 
 
 
 
 
 
 
 
12
  import io
13
  from contextlib import redirect_stdout, redirect_stderr
14
 
15
+ # Initialize session state variables
16
+ if 'main_output' not in st.session_state:
17
+ st.session_state.main_output = []
18
+ if 'debug_output' not in st.session_state:
19
+ st.session_state.debug_output = []
20
+ if 'progress' not in st.session_state:
21
+ st.session_state.progress = 0
22
+
23
  # FILES
24
  iteration_output_file = "llm_benchmark_iteration_results.csv" # File to store iteration results, defined as global
25
  results_file = "llm_benchmark_results.csv" # all data
 
49
  "a very difficult": 0.6
50
  }
51
 
 
 
 
 
 
 
52
  # Custom print function to capture output
53
  def custom_print(*args, **kwargs):
54
  # Convert args to string and join with spaces
 
59
 
60
  # Also print to standard output for console logging
61
  print(*args, **kwargs)
62
+
63
+ # Force an immediate update of the UI (when used inside a function)
64
+ st.session_state.update_counter = st.session_state.get('update_counter', 0) + 1
65
 
66
  # Custom function to capture warnings and errors
67
  def log_debug(message):
68
  st.session_state.debug_output.append(message)
69
  print(f"DEBUG: {message}", file=sys.stderr)
70
+
71
+ # Force an immediate update of the UI
72
+ st.session_state.update_counter = st.session_state.get('update_counter', 0) + 1
73
 
74
  def retry_api_request(max_retries=3, wait_time=10):
75
  """Decorator for retrying API requests with rate limit handling."""
 
732
 
733
 
734
  for model_id in active_models:
735
+ answer = answers.get(model_id)
736
+ if not answer: # Add guard clause
737
+ log_debug(f"No answer found for model {model_id}. Skipping ranking.")
738
+ continue
739
+
740
  if answer == "Error answering": # Handle answer generation errors
741
  consecutive_failures[model_id] += 1
742
  if consecutive_failures[model_id] >= failure_threshold:
 
805
  results["question_prompt"].append(question_prompt)
806
  results["question"].append(question)
807
  results["answer"].append(answer)
808
+ results["answer_generation_duration"].append(answer_durations.get(model_id, 0))
809
  results["average_rank"].append(average_rank)
810
  results["ranks"].append([ranks[m] for m in active_models if m in ranks]) # Store raw ranks including Nones, ensure order
811
  results["question_rank_average"].append(question_avg_rank) # Store question rank average
 
827
  total_valid_rank = 0 # Keep track of the sum of valid (non-NaN) ranks
828
 
829
  for m_id in active_models:
830
+ if m_id in cumulative_avg_rank and not np.isnan(cumulative_avg_rank[m_id]):
831
  temp_weights[m_id] = cumulative_avg_rank[m_id]
832
  total_valid_rank += cumulative_avg_rank[m_id]
833
  else: # if cumulative is empty, keep original
 
895
  # Streamlit UI
896
  st.title("LLM Benchmark")
897
 
 
 
 
 
898
  # Setup sidebar for configuration
899
  st.sidebar.header("Configuration")
900
 
 
977
  # Clear previous outputs
978
  st.session_state.main_output = []
979
  st.session_state.debug_output = []
980
+ st.session_state.progress = 0
981
 
982
  if not hf_token:
983
  st.error("Please enter your Hugging Face API token")
 
1046
  with tab2:
1047
  # Display main output log
1048
  st.subheader("Execution Log")
 
1049
 
1050
  # Display logs
1051
  log_text = "\n".join(st.session_state.main_output)
1052
+ st.text_area("Progress Log", log_text, height=400)
1053
 
1054
  # Add a refresh button for the log
1055
+ if st.button("Refresh Progress Log"):
1056
+ pass # The rerun happens automatically at the end
1057
 
1058
  with tab3:
1059
  # Display debug output
1060
  st.subheader("Debug Log")
 
1061
 
1062
  # Display debug logs
1063
  debug_text = "\n".join(st.session_state.debug_output)
1064
+ st.text_area("Debug Information", debug_text, height=400)
1065
+
1066
+ # Add a refresh button for the debug log
1067
+ if st.button("Refresh Debug Log"):
1068
+ pass # The rerun happens automatically at the end
1069
+
1070
+ # Auto-refresh mechanism
1071
+ if st.session_state.get('update_counter', 0) > 0:
1072
+ time.sleep(0.1) # Brief pause to allow UI to update
1073
+ st.experimental_rerun()