Spaces:
Sleeping
Sleeping
import gradio as gr | |
from pathlib import Path | |
abs_path = Path(__file__).parent | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import pandas as pd | |
import numpy as np | |
from sheet_manager.sheet_loader.sheet2df import sheet2df | |
from sheet_manager.sheet_convert.json2sheet import str2json | |
# Mock ๋ฐ์ดํฐ ์์ฑ | |
def calculate_avg_metrics(df): | |
""" | |
๊ฐ ๋ชจ๋ธ์ ์นดํ ๊ณ ๋ฆฌ๋ณ ํ๊ท ์ฑ๋ฅ ์งํ๋ฅผ ๊ณ์ฐ | |
""" | |
metrics_data = [] | |
for _, row in df.iterrows(): | |
model_name = row['Model name'] | |
# PIA๊ฐ ๋น์ด์๊ฑฐ๋ ๋ค๋ฅธ ๊ฐ์ธ ๊ฒฝ์ฐ ๊ฑด๋๋ฐ๊ธฐ | |
if pd.isna(row['PIA']) or not isinstance(row['PIA'], str): | |
print(f"Skipping model {model_name}: Invalid PIA data") | |
continue | |
try: | |
metrics = str2json(row['PIA']) | |
# metrics๊ฐ None์ด๊ฑฐ๋ dict๊ฐ ์๋ ๊ฒฝ์ฐ ๊ฑด๋๋ฐ๊ธฐ | |
if not metrics or not isinstance(metrics, dict): | |
print(f"Skipping model {model_name}: Invalid JSON format") | |
continue | |
# ํ์ํ ์นดํ ๊ณ ๋ฆฌ๊ฐ ๋ชจ๋ ์๋์ง ํ์ธ | |
required_categories = ['falldown', 'violence', 'fire'] | |
if not all(cat in metrics for cat in required_categories): | |
print(f"Skipping model {model_name}: Missing required categories") | |
continue | |
# ํ์ํ ๋ฉํธ๋ฆญ์ด ๋ชจ๋ ์๋์ง ํ์ธ | |
required_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1', | |
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'] | |
avg_metrics = {} | |
for metric in required_metrics: | |
try: | |
values = [metrics[cat][metric] for cat in required_categories | |
if metric in metrics[cat]] | |
if values: # ๊ฐ์ด ์๋ ๊ฒฝ์ฐ๋ง ํ๊ท ๊ณ์ฐ | |
avg_metrics[metric] = sum(values) / len(values) | |
else: | |
avg_metrics[metric] = 0 # ๋๋ ๋ค๋ฅธ ๊ธฐ๋ณธ๊ฐ ์ค์ | |
except (KeyError, TypeError) as e: | |
print(f"Error calculating {metric} for {model_name}: {str(e)}") | |
avg_metrics[metric] = 0 # ์๋ฌ ๋ฐ์ ์ ๊ธฐ๋ณธ๊ฐ ์ค์ | |
metrics_data.append({ | |
'model_name': model_name, | |
**avg_metrics | |
}) | |
except Exception as e: | |
print(f"Error processing model {model_name}: {str(e)}") | |
continue | |
return pd.DataFrame(metrics_data) | |
def create_performance_chart(df, selected_metrics): | |
""" | |
๋ชจ๋ธ๋ณ ์ ํ๋ ์ฑ๋ฅ ์งํ์ ์ํ ๋ง๋ ๊ทธ๋ํ ์์ฑ | |
""" | |
fig = go.Figure() | |
# ๋ชจ๋ธ ์ด๋ฆ ๊ธธ์ด์ ๋ฐ๋ฅธ ๋ง์ง ๊ณ์ฐ | |
max_name_length = max([len(name) for name in df['model_name']]) | |
left_margin = min(max_name_length * 7, 500) # ๊ธ์ ์์ ๋ฐ๋ผ ๋ง์ง ์กฐ์ , ์ต๋ 500 | |
for metric in selected_metrics: | |
fig.add_trace(go.Bar( | |
name=metric, | |
y=df['model_name'], # y์ถ์ ๋ชจ๋ธ ์ด๋ฆ | |
x=df[metric], # x์ถ์ ์ฑ๋ฅ ์งํ ๊ฐ | |
text=[f'{val:.3f}' for val in df[metric]], | |
textposition='auto', | |
orientation='h' # ์ํ ๋ฐฉํฅ ๋ง๋ | |
)) | |
fig.update_layout( | |
title='Model Performance Comparison', | |
yaxis_title='Model Name', | |
xaxis_title='Performance', | |
barmode='group', | |
height=max(400, len(df) * 40), # ๋ชจ๋ธ ์์ ๋ฐ๋ผ ๋์ด ์กฐ์ | |
margin=dict(l=left_margin, r=50, t=50, b=50), # ์ผ์ชฝ ๋ง์ง ๋์ ์กฐ์ | |
showlegend=True, | |
legend=dict( | |
orientation="h", | |
yanchor="bottom", | |
y=1.02, | |
xanchor="right", | |
x=1 | |
), | |
yaxis={'categoryorder': 'total ascending'} # ์ฑ๋ฅ ์์ผ๋ก ์ ๋ ฌ | |
) | |
# y์ถ ๋ ์ด๋ธ ์คํ์ผ ์กฐ์ | |
fig.update_yaxes(tickfont=dict(size=10)) # ๊ธ์ ํฌ๊ธฐ ์กฐ์ | |
return fig | |
def create_confusion_matrix(metrics_data, selected_category): | |
"""ํผ๋ ํ๋ ฌ ์๊ฐํ ์์ฑ""" | |
# ์ ํ๋ ์นดํ ๊ณ ๋ฆฌ์ ํผ๋ ํ๋ ฌ ๋ฐ์ดํฐ | |
tp = metrics_data[selected_category]['tp'] | |
tn = metrics_data[selected_category]['tn'] | |
fp = metrics_data[selected_category]['fp'] | |
fn = metrics_data[selected_category]['fn'] | |
# ํผ๋ ํ๋ ฌ ๋ฐ์ดํฐ | |
z = [[tn, fp], [fn, tp]] | |
x = ['Negative', 'Positive'] | |
y = ['Negative', 'Positive'] | |
# ํํธ๋งต ์์ฑ | |
fig = go.Figure(data=go.Heatmap( | |
z=z, | |
x=x, | |
y=y, | |
colorscale=[[0, '#f7fbff'], [1, '#08306b']], | |
showscale=False, | |
text=[[str(val) for val in row] for row in z], | |
texttemplate="%{text}", | |
textfont={"color": "black", "size": 16}, # ๊ธ์ ์์์ ๊ฒ์ ์์ผ๋ก ๊ณ ์ | |
)) | |
# ๋ ์ด์์ ์ ๋ฐ์ดํธ | |
fig.update_layout( | |
title={ | |
'text': f'Confusion Matrix - {selected_category}', | |
'y':0.9, | |
'x':0.5, | |
'xanchor': 'center', | |
'yanchor': 'top' | |
}, | |
xaxis_title='Predicted', | |
yaxis_title='Actual', | |
width=600, # ๋๋น ์ฆ๊ฐ | |
height=600, # ๋์ด ์ฆ๊ฐ | |
margin=dict(l=80, r=80, t=100, b=80), # ์ฌ๋ฐฑ ์กฐ์ | |
paper_bgcolor='white', | |
plot_bgcolor='white', | |
font=dict(size=14) # ์ ์ฒด ํฐํธ ํฌ๊ธฐ ์กฐ์ | |
) | |
# ์ถ ์ค์ | |
fig.update_xaxes(side="bottom", tickfont=dict(size=14)) | |
fig.update_yaxes(side="left", tickfont=dict(size=14)) | |
return fig | |
def get_metrics_for_model(df, model_name, benchmark_name): | |
"""ํน์ ๋ชจ๋ธ๊ณผ ๋ฒค์น๋งํฌ์ ๋ํ ๋ฉํธ๋ฆญ์ค ๋ฐ์ดํฐ ์ถ์ถ""" | |
row = df[(df['Model name'] == model_name) & (df['Benchmark'] == benchmark_name)] | |
if not row.empty: | |
metrics = str2json(row['PIA'].iloc[0]) | |
return metrics | |
return None | |
def metric_visual_tab(): | |
# ๋ฐ์ดํฐ ๋ก๋ | |
df = sheet2df(sheet_name="metric") | |
avg_metrics_df = calculate_avg_metrics(df) | |
# ๊ฐ๋ฅํ ๋ชจ๋ ๋ฉํธ๋ฆญ ๋ฆฌ์คํธ | |
all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1', | |
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'] | |
with gr.Tab("๐ Performance Visualization"): | |
with gr.Row(): | |
metrics_multiselect = gr.CheckboxGroup( | |
choices=all_metrics, | |
value=[], # ์ด๊ธฐ ์ ํ ์์ | |
label="Select Performance Metrics", | |
interactive=True | |
) | |
# Performance comparison chart (์ด๊ธฐ๊ฐ ์์) | |
performance_plot = gr.Plot() | |
def update_plot(selected_metrics): | |
if not selected_metrics: # ์ ํ๋ ๋ฉํธ๋ฆญ์ด ์๋ ๊ฒฝ์ฐ | |
return None | |
try: | |
# accuracy ๊ธฐ์ค์ผ๋ก ์ ๋ ฌ | |
sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True) | |
return create_performance_chart(sorted_df, selected_metrics) | |
except Exception as e: | |
print(f"Error in update_plot: {str(e)}") | |
return None | |
# Connect event handler | |
metrics_multiselect.change( | |
fn=update_plot, | |
inputs=[metrics_multiselect], | |
outputs=[performance_plot] | |
) | |
def create_category_metrics_chart(metrics_data, selected_metrics): | |
""" | |
์ ํ๋ ๋ชจ๋ธ์ ๊ฐ ์นดํ ๊ณ ๋ฆฌ๋ณ ์ฑ๋ฅ ์งํ ์๊ฐํ | |
""" | |
fig = go.Figure() | |
categories = ['falldown', 'violence', 'fire'] | |
for metric in selected_metrics: | |
values = [] | |
for category in categories: | |
values.append(metrics_data[category][metric]) | |
fig.add_trace(go.Bar( | |
name=metric, | |
x=categories, | |
y=values, | |
text=[f'{val:.3f}' for val in values], | |
textposition='auto', | |
)) | |
fig.update_layout( | |
title='Performance Metrics by Category', | |
xaxis_title='Category', | |
yaxis_title='Score', | |
barmode='group', | |
height=500, | |
showlegend=True, | |
legend=dict( | |
orientation="h", | |
yanchor="bottom", | |
y=1.02, | |
xanchor="right", | |
x=1 | |
) | |
) | |
return fig | |
def metric_visual_tab(): | |
# ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ฒซ ๋ฒ์งธ ์๊ฐํ ๋ถ๋ถ | |
df = sheet2df(sheet_name="metric") | |
avg_metrics_df = calculate_avg_metrics(df) | |
# ๊ฐ๋ฅํ ๋ชจ๋ ๋ฉํธ๋ฆญ ๋ฆฌ์คํธ | |
all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1', | |
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'] | |
with gr.Tab("๐ Performance Visualization"): | |
with gr.Row(): | |
metrics_multiselect = gr.CheckboxGroup( | |
choices=all_metrics, | |
value=[], # ์ด๊ธฐ ์ ํ ์์ | |
label="Select Performance Metrics", | |
interactive=True | |
) | |
performance_plot = gr.Plot() | |
def update_plot(selected_metrics): | |
if not selected_metrics: | |
return None | |
try: | |
sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True) | |
return create_performance_chart(sorted_df, selected_metrics) | |
except Exception as e: | |
print(f"Error in update_plot: {str(e)}") | |
return None | |
metrics_multiselect.change( | |
fn=update_plot, | |
inputs=[metrics_multiselect], | |
outputs=[performance_plot] | |
) | |
# ๋ ๋ฒ์งธ ์๊ฐํ ์น์ | |
gr.Markdown("## Detailed Model Analysis") | |
with gr.Row(): | |
# ๋ชจ๋ธ ์ ํ | |
model_dropdown = gr.Dropdown( | |
choices=sorted(df['Model name'].unique().tolist()), | |
label="Select Model", | |
interactive=True | |
) | |
# ์ปฌ๋ผ ์ ํ (Model name ์ ์ธ) | |
column_dropdown = gr.Dropdown( | |
choices=[col for col in df.columns if col != 'Model name'], | |
label="Select Metric Column", | |
interactive=True | |
) | |
# ์นดํ ๊ณ ๋ฆฌ ์ ํ | |
category_dropdown = gr.Dropdown( | |
choices=['falldown', 'violence', 'fire'], | |
label="Select Category", | |
interactive=True | |
) | |
# ํผ๋ ํ๋ ฌ ์๊ฐํ | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("") # ๋น ๊ณต๊ฐ | |
with gr.Column(scale=2): | |
confusion_matrix_plot = gr.Plot(container=True) # container=True ์ถ๊ฐ | |
with gr.Column(scale=1): | |
gr.Markdown("") # ๋น ๊ณต๊ฐ | |
with gr.Column(scale=2): | |
# ์ฑ๋ฅ ์งํ ์ ํ | |
metrics_select = gr.CheckboxGroup( | |
choices=['accuracy', 'precision', 'recall', 'specificity', 'f1', | |
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'], | |
value=['accuracy'], # ๊ธฐ๋ณธ๊ฐ | |
label="Select Metrics to Display", | |
interactive=True | |
) | |
category_metrics_plot = gr.Plot() | |
def update_visualizations(model, column, category, selected_metrics): | |
if not all([model, column]): # category๋ ํผ๋ํ๋ ฌ์๋ง ํ์ | |
return None, None | |
try: | |
# ์ ํ๋ ๋ชจ๋ธ์ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ | |
selected_data = df[df['Model name'] == model][column].iloc[0] | |
metrics = str2json(selected_data) | |
if not metrics: | |
return None, None | |
# ํผ๋ ํ๋ ฌ (์ผ์ชฝ) | |
confusion_fig = create_confusion_matrix(metrics, category) if category else None | |
# ์นดํ ๊ณ ๋ฆฌ๋ณ ์ฑ๋ฅ ์งํ (์ค๋ฅธ์ชฝ) | |
if not selected_metrics: | |
selected_metrics = ['accuracy'] | |
category_fig = create_category_metrics_chart(metrics, selected_metrics) | |
return confusion_fig, category_fig | |
except Exception as e: | |
print(f"Error updating visualizations: {str(e)}") | |
return None, None | |
# ์ด๋ฒคํธ ํธ๋ค๋ฌ ์ฐ๊ฒฐ | |
for input_component in [model_dropdown, column_dropdown, category_dropdown, metrics_select]: | |
input_component.change( | |
fn=update_visualizations, | |
inputs=[model_dropdown, column_dropdown, category_dropdown, metrics_select], | |
outputs=[confusion_matrix_plot, category_metrics_plot] | |
) | |
# def update_confusion_matrix(model, column, category): | |
# if not all([model, column, category]): | |
# return None | |
# try: | |
# # ์ ํ๋ ๋ชจ๋ธ์ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ | |
# selected_data = df[df['Model name'] == model][column].iloc[0] | |
# metrics = str2json(selected_data) | |
# if metrics and category in metrics: | |
# category_data = metrics[category] | |
# # ํผ๋ ํ๋ ฌ ๋ฐ์ดํฐ | |
# confusion_data = { | |
# 'tp': category_data['tp'], | |
# 'tn': category_data['tn'], | |
# 'fp': category_data['fp'], | |
# 'fn': category_data['fn'] | |
# } | |
# # ํํธ๋งต ์์ฑ | |
# z = [[confusion_data['tn'], confusion_data['fp']], | |
# [confusion_data['fn'], confusion_data['tp']]] | |
# fig = go.Figure(data=go.Heatmap( | |
# z=z, | |
# x=['Negative', 'Positive'], | |
# y=['Negative', 'Positive'], | |
# text=[[str(val) for val in row] for row in z], | |
# texttemplate="%{text}", | |
# textfont={"size": 16}, | |
# colorscale='Blues', | |
# showscale=False | |
# )) | |
# fig.update_layout( | |
# title=f'Confusion Matrix - {category}', | |
# xaxis_title='Predicted', | |
# yaxis_title='Actual', | |
# width=500, | |
# height=500 | |
# ) | |
# return fig | |
# except Exception as e: | |
# print(f"Error updating confusion matrix: {str(e)}") | |
# return None | |
# # ์ด๋ฒคํธ ํธ๋ค๋ฌ ์ฐ๊ฒฐ | |
# for dropdown in [model_dropdown, column_dropdown, category_dropdown]: | |
# dropdown.change( | |
# fn=update_confusion_matrix, | |
# inputs=[model_dropdown, column_dropdown, category_dropdown], | |
# outputs=confusion_matrix_plot | |
# ) | |