Spaces:
Running
Running
Yuxuan-Zhang-Dexter
commited on
Commit
·
455f800
1
Parent(s):
9568660
update gradio app with radar chart
Browse files- data_visualization.py +207 -5
data_visualization.py
CHANGED
@@ -75,7 +75,7 @@ def create_horizontal_bar_chart(df, game_name):
|
|
75 |
# Set style
|
76 |
plt.style.use('default')
|
77 |
# Increase figure width to accommodate long model names
|
78 |
-
fig, ax = plt.subplots(figsize=(20,
|
79 |
|
80 |
# Sort by score
|
81 |
if game_name == "Super Mario Bros":
|
@@ -114,7 +114,7 @@ def create_horizontal_bar_chart(df, game_name):
|
|
114 |
bars = ax.barh(range(len(df_sorted)), df_sorted[score_col], color=colors)
|
115 |
|
116 |
# Add more space for labels on the left
|
117 |
-
plt.subplots_adjust(left=0.3)
|
118 |
|
119 |
# Customize the chart
|
120 |
ax.set_yticks(range(len(df_sorted)))
|
@@ -145,6 +145,11 @@ def create_horizontal_bar_chart(df, game_name):
|
|
145 |
else:
|
146 |
score_text = f'{width:.0f}'
|
147 |
|
|
|
|
|
|
|
|
|
|
|
148 |
ax.text(width, bar.get_y() + bar.get_height()/2,
|
149 |
score_text,
|
150 |
ha='left', va='center',
|
@@ -317,7 +322,7 @@ def create_radar_charts(df):
|
|
317 |
fontweight='bold') # Bold title
|
318 |
|
319 |
legend = ax.legend(loc='upper right',
|
320 |
-
bbox_to_anchor=(
|
321 |
fontsize=7, # Slightly larger legend
|
322 |
framealpha=0.9, # More opaque legend
|
323 |
edgecolor='#404040', # Darker edge
|
@@ -407,7 +412,7 @@ def create_group_bar_chart(df):
|
|
407 |
|
408 |
# Create figure and axis with better styling
|
409 |
sns.set_style("whitegrid")
|
410 |
-
fig = plt.figure(figsize=(
|
411 |
|
412 |
# Create subplot with specific spacing
|
413 |
ax = plt.subplot(111)
|
@@ -415,7 +420,7 @@ def create_group_bar_chart(df):
|
|
415 |
# Adjust the subplot parameters
|
416 |
plt.subplots_adjust(top=0.90, # Add more space at the top
|
417 |
bottom=0.15, # Add more space at the bottom
|
418 |
-
right=0.
|
419 |
left=0.05) # Add space on the left
|
420 |
|
421 |
# Get unique models
|
@@ -543,6 +548,203 @@ def get_combined_leaderboard_with_group_bar(rank_data, selected_games):
|
|
543 |
group_bar_fig = create_group_bar_chart(df)
|
544 |
return df, group_bar_fig
|
545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
def save_visualization(fig, filename):
|
547 |
"""
|
548 |
Save visualization to file
|
|
|
75 |
# Set style
|
76 |
plt.style.use('default')
|
77 |
# Increase figure width to accommodate long model names
|
78 |
+
fig, ax = plt.subplots(figsize=(20, 7))
|
79 |
|
80 |
# Sort by score
|
81 |
if game_name == "Super Mario Bros":
|
|
|
114 |
bars = ax.barh(range(len(df_sorted)), df_sorted[score_col], color=colors)
|
115 |
|
116 |
# Add more space for labels on the left
|
117 |
+
plt.subplots_adjust(left=0.3, top=0.85, bottom=0.3)
|
118 |
|
119 |
# Customize the chart
|
120 |
ax.set_yticks(range(len(df_sorted)))
|
|
|
145 |
else:
|
146 |
score_text = f'{width:.0f}'
|
147 |
|
148 |
+
# Get color for model from MODEL_COLORS, use default if not found
|
149 |
+
model_name = df_sorted.iloc[i]['Player']
|
150 |
+
color = MODEL_COLORS.get(model_name, '#808080') # Default to gray if color not found
|
151 |
+
bar.set_color(color) # Set the bar color
|
152 |
+
|
153 |
ax.text(width, bar.get_y() + bar.get_height()/2,
|
154 |
score_text,
|
155 |
ha='left', va='center',
|
|
|
322 |
fontweight='bold') # Bold title
|
323 |
|
324 |
legend = ax.legend(loc='upper right',
|
325 |
+
bbox_to_anchor=(0.9, 1.1),
|
326 |
fontsize=7, # Slightly larger legend
|
327 |
framealpha=0.9, # More opaque legend
|
328 |
edgecolor='#404040', # Darker edge
|
|
|
412 |
|
413 |
# Create figure and axis with better styling
|
414 |
sns.set_style("whitegrid")
|
415 |
+
fig = plt.figure(figsize=(10, 7))
|
416 |
|
417 |
# Create subplot with specific spacing
|
418 |
ax = plt.subplot(111)
|
|
|
420 |
# Adjust the subplot parameters
|
421 |
plt.subplots_adjust(top=0.90, # Add more space at the top
|
422 |
bottom=0.15, # Add more space at the bottom
|
423 |
+
right=0.70, # Reduced from 0.75 to 0.70 to make more space for legend
|
424 |
left=0.05) # Add space on the left
|
425 |
|
426 |
# Get unique models
|
|
|
548 |
group_bar_fig = create_group_bar_chart(df)
|
549 |
return df, group_bar_fig
|
550 |
|
551 |
+
def create_single_radar_chart(df, selected_games=None, highlight_models=None):
|
552 |
+
"""
|
553 |
+
Create a single radar chart comparing AI model performance across selected games
|
554 |
+
|
555 |
+
Args:
|
556 |
+
df (pd.DataFrame): DataFrame containing the combined leaderboard data
|
557 |
+
selected_games (list, optional): List of game names to include in the radar chart
|
558 |
+
highlight_models (list, optional): List of model names to highlight in the chart
|
559 |
+
|
560 |
+
Returns:
|
561 |
+
matplotlib.figure.Figure: The generated radar chart figure
|
562 |
+
"""
|
563 |
+
# Close any existing figures to prevent memory leaks
|
564 |
+
plt.close('all')
|
565 |
+
|
566 |
+
# Use provided selected_games or default to the four main games
|
567 |
+
if selected_games is None:
|
568 |
+
selected_games = ['Super Mario Bros', '2048', 'Candy Crash', 'Sokoban']
|
569 |
+
|
570 |
+
game_columns = [f"{game} Score" for game in selected_games]
|
571 |
+
categories = selected_games
|
572 |
+
|
573 |
+
# Create figure
|
574 |
+
fig, ax = plt.subplots(figsize=(8, 7), subplot_kw=dict(projection='polar'))
|
575 |
+
fig.patch.set_facecolor('white')
|
576 |
+
ax.set_facecolor('white')
|
577 |
+
|
578 |
+
# Compute number of variables
|
579 |
+
num_vars = len(categories)
|
580 |
+
angles = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
|
581 |
+
angles = np.concatenate((angles, [angles[0]])) # Complete the circle
|
582 |
+
|
583 |
+
# Set up the axes
|
584 |
+
ax.set_xticks(angles[:-1])
|
585 |
+
|
586 |
+
# Format categories with bold text
|
587 |
+
formatted_categories = []
|
588 |
+
for game in categories:
|
589 |
+
if game == "Super Mario Bros":
|
590 |
+
game = "Super\nMario"
|
591 |
+
elif game == "Candy Crash":
|
592 |
+
game = "Candy\nCrash"
|
593 |
+
elif game == "Tetris (planning only)":
|
594 |
+
game = "Tetris\n(planning)"
|
595 |
+
elif game == "Tetris (complete)":
|
596 |
+
game = "Tetris\n(complete)"
|
597 |
+
formatted_categories.append(game)
|
598 |
+
|
599 |
+
# Set bold labels for categories
|
600 |
+
ax.set_xticklabels(formatted_categories, fontsize=10, fontweight='bold')
|
601 |
+
|
602 |
+
# Draw grid lines
|
603 |
+
ax.set_rgrids([20, 40, 60, 80, 100],
|
604 |
+
labels=['20', '40', '60', '80', '100'],
|
605 |
+
angle=45,
|
606 |
+
fontsize=8)
|
607 |
+
|
608 |
+
# Calculate game statistics for normalization
|
609 |
+
def get_game_stats(df, game_col):
|
610 |
+
values = []
|
611 |
+
for val in df[game_col]:
|
612 |
+
if isinstance(val, str) and val == '_':
|
613 |
+
values.append(0)
|
614 |
+
else:
|
615 |
+
try:
|
616 |
+
values.append(float(val))
|
617 |
+
except:
|
618 |
+
values.append(0)
|
619 |
+
return np.mean(values), np.std(values)
|
620 |
+
|
621 |
+
game_stats = {col: get_game_stats(df, col) for col in game_columns}
|
622 |
+
|
623 |
+
# Split the dataframe into highlighted and non-highlighted models
|
624 |
+
if highlight_models:
|
625 |
+
highlighted_df = df[df['Player'].isin(highlight_models)]
|
626 |
+
non_highlighted_df = df[~df['Player'].isin(highlight_models)]
|
627 |
+
else:
|
628 |
+
highlighted_df = pd.DataFrame()
|
629 |
+
non_highlighted_df = df
|
630 |
+
|
631 |
+
# Plot non-highlighted models first
|
632 |
+
for _, row in non_highlighted_df.iterrows():
|
633 |
+
values = []
|
634 |
+
for col in game_columns:
|
635 |
+
val = row[col]
|
636 |
+
if isinstance(val, str) and val == '_':
|
637 |
+
values.append(0)
|
638 |
+
else:
|
639 |
+
try:
|
640 |
+
mean, std = game_stats[col]
|
641 |
+
if std == 0:
|
642 |
+
normalized = 50 if float(val) > 0 else 0
|
643 |
+
else:
|
644 |
+
z_score = (float(val) - mean) / std
|
645 |
+
normalized = max(0, min(100, (z_score * 30) + 50))
|
646 |
+
values.append(normalized)
|
647 |
+
except:
|
648 |
+
values.append(0)
|
649 |
+
|
650 |
+
# Complete the circular plot
|
651 |
+
values = np.concatenate((values, [values[0]]))
|
652 |
+
|
653 |
+
# Get color for model, use default if not found
|
654 |
+
model_name = row['Player']
|
655 |
+
color = MODEL_COLORS.get(model_name, '#808080') # Default to gray if color not found
|
656 |
+
|
657 |
+
# Plot with lines and markers
|
658 |
+
ax.plot(angles, values, 'o-', linewidth=2, label=model_name, color=color)
|
659 |
+
ax.fill(angles, values, alpha=0.25, color=color)
|
660 |
+
|
661 |
+
# Plot highlighted models last (so they appear on top)
|
662 |
+
for _, row in highlighted_df.iterrows():
|
663 |
+
values = []
|
664 |
+
for col in game_columns:
|
665 |
+
val = row[col]
|
666 |
+
if isinstance(val, str) and val == '_':
|
667 |
+
values.append(0)
|
668 |
+
else:
|
669 |
+
try:
|
670 |
+
mean, std = game_stats[col]
|
671 |
+
if std == 0:
|
672 |
+
normalized = 50 if float(val) > 0 else 0
|
673 |
+
else:
|
674 |
+
z_score = (float(val) - mean) / std
|
675 |
+
normalized = max(0, min(100, (z_score * 30) + 30))
|
676 |
+
values.append(normalized)
|
677 |
+
except:
|
678 |
+
values.append(0)
|
679 |
+
|
680 |
+
# Complete the circular plot
|
681 |
+
values = np.concatenate((values, [values[0]]))
|
682 |
+
|
683 |
+
# Plot with red color and thicker line
|
684 |
+
model_name = row['Player']
|
685 |
+
ax.plot(angles, values, 'o-', linewidth=6, label=model_name, color='red')
|
686 |
+
ax.fill(angles, values, alpha=0.25, color='red')
|
687 |
+
|
688 |
+
# Add title
|
689 |
+
plt.title('AI Models Performance Across Selected Games\n(Normalized Scores)',
|
690 |
+
pad=20, fontsize=14, fontweight='bold')
|
691 |
+
|
692 |
+
# Get handles and labels for legend
|
693 |
+
handles, labels = ax.get_legend_handles_labels()
|
694 |
+
|
695 |
+
# Reorder legend to put highlighted models first
|
696 |
+
if highlight_models:
|
697 |
+
highlighted_handles = []
|
698 |
+
highlighted_labels = []
|
699 |
+
non_highlighted_handles = []
|
700 |
+
non_highlighted_labels = []
|
701 |
+
|
702 |
+
for handle, label in zip(handles, labels):
|
703 |
+
if label in highlight_models:
|
704 |
+
highlighted_handles.append(handle)
|
705 |
+
highlighted_labels.append(label)
|
706 |
+
else:
|
707 |
+
non_highlighted_handles.append(handle)
|
708 |
+
non_highlighted_labels.append(label)
|
709 |
+
|
710 |
+
handles = highlighted_handles + non_highlighted_handles
|
711 |
+
labels = highlighted_labels + non_highlighted_labels
|
712 |
+
|
713 |
+
# Add legend with reordered handles and labels
|
714 |
+
legend = plt.legend(handles, labels,
|
715 |
+
loc='center left',
|
716 |
+
bbox_to_anchor=(0.95, 1), # Moved from (1.2, 0.5) to (1.1, 0.5) to shift left
|
717 |
+
fontsize=8,
|
718 |
+
title='AI Models',
|
719 |
+
title_fontsize=10)
|
720 |
+
|
721 |
+
# Make the legend title bold
|
722 |
+
legend.get_title().set_fontweight('bold')
|
723 |
+
|
724 |
+
# Adjust layout to prevent label cutoff
|
725 |
+
plt.subplots_adjust(right=0.8) # Added subplot adjustment to give more space on the right
|
726 |
+
plt.tight_layout()
|
727 |
+
|
728 |
+
return fig
|
729 |
+
|
730 |
+
def get_combined_leaderboard_with_single_radar(rank_data, selected_games, highlight_models=None):
|
731 |
+
"""
|
732 |
+
Get combined leaderboard and create single radar chart
|
733 |
+
|
734 |
+
Args:
|
735 |
+
rank_data (dict): Dictionary containing rank data
|
736 |
+
selected_games (dict): Dictionary of game names and their selection status
|
737 |
+
highlight_models (list, optional): List of model names to highlight in the chart
|
738 |
+
|
739 |
+
Returns:
|
740 |
+
tuple: (DataFrame, matplotlib.figure.Figure) containing the leaderboard data and radar chart
|
741 |
+
"""
|
742 |
+
df = get_combined_leaderboard(rank_data, selected_games)
|
743 |
+
# Convert selected_games dict to list of selected game names
|
744 |
+
selected_game_names = [game for game, selected in selected_games.items() if selected]
|
745 |
+
radar_fig = create_single_radar_chart(df, selected_games=selected_game_names, highlight_models=highlight_models)
|
746 |
+
return df, radar_fig
|
747 |
+
|
748 |
def save_visualization(fig, filename):
|
749 |
"""
|
750 |
Save visualization to file
|