Yuxuan-Zhang-Dexter commited on
Commit
455f800
·
1 Parent(s): 9568660

update gradio app with radar chart

Browse files
Files changed (1) hide show
  1. data_visualization.py +207 -5
data_visualization.py CHANGED
@@ -75,7 +75,7 @@ def create_horizontal_bar_chart(df, game_name):
75
  # Set style
76
  plt.style.use('default')
77
  # Increase figure width to accommodate long model names
78
- fig, ax = plt.subplots(figsize=(20, 11))
79
 
80
  # Sort by score
81
  if game_name == "Super Mario Bros":
@@ -114,7 +114,7 @@ def create_horizontal_bar_chart(df, game_name):
114
  bars = ax.barh(range(len(df_sorted)), df_sorted[score_col], color=colors)
115
 
116
  # Add more space for labels on the left
117
- plt.subplots_adjust(left=0.3)
118
 
119
  # Customize the chart
120
  ax.set_yticks(range(len(df_sorted)))
@@ -145,6 +145,11 @@ def create_horizontal_bar_chart(df, game_name):
145
  else:
146
  score_text = f'{width:.0f}'
147
 
 
 
 
 
 
148
  ax.text(width, bar.get_y() + bar.get_height()/2,
149
  score_text,
150
  ha='left', va='center',
@@ -317,7 +322,7 @@ def create_radar_charts(df):
317
  fontweight='bold') # Bold title
318
 
319
  legend = ax.legend(loc='upper right',
320
- bbox_to_anchor=(1.3, 1.1),
321
  fontsize=7, # Slightly larger legend
322
  framealpha=0.9, # More opaque legend
323
  edgecolor='#404040', # Darker edge
@@ -407,7 +412,7 @@ def create_group_bar_chart(df):
407
 
408
  # Create figure and axis with better styling
409
  sns.set_style("whitegrid")
410
- fig = plt.figure(figsize=(20, 11))
411
 
412
  # Create subplot with specific spacing
413
  ax = plt.subplot(111)
@@ -415,7 +420,7 @@ def create_group_bar_chart(df):
415
  # Adjust the subplot parameters
416
  plt.subplots_adjust(top=0.90, # Add more space at the top
417
  bottom=0.15, # Add more space at the bottom
418
- right=0.85, # Add more space for legend
419
  left=0.05) # Add space on the left
420
 
421
  # Get unique models
@@ -543,6 +548,203 @@ def get_combined_leaderboard_with_group_bar(rank_data, selected_games):
543
  group_bar_fig = create_group_bar_chart(df)
544
  return df, group_bar_fig
545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  def save_visualization(fig, filename):
547
  """
548
  Save visualization to file
 
75
  # Set style
76
  plt.style.use('default')
77
  # Increase figure width to accommodate long model names
78
+ fig, ax = plt.subplots(figsize=(20, 7))
79
 
80
  # Sort by score
81
  if game_name == "Super Mario Bros":
 
114
  bars = ax.barh(range(len(df_sorted)), df_sorted[score_col], color=colors)
115
 
116
  # Add more space for labels on the left
117
+ plt.subplots_adjust(left=0.3, top=0.85, bottom=0.3)
118
 
119
  # Customize the chart
120
  ax.set_yticks(range(len(df_sorted)))
 
145
  else:
146
  score_text = f'{width:.0f}'
147
 
148
+ # Get color for model from MODEL_COLORS, use default if not found
149
+ model_name = df_sorted.iloc[i]['Player']
150
+ color = MODEL_COLORS.get(model_name, '#808080') # Default to gray if color not found
151
+ bar.set_color(color) # Set the bar color
152
+
153
  ax.text(width, bar.get_y() + bar.get_height()/2,
154
  score_text,
155
  ha='left', va='center',
 
322
  fontweight='bold') # Bold title
323
 
324
  legend = ax.legend(loc='upper right',
325
+ bbox_to_anchor=(0.9, 1.1),
326
  fontsize=7, # Slightly larger legend
327
  framealpha=0.9, # More opaque legend
328
  edgecolor='#404040', # Darker edge
 
412
 
413
  # Create figure and axis with better styling
414
  sns.set_style("whitegrid")
415
+ fig = plt.figure(figsize=(10, 7))
416
 
417
  # Create subplot with specific spacing
418
  ax = plt.subplot(111)
 
420
  # Adjust the subplot parameters
421
  plt.subplots_adjust(top=0.90, # Add more space at the top
422
  bottom=0.15, # Add more space at the bottom
423
+ right=0.70, # Reduced from 0.75 to 0.70 to make more space for legend
424
  left=0.05) # Add space on the left
425
 
426
  # Get unique models
 
548
  group_bar_fig = create_group_bar_chart(df)
549
  return df, group_bar_fig
550
 
551
+ def create_single_radar_chart(df, selected_games=None, highlight_models=None):
552
+ """
553
+ Create a single radar chart comparing AI model performance across selected games
554
+
555
+ Args:
556
+ df (pd.DataFrame): DataFrame containing the combined leaderboard data
557
+ selected_games (list, optional): List of game names to include in the radar chart
558
+ highlight_models (list, optional): List of model names to highlight in the chart
559
+
560
+ Returns:
561
+ matplotlib.figure.Figure: The generated radar chart figure
562
+ """
563
+ # Close any existing figures to prevent memory leaks
564
+ plt.close('all')
565
+
566
+ # Use provided selected_games or default to the four main games
567
+ if selected_games is None:
568
+ selected_games = ['Super Mario Bros', '2048', 'Candy Crash', 'Sokoban']
569
+
570
+ game_columns = [f"{game} Score" for game in selected_games]
571
+ categories = selected_games
572
+
573
+ # Create figure
574
+ fig, ax = plt.subplots(figsize=(8, 7), subplot_kw=dict(projection='polar'))
575
+ fig.patch.set_facecolor('white')
576
+ ax.set_facecolor('white')
577
+
578
+ # Compute number of variables
579
+ num_vars = len(categories)
580
+ angles = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
581
+ angles = np.concatenate((angles, [angles[0]])) # Complete the circle
582
+
583
+ # Set up the axes
584
+ ax.set_xticks(angles[:-1])
585
+
586
+ # Format categories with bold text
587
+ formatted_categories = []
588
+ for game in categories:
589
+ if game == "Super Mario Bros":
590
+ game = "Super\nMario"
591
+ elif game == "Candy Crash":
592
+ game = "Candy\nCrash"
593
+ elif game == "Tetris (planning only)":
594
+ game = "Tetris\n(planning)"
595
+ elif game == "Tetris (complete)":
596
+ game = "Tetris\n(complete)"
597
+ formatted_categories.append(game)
598
+
599
+ # Set bold labels for categories
600
+ ax.set_xticklabels(formatted_categories, fontsize=10, fontweight='bold')
601
+
602
+ # Draw grid lines
603
+ ax.set_rgrids([20, 40, 60, 80, 100],
604
+ labels=['20', '40', '60', '80', '100'],
605
+ angle=45,
606
+ fontsize=8)
607
+
608
+ # Calculate game statistics for normalization
609
+ def get_game_stats(df, game_col):
610
+ values = []
611
+ for val in df[game_col]:
612
+ if isinstance(val, str) and val == '_':
613
+ values.append(0)
614
+ else:
615
+ try:
616
+ values.append(float(val))
617
+ except:
618
+ values.append(0)
619
+ return np.mean(values), np.std(values)
620
+
621
+ game_stats = {col: get_game_stats(df, col) for col in game_columns}
622
+
623
+ # Split the dataframe into highlighted and non-highlighted models
624
+ if highlight_models:
625
+ highlighted_df = df[df['Player'].isin(highlight_models)]
626
+ non_highlighted_df = df[~df['Player'].isin(highlight_models)]
627
+ else:
628
+ highlighted_df = pd.DataFrame()
629
+ non_highlighted_df = df
630
+
631
+ # Plot non-highlighted models first
632
+ for _, row in non_highlighted_df.iterrows():
633
+ values = []
634
+ for col in game_columns:
635
+ val = row[col]
636
+ if isinstance(val, str) and val == '_':
637
+ values.append(0)
638
+ else:
639
+ try:
640
+ mean, std = game_stats[col]
641
+ if std == 0:
642
+ normalized = 50 if float(val) > 0 else 0
643
+ else:
644
+ z_score = (float(val) - mean) / std
645
+ normalized = max(0, min(100, (z_score * 30) + 50))
646
+ values.append(normalized)
647
+ except:
648
+ values.append(0)
649
+
650
+ # Complete the circular plot
651
+ values = np.concatenate((values, [values[0]]))
652
+
653
+ # Get color for model, use default if not found
654
+ model_name = row['Player']
655
+ color = MODEL_COLORS.get(model_name, '#808080') # Default to gray if color not found
656
+
657
+ # Plot with lines and markers
658
+ ax.plot(angles, values, 'o-', linewidth=2, label=model_name, color=color)
659
+ ax.fill(angles, values, alpha=0.25, color=color)
660
+
661
+ # Plot highlighted models last (so they appear on top)
662
+ for _, row in highlighted_df.iterrows():
663
+ values = []
664
+ for col in game_columns:
665
+ val = row[col]
666
+ if isinstance(val, str) and val == '_':
667
+ values.append(0)
668
+ else:
669
+ try:
670
+ mean, std = game_stats[col]
671
+ if std == 0:
672
+ normalized = 50 if float(val) > 0 else 0
673
+ else:
674
+ z_score = (float(val) - mean) / std
675
+ normalized = max(0, min(100, (z_score * 30) + 30))
676
+ values.append(normalized)
677
+ except:
678
+ values.append(0)
679
+
680
+ # Complete the circular plot
681
+ values = np.concatenate((values, [values[0]]))
682
+
683
+ # Plot with red color and thicker line
684
+ model_name = row['Player']
685
+ ax.plot(angles, values, 'o-', linewidth=6, label=model_name, color='red')
686
+ ax.fill(angles, values, alpha=0.25, color='red')
687
+
688
+ # Add title
689
+ plt.title('AI Models Performance Across Selected Games\n(Normalized Scores)',
690
+ pad=20, fontsize=14, fontweight='bold')
691
+
692
+ # Get handles and labels for legend
693
+ handles, labels = ax.get_legend_handles_labels()
694
+
695
+ # Reorder legend to put highlighted models first
696
+ if highlight_models:
697
+ highlighted_handles = []
698
+ highlighted_labels = []
699
+ non_highlighted_handles = []
700
+ non_highlighted_labels = []
701
+
702
+ for handle, label in zip(handles, labels):
703
+ if label in highlight_models:
704
+ highlighted_handles.append(handle)
705
+ highlighted_labels.append(label)
706
+ else:
707
+ non_highlighted_handles.append(handle)
708
+ non_highlighted_labels.append(label)
709
+
710
+ handles = highlighted_handles + non_highlighted_handles
711
+ labels = highlighted_labels + non_highlighted_labels
712
+
713
+ # Add legend with reordered handles and labels
714
+ legend = plt.legend(handles, labels,
715
+ loc='center left',
716
+ bbox_to_anchor=(0.95, 1), # Moved from (1.2, 0.5) to (1.1, 0.5) to shift left
717
+ fontsize=8,
718
+ title='AI Models',
719
+ title_fontsize=10)
720
+
721
+ # Make the legend title bold
722
+ legend.get_title().set_fontweight('bold')
723
+
724
+ # Adjust layout to prevent label cutoff
725
+ plt.subplots_adjust(right=0.8) # Added subplot adjustment to give more space on the right
726
+ plt.tight_layout()
727
+
728
+ return fig
729
+
730
+ def get_combined_leaderboard_with_single_radar(rank_data, selected_games, highlight_models=None):
731
+ """
732
+ Get combined leaderboard and create single radar chart
733
+
734
+ Args:
735
+ rank_data (dict): Dictionary containing rank data
736
+ selected_games (dict): Dictionary of game names and their selection status
737
+ highlight_models (list, optional): List of model names to highlight in the chart
738
+
739
+ Returns:
740
+ tuple: (DataFrame, matplotlib.figure.Figure) containing the leaderboard data and radar chart
741
+ """
742
+ df = get_combined_leaderboard(rank_data, selected_games)
743
+ # Convert selected_games dict to list of selected game names
744
+ selected_game_names = [game for game, selected in selected_games.items() if selected]
745
+ radar_fig = create_single_radar_chart(df, selected_games=selected_game_names, highlight_models=highlight_models)
746
+ return df, radar_fig
747
+
748
  def save_visualization(fig, filename):
749
  """
750
  Save visualization to file