jisujang's picture
first
a005c19
import gradio as gr
from pathlib import Path
abs_path = Path(__file__).parent
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from sheet_manager.sheet_loader.sheet2df import sheet2df
from sheet_manager.sheet_convert.json2sheet import str2json
# Mock ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
def calculate_avg_metrics(df):
"""
๊ฐ ๋ชจ๋ธ์˜ ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ํ‰๊ท  ์„ฑ๋Šฅ ์ง€ํ‘œ๋ฅผ ๊ณ„์‚ฐ
"""
metrics_data = []
for _, row in df.iterrows():
model_name = row['Model name']
# PIA๊ฐ€ ๋น„์–ด์žˆ๊ฑฐ๋‚˜ ๋‹ค๋ฅธ ๊ฐ’์ธ ๊ฒฝ์šฐ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
if pd.isna(row['PIA']) or not isinstance(row['PIA'], str):
print(f"Skipping model {model_name}: Invalid PIA data")
continue
try:
metrics = str2json(row['PIA'])
# metrics๊ฐ€ None์ด๊ฑฐ๋‚˜ dict๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
if not metrics or not isinstance(metrics, dict):
print(f"Skipping model {model_name}: Invalid JSON format")
continue
# ํ•„์š”ํ•œ ์นดํ…Œ๊ณ ๋ฆฌ๊ฐ€ ๋ชจ๋‘ ์žˆ๋Š”์ง€ ํ™•์ธ
required_categories = ['falldown', 'violence', 'fire']
if not all(cat in metrics for cat in required_categories):
print(f"Skipping model {model_name}: Missing required categories")
continue
# ํ•„์š”ํ•œ ๋ฉ”ํŠธ๋ฆญ์ด ๋ชจ๋‘ ์žˆ๋Š”์ง€ ํ™•์ธ
required_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
avg_metrics = {}
for metric in required_metrics:
try:
values = [metrics[cat][metric] for cat in required_categories
if metric in metrics[cat]]
if values: # ๊ฐ’์ด ์žˆ๋Š” ๊ฒฝ์šฐ๋งŒ ํ‰๊ท  ๊ณ„์‚ฐ
avg_metrics[metric] = sum(values) / len(values)
else:
avg_metrics[metric] = 0 # ๋˜๋Š” ๋‹ค๋ฅธ ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
except (KeyError, TypeError) as e:
print(f"Error calculating {metric} for {model_name}: {str(e)}")
avg_metrics[metric] = 0 # ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
metrics_data.append({
'model_name': model_name,
**avg_metrics
})
except Exception as e:
print(f"Error processing model {model_name}: {str(e)}")
continue
return pd.DataFrame(metrics_data)
def create_performance_chart(df, selected_metrics):
"""
๋ชจ๋ธ๋ณ„ ์„ ํƒ๋œ ์„ฑ๋Šฅ ์ง€ํ‘œ์˜ ์ˆ˜ํ‰ ๋ง‰๋Œ€ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
"""
fig = go.Figure()
# ๋ชจ๋ธ ์ด๋ฆ„ ๊ธธ์ด์— ๋”ฐ๋ฅธ ๋งˆ์ง„ ๊ณ„์‚ฐ
max_name_length = max([len(name) for name in df['model_name']])
left_margin = min(max_name_length * 7, 500) # ๊ธ€์ž ์ˆ˜์— ๋”ฐ๋ผ ๋งˆ์ง„ ์กฐ์ •, ์ตœ๋Œ€ 500
for metric in selected_metrics:
fig.add_trace(go.Bar(
name=metric,
y=df['model_name'], # y์ถ•์— ๋ชจ๋ธ ์ด๋ฆ„
x=df[metric], # x์ถ•์— ์„ฑ๋Šฅ ์ง€ํ‘œ ๊ฐ’
text=[f'{val:.3f}' for val in df[metric]],
textposition='auto',
orientation='h' # ์ˆ˜ํ‰ ๋ฐฉํ–ฅ ๋ง‰๋Œ€
))
fig.update_layout(
title='Model Performance Comparison',
yaxis_title='Model Name',
xaxis_title='Performance',
barmode='group',
height=max(400, len(df) * 40), # ๋ชจ๋ธ ์ˆ˜์— ๋”ฐ๋ผ ๋†’์ด ์กฐ์ •
margin=dict(l=left_margin, r=50, t=50, b=50), # ์™ผ์ชฝ ๋งˆ์ง„ ๋™์  ์กฐ์ •
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
yaxis={'categoryorder': 'total ascending'} # ์„ฑ๋Šฅ ์ˆœ์œผ๋กœ ์ •๋ ฌ
)
# y์ถ• ๋ ˆ์ด๋ธ” ์Šคํƒ€์ผ ์กฐ์ •
fig.update_yaxes(tickfont=dict(size=10)) # ๊ธ€์ž ํฌ๊ธฐ ์กฐ์ •
return fig
def create_confusion_matrix(metrics_data, selected_category):
"""ํ˜ผ๋™ ํ–‰๋ ฌ ์‹œ๊ฐํ™” ์ƒ์„ฑ"""
# ์„ ํƒ๋œ ์นดํ…Œ๊ณ ๋ฆฌ์˜ ํ˜ผ๋™ ํ–‰๋ ฌ ๋ฐ์ดํ„ฐ
tp = metrics_data[selected_category]['tp']
tn = metrics_data[selected_category]['tn']
fp = metrics_data[selected_category]['fp']
fn = metrics_data[selected_category]['fn']
# ํ˜ผ๋™ ํ–‰๋ ฌ ๋ฐ์ดํ„ฐ
z = [[tn, fp], [fn, tp]]
x = ['Negative', 'Positive']
y = ['Negative', 'Positive']
# ํžˆํŠธ๋งต ์ƒ์„ฑ
fig = go.Figure(data=go.Heatmap(
z=z,
x=x,
y=y,
colorscale=[[0, '#f7fbff'], [1, '#08306b']],
showscale=False,
text=[[str(val) for val in row] for row in z],
texttemplate="%{text}",
textfont={"color": "black", "size": 16}, # ๊ธ€์ž ์ƒ‰์ƒ์„ ๊ฒ€์ •์ƒ‰์œผ๋กœ ๊ณ ์ •
))
# ๋ ˆ์ด์•„์›ƒ ์—…๋ฐ์ดํŠธ
fig.update_layout(
title={
'text': f'Confusion Matrix - {selected_category}',
'y':0.9,
'x':0.5,
'xanchor': 'center',
'yanchor': 'top'
},
xaxis_title='Predicted',
yaxis_title='Actual',
width=600, # ๋„ˆ๋น„ ์ฆ๊ฐ€
height=600, # ๋†’์ด ์ฆ๊ฐ€
margin=dict(l=80, r=80, t=100, b=80), # ์—ฌ๋ฐฑ ์กฐ์ •
paper_bgcolor='white',
plot_bgcolor='white',
font=dict(size=14) # ์ „์ฒด ํฐํŠธ ํฌ๊ธฐ ์กฐ์ •
)
# ์ถ• ์„ค์ •
fig.update_xaxes(side="bottom", tickfont=dict(size=14))
fig.update_yaxes(side="left", tickfont=dict(size=14))
return fig
def get_metrics_for_model(df, model_name, benchmark_name):
"""ํŠน์ • ๋ชจ๋ธ๊ณผ ๋ฒค์น˜๋งˆํฌ์— ๋Œ€ํ•œ ๋ฉ”ํŠธ๋ฆญ์Šค ๋ฐ์ดํ„ฐ ์ถ”์ถœ"""
row = df[(df['Model name'] == model_name) & (df['Benchmark'] == benchmark_name)]
if not row.empty:
metrics = str2json(row['PIA'].iloc[0])
return metrics
return None
def metric_visual_tab():
# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
df = sheet2df(sheet_name="metric")
avg_metrics_df = calculate_avg_metrics(df)
# ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๋ฉ”ํŠธ๋ฆญ ๋ฆฌ์ŠคํŠธ
all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
with gr.Tab("๐Ÿ“Š Performance Visualization"):
with gr.Row():
metrics_multiselect = gr.CheckboxGroup(
choices=all_metrics,
value=[], # ์ดˆ๊ธฐ ์„ ํƒ ์—†์Œ
label="Select Performance Metrics",
interactive=True
)
# Performance comparison chart (์ดˆ๊ธฐ๊ฐ’ ์—†์Œ)
performance_plot = gr.Plot()
def update_plot(selected_metrics):
if not selected_metrics: # ์„ ํƒ๋œ ๋ฉ”ํŠธ๋ฆญ์ด ์—†๋Š” ๊ฒฝ์šฐ
return None
try:
# accuracy ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌ
sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
return create_performance_chart(sorted_df, selected_metrics)
except Exception as e:
print(f"Error in update_plot: {str(e)}")
return None
# Connect event handler
metrics_multiselect.change(
fn=update_plot,
inputs=[metrics_multiselect],
outputs=[performance_plot]
)
def create_category_metrics_chart(metrics_data, selected_metrics):
"""
์„ ํƒ๋œ ๋ชจ๋ธ์˜ ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ์„ฑ๋Šฅ ์ง€ํ‘œ ์‹œ๊ฐํ™”
"""
fig = go.Figure()
categories = ['falldown', 'violence', 'fire']
for metric in selected_metrics:
values = []
for category in categories:
values.append(metrics_data[category][metric])
fig.add_trace(go.Bar(
name=metric,
x=categories,
y=values,
text=[f'{val:.3f}' for val in values],
textposition='auto',
))
fig.update_layout(
title='Performance Metrics by Category',
xaxis_title='Category',
yaxis_title='Score',
barmode='group',
height=500,
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
return fig
def metric_visual_tab():
# ๋ฐ์ดํ„ฐ ๋กœ๋“œ ๋ฐ ์ฒซ ๋ฒˆ์งธ ์‹œ๊ฐํ™” ๋ถ€๋ถ„
df = sheet2df(sheet_name="metric")
avg_metrics_df = calculate_avg_metrics(df)
# ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๋ฉ”ํŠธ๋ฆญ ๋ฆฌ์ŠคํŠธ
all_metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1',
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far']
with gr.Tab("๐Ÿ“Š Performance Visualization"):
with gr.Row():
metrics_multiselect = gr.CheckboxGroup(
choices=all_metrics,
value=[], # ์ดˆ๊ธฐ ์„ ํƒ ์—†์Œ
label="Select Performance Metrics",
interactive=True
)
performance_plot = gr.Plot()
def update_plot(selected_metrics):
if not selected_metrics:
return None
try:
sorted_df = avg_metrics_df.sort_values(by='accuracy', ascending=True)
return create_performance_chart(sorted_df, selected_metrics)
except Exception as e:
print(f"Error in update_plot: {str(e)}")
return None
metrics_multiselect.change(
fn=update_plot,
inputs=[metrics_multiselect],
outputs=[performance_plot]
)
# ๋‘ ๋ฒˆ์งธ ์‹œ๊ฐํ™” ์„น์…˜
gr.Markdown("## Detailed Model Analysis")
with gr.Row():
# ๋ชจ๋ธ ์„ ํƒ
model_dropdown = gr.Dropdown(
choices=sorted(df['Model name'].unique().tolist()),
label="Select Model",
interactive=True
)
# ์ปฌ๋Ÿผ ์„ ํƒ (Model name ์ œ์™ธ)
column_dropdown = gr.Dropdown(
choices=[col for col in df.columns if col != 'Model name'],
label="Select Metric Column",
interactive=True
)
# ์นดํ…Œ๊ณ ๋ฆฌ ์„ ํƒ
category_dropdown = gr.Dropdown(
choices=['falldown', 'violence', 'fire'],
label="Select Category",
interactive=True
)
# ํ˜ผ๋™ ํ–‰๋ ฌ ์‹œ๊ฐํ™”
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("") # ๋นˆ ๊ณต๊ฐ„
with gr.Column(scale=2):
confusion_matrix_plot = gr.Plot(container=True) # container=True ์ถ”๊ฐ€
with gr.Column(scale=1):
gr.Markdown("") # ๋นˆ ๊ณต๊ฐ„
with gr.Column(scale=2):
# ์„ฑ๋Šฅ ์ง€ํ‘œ ์„ ํƒ
metrics_select = gr.CheckboxGroup(
choices=['accuracy', 'precision', 'recall', 'specificity', 'f1',
'balanced_accuracy', 'g_mean', 'mcc', 'npv', 'far'],
value=['accuracy'], # ๊ธฐ๋ณธ๊ฐ’
label="Select Metrics to Display",
interactive=True
)
category_metrics_plot = gr.Plot()
def update_visualizations(model, column, category, selected_metrics):
if not all([model, column]): # category๋Š” ํ˜ผ๋™ํ–‰๋ ฌ์—๋งŒ ํ•„์š”
return None, None
try:
# ์„ ํƒ๋œ ๋ชจ๋ธ์˜ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
selected_data = df[df['Model name'] == model][column].iloc[0]
metrics = str2json(selected_data)
if not metrics:
return None, None
# ํ˜ผ๋™ ํ–‰๋ ฌ (์™ผ์ชฝ)
confusion_fig = create_confusion_matrix(metrics, category) if category else None
# ์นดํ…Œ๊ณ ๋ฆฌ๋ณ„ ์„ฑ๋Šฅ ์ง€ํ‘œ (์˜ค๋ฅธ์ชฝ)
if not selected_metrics:
selected_metrics = ['accuracy']
category_fig = create_category_metrics_chart(metrics, selected_metrics)
return confusion_fig, category_fig
except Exception as e:
print(f"Error updating visualizations: {str(e)}")
return None, None
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ
for input_component in [model_dropdown, column_dropdown, category_dropdown, metrics_select]:
input_component.change(
fn=update_visualizations,
inputs=[model_dropdown, column_dropdown, category_dropdown, metrics_select],
outputs=[confusion_matrix_plot, category_metrics_plot]
)
# def update_confusion_matrix(model, column, category):
# if not all([model, column, category]):
# return None
# try:
# # ์„ ํƒ๋œ ๋ชจ๋ธ์˜ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
# selected_data = df[df['Model name'] == model][column].iloc[0]
# metrics = str2json(selected_data)
# if metrics and category in metrics:
# category_data = metrics[category]
# # ํ˜ผ๋™ ํ–‰๋ ฌ ๋ฐ์ดํ„ฐ
# confusion_data = {
# 'tp': category_data['tp'],
# 'tn': category_data['tn'],
# 'fp': category_data['fp'],
# 'fn': category_data['fn']
# }
# # ํžˆํŠธ๋งต ์ƒ์„ฑ
# z = [[confusion_data['tn'], confusion_data['fp']],
# [confusion_data['fn'], confusion_data['tp']]]
# fig = go.Figure(data=go.Heatmap(
# z=z,
# x=['Negative', 'Positive'],
# y=['Negative', 'Positive'],
# text=[[str(val) for val in row] for row in z],
# texttemplate="%{text}",
# textfont={"size": 16},
# colorscale='Blues',
# showscale=False
# ))
# fig.update_layout(
# title=f'Confusion Matrix - {category}',
# xaxis_title='Predicted',
# yaxis_title='Actual',
# width=500,
# height=500
# )
# return fig
# except Exception as e:
# print(f"Error updating confusion matrix: {str(e)}")
# return None
# # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ
# for dropdown in [model_dropdown, column_dropdown, category_dropdown]:
# dropdown.change(
# fn=update_confusion_matrix,
# inputs=[model_dropdown, column_dropdown, category_dropdown],
# outputs=confusion_matrix_plot
# )