Spaces:
Running
Running
/** | |
* LeaderboardTable.tsx | |
* | |
* This component displays a structured table with hierarchical data (groups, subgroups, metrics) | |
* and provides two independent sorting mechanisms: | |
* | |
* 1. Row Sorting: Clicking on a column header (model name) sorts the rows | |
* - Implemented using TanStack Table's built-in sorting | |
* - Controls which rows appear first in the table | |
* - Groups sort against other groups based on their values | |
* - Subgroups stay with their parent group but sort within the group | |
* | |
* 2. Column Sorting: Clicking on a row header sorts the columns (models) | |
* - Custom implementation using modelOrderByOverallMetric | |
* - Controls the order of models (columns) for each metric | |
* - Completely independent of row sorting | |
* | |
* Both sorting mechanisms operate independently and can be used simultaneously. | |
*/ | |
import React, { useEffect, useState, useMemo, useCallback } from 'react' | |
import { ArrowDownTrayIcon } from '@heroicons/react/24/solid' | |
import QualityMetricsTable from './QualityMetricsTable' | |
import MetricInfoIcon from './MetricInfoIcon' | |
import { | |
createColumnHelper, | |
flexRender, | |
getCoreRowModel, | |
useReactTable, | |
ColumnDef, | |
} from '@tanstack/react-table' | |
interface LeaderboardTableProps { | |
benchmarkData: any | |
selectedModels: Set<string> | |
} | |
// Original Row interface - used for the raw data | |
interface OriginalRow { | |
metric: string | |
[key: string]: string | number | |
} | |
// New TableRow interface for the structured hierarchical data | |
interface TableRow { | |
id: string | |
type: 'group' | 'subgroup' | 'metric' | |
groupId?: string | |
subgroupId?: string | |
metricName?: string | |
name: string | |
visible: boolean | |
depth: number | |
isExpanded?: boolean | |
[key: string]: any | |
} | |
interface Groups { | |
[group: string]: { [subgroup: string]: string[] } | |
} | |
// For sorting rows (used when clicking column headers) | |
interface RowSortState { | |
columnId: string | |
direction: 'asc' | 'desc' | |
} | |
// For sorting columns (used when clicking row headers) | |
interface ColumnSortState { | |
rowKey: string | |
direction: 'asc' | 'desc' | |
} | |
const OVERALL_ROW = 'Overall' | |
const DEFAULT_SELECTED_METRICS = new Set(['log10_p_value']) | |
const OverallMetricFilter: React.FC<{ | |
overallMetrics: string[] | |
selectedOverallMetrics: Set<string> | |
setSelectedOverallMetrics: (metrics: Set<string>) => void | |
}> = ({ overallMetrics, selectedOverallMetrics, setSelectedOverallMetrics }) => { | |
const toggleMetric = (metric: string) => { | |
const newSelected = new Set(selectedOverallMetrics) | |
if (newSelected.has(metric)) { | |
newSelected.delete(metric) | |
} else { | |
newSelected.add(metric) | |
} | |
setSelectedOverallMetrics(newSelected) | |
} | |
return ( | |
<div className="w-full"> | |
<fieldset className="fieldset w-full p-4 rounded border border-gray-700 bg-base-200"> | |
<legend className="fieldset-legend font-semibold"> | |
Metrics ({selectedOverallMetrics.size}/{overallMetrics.length}) | |
</legend> | |
<div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-1 max-h-48 overflow-y-auto pr-2"> | |
{overallMetrics.map((metric) => ( | |
<label key={metric} className="flex items-center gap-2 text-sm"> | |
<input | |
type="checkbox" | |
className="form-checkbox h-4 w-4" | |
checked={selectedOverallMetrics.has(metric)} | |
onChange={() => toggleMetric(metric)} | |
/> | |
<div className="flex items-center truncate"> | |
<span className="truncate" title={metric}> | |
{metric} | |
</span> | |
<MetricInfoIcon metricName={metric} /> | |
</div> | |
</label> | |
))} | |
</div> | |
</fieldset> | |
</div> | |
) | |
} | |
const LeaderboardTable: React.FC<LeaderboardTableProps> = ({ benchmarkData, selectedModels }) => { | |
const [rawRows, setRawRows] = useState<OriginalRow[]>([]) | |
const [tableHeader, setTableHeader] = useState<string[]>([]) | |
const [error, setError] = useState<string | null>(null) | |
const [groupRows, setGroupRows] = useState<Groups>({}) | |
const [openGroupRows, setOpenGroupRows] = useState<{ [key: string]: boolean }>({}) | |
const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set()) | |
const [overallMetrics, setOverallMetrics] = useState<string[]>([]) | |
const [selectedOverallMetrics, setSelectedOverallMetrics] = | |
useState<Set<string>>(DEFAULT_SELECTED_METRICS) | |
const [rowSortState, setRowSortState] = useState<RowSortState | null>(null) | |
const [columnSortState, setColumnSortState] = useState<ColumnSortState | null>(null) | |
const [modelOrderByOverallMetric, setModelOrderByOverallMetric] = useState<{ | |
[key: string]: string[] | |
}>({}) | |
// Get filtered models based on selectedModels | |
const models = useMemo(() => { | |
return tableHeader.filter((model) => selectedModels.has(model)) | |
}, [tableHeader, selectedModels]) | |
// Parse benchmark data when it changes | |
useEffect(() => { | |
if (!benchmarkData) { | |
return | |
} | |
try { | |
const data = benchmarkData | |
const rows: OriginalRow[] = data['rows'] | |
const allGroups = data['groups'] as { [key: string]: string[] } | |
const { Overall: overallGroup, ...groups } = allGroups | |
const uniqueMetrics = new Set<string>() | |
overallGroup?.forEach((metric) => { | |
if (metric.includes('_')) { | |
const metricName = metric.split('_').slice(1).join('_') | |
uniqueMetrics.add(metricName) | |
} | |
}) | |
setOverallMetrics(Array.from(uniqueMetrics).sort()) | |
setSelectedOverallMetrics(new Set(DEFAULT_SELECTED_METRICS)) | |
// setSelectedOverallMetrics(new Set(Array.from(uniqueMetrics))) | |
const groupsData = Object.entries(allGroups) | |
.sort(([groupA], [groupB]) => { | |
if (groupA === OVERALL_ROW) return -1 | |
if (groupB === OVERALL_ROW) return 1 | |
return groupA.localeCompare(groupB) | |
}) | |
.reduce( | |
(acc, [group, metrics]) => { | |
const sortedMetrics = [...metrics].sort() | |
acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => { | |
const [mainGroup, subGroup] = metric.split('_') | |
if (!subAcc[mainGroup]) { | |
subAcc[mainGroup] = [] | |
} | |
subAcc[mainGroup].push(metric) | |
return subAcc | |
}, {}) | |
acc[group] = Object.fromEntries( | |
Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) => | |
subGroupA.localeCompare(subGroupB) | |
) | |
) | |
return acc | |
}, | |
{} as { [key: string]: { [key: string]: string[] } } | |
) | |
const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row)))) | |
const headers = allKeys.filter((key) => key !== 'metric') | |
const initialOpenGroups: { [key: string]: boolean } = {} | |
Object.keys(groupsData).forEach((group) => { | |
initialOpenGroups[group] = false | |
}) | |
const allMetrics = Object.values(allGroups).flat() | |
setSelectedMetrics(new Set(allMetrics)) | |
setTableHeader(headers) | |
setRawRows(rows) | |
setGroupRows(groupsData) | |
setOpenGroupRows(initialOpenGroups) | |
// Initialize row sort state for Overall group | |
setColumnSortState({ | |
rowKey: getColumnSortRowKey(OVERALL_ROW, null, null), | |
direction: 'asc', | |
}) | |
// Initialize model order by overall metric | |
const metricOrders: { [key: string]: string[] } = {} | |
Array.from(uniqueMetrics).forEach((metric) => { | |
metricOrders[metric] = [...headers] | |
}) | |
// Store the original model order for resetting when sort is cleared | |
setModelOrderByOverallMetric(metricOrders) | |
setError(null) | |
} catch (err: any) { | |
setError('Failed to parse benchmark data, please try again: ' + err.message) | |
} | |
}, [benchmarkData]) | |
const handleRowSort = (overallMetric: string, model: string) => { | |
// Create the column ID for this metric-model combination | |
const columnId = `${overallMetric}-${model}` | |
let nextDirection: 'asc' | 'desc' | null = null | |
if (!rowSortState || rowSortState.columnId !== columnId) { | |
nextDirection = 'asc' | |
} else if (rowSortState.direction === 'asc') { | |
nextDirection = 'desc' | |
} else { | |
nextDirection = null | |
} | |
setRowSortState(nextDirection ? { columnId, direction: nextDirection } : null) | |
} | |
// Helper to generate a stable composite key for row-based column sorting | |
function getColumnSortRowKey( | |
group: string | null, | |
subGroup: string | null, | |
metric: string | null | |
): string { | |
return `${group ?? ''}||${subGroup ?? ''}||${metric ?? ''}` | |
} | |
// Update the column order when a row's sort icon is clicked | |
const handleColumnSort = ( | |
group: string | null, | |
subGroup: string | null, | |
metric: string | null | |
) => { | |
const rowKey = getColumnSortRowKey(group, subGroup, metric) | |
console.log('Column sort clicked:', { group, subGroup, metric, rowKey }) | |
// First determine the new sort direction | |
let newDirection: 'asc' | 'desc' | null = null | |
if (!columnSortState || columnSortState.rowKey !== rowKey) { | |
// New sort, start with ascending | |
newDirection = 'asc' | |
} else if (columnSortState.direction === 'asc') { | |
// Toggle from ascending to descending | |
newDirection = 'desc' | |
} else { | |
// Toggle from descending to null (clear sort) | |
newDirection = null | |
} | |
setColumnSortState(newDirection ? { rowKey, direction: newDirection } : null) | |
// If clearing the sort, reset to default column order for this metric only | |
if (!newDirection && metric) { | |
setModelOrderByOverallMetric((prev) => { | |
const newOrder = { ...prev } | |
newOrder[metric] = [...tableHeader.filter((model) => selectedModels.has(model))] | |
return newOrder | |
}) | |
} | |
} | |
// Find all metrics matching a particular extracted metric name (like "log10_p_value") | |
const findAllMetricsForName = useCallback( | |
(metricName: string): string[] => { | |
return rawRows | |
.filter((row) => { | |
const metric = row.metric as string | |
if (metric.includes('_')) { | |
const extractedName = metric.split('_').slice(1).join('_') | |
return extractedName.endsWith(metricName) | |
} | |
return false | |
}) | |
.map((row) => row.metric as string) | |
}, | |
[rawRows] | |
) | |
// Identify metrics that don't belong to any overall metric group | |
const findQualityMetrics = useCallback((): string[] => { | |
const allMetrics = rawRows.map((row) => row.metric as string) | |
return allMetrics.filter((metric: string) => { | |
for (const overall of overallMetrics) { | |
if (metric.endsWith(`_${overall}`) || metric === overall) { | |
return false | |
} | |
} | |
return true | |
}) | |
}, [rawRows, overallMetrics]) | |
// Calculate average and standard deviation for a set of metrics for a specific column | |
const calculateStats = useCallback( | |
(metricNames: string[], columnKey: string): { avg: number; stdDev: number } => { | |
const values = metricNames | |
.map((metricName) => { | |
const row = rawRows.find((row) => row.metric === metricName) | |
return row ? Number(row[columnKey]) : NaN | |
}) | |
.filter((value) => !isNaN(value)) | |
if (values.length === 0) return { avg: NaN, stdDev: NaN } | |
const avg = values.reduce((sum, val) => sum + val, 0) / values.length | |
const squareDiffs = values.map((value) => { | |
const diff = value - avg | |
return diff * diff | |
}) | |
const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length | |
const stdDev = Math.sqrt(variance) | |
return { avg, stdDev } | |
}, | |
[rawRows] | |
) | |
// Filter metrics by group and/or subgroup | |
const filterMetricsByGroupAndSubgroup = useCallback( | |
( | |
metricNames: string[], | |
group: string | null = null, | |
subgroup: string | null = null | |
): string[] => { | |
if (!group) return metricNames | |
const groupMetrics = Object.values(groupRows[group] || {}).flat() as string[] | |
if (subgroup && groupRows[group]?.[subgroup]) { | |
return metricNames.filter( | |
(metric) => groupRows[group][subgroup].includes(metric) && selectedMetrics.has(metric) | |
) | |
} | |
return metricNames.filter( | |
(metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric) | |
) | |
}, | |
[groupRows, selectedMetrics] | |
) | |
// Compute visible metrics for rendering | |
const visibleMetrics = overallMetrics.filter((metric) => selectedOverallMetrics.has(metric)) | |
// Generate data for the table | |
const tableData = useMemo(() => { | |
const rows: TableRow[] = [] | |
let groupEntries = Object.entries(groupRows) | |
// --- Manual row sorting using rowSortState --- | |
if (rowSortState) { | |
const { columnId, direction } = rowSortState | |
const [metric, model] = columnId.split('-') | |
groupEntries = [...groupEntries].sort(([groupA, subGroupsA], [groupB, subGroupsB]) => { | |
const allGroupMetricsA = Object.values(subGroupsA).flat() | |
const allGroupMetricsB = Object.values(subGroupsB).flat() | |
const allMetricsWithNameA = findAllMetricsForName(metric) | |
const allMetricsWithNameB = allMetricsWithNameA | |
const metricsInGroupA = allGroupMetricsA.filter((m) => allMetricsWithNameA.includes(m)) | |
const metricsInGroupB = allGroupMetricsB.filter((m) => allMetricsWithNameB.includes(m)) | |
const statsA = calculateStats(metricsInGroupA, model) | |
const statsB = calculateStats(metricsInGroupB, model) | |
const valueA = !isNaN(statsA.avg) ? statsA.avg : -Infinity | |
const valueB = !isNaN(statsB.avg) ? statsB.avg : -Infinity | |
return direction === 'asc' ? valueA - valueB : valueB - valueA | |
}) | |
} | |
groupEntries.forEach(([group, subGroups]) => { | |
const allGroupMetrics = Object.values(subGroups).flat() | |
const visibleGroupMetrics = filterMetricsByGroupAndSubgroup(allGroupMetrics, group) | |
if (visibleGroupMetrics.length === 0) return | |
const groupRow: TableRow = { | |
id: `group-${group}`, | |
type: 'group', | |
name: group, | |
visible: true, | |
depth: 0, | |
isExpanded: openGroupRows[group], | |
} | |
selectedOverallMetrics.forEach((metric) => { | |
if (overallMetrics.includes(metric)) { | |
models.forEach((model) => { | |
const allMetricsWithName = findAllMetricsForName(metric) | |
const metricsInGroupForThisMetric = visibleGroupMetrics.filter((m) => | |
allMetricsWithName.includes(m) | |
) | |
const stats = calculateStats(metricsInGroupForThisMetric, model) | |
groupRow[`${metric}-${model}`] = !isNaN(stats.avg) | |
? { avg: stats.avg, stdDev: stats.stdDev } | |
: null | |
}) | |
} | |
}) | |
rows.push(groupRow) | |
if (openGroupRows[group]) { | |
let subGroupEntries = Object.entries(subGroups).sort(([a], [b]) => a.localeCompare(b)) | |
if (rowSortState) { | |
const { columnId, direction } = rowSortState | |
const [metric, model] = columnId.split('-') | |
subGroupEntries = [...subGroupEntries].sort(([subA, metricsA], [subB, metricsB]) => { | |
const allMetricsWithName = findAllMetricsForName(metric) | |
const metricsInSubgroupA = metricsA.filter((m) => allMetricsWithName.includes(m)) | |
const metricsInSubgroupB = metricsB.filter((m) => allMetricsWithName.includes(m)) | |
const statsA = calculateStats(metricsInSubgroupA, model) | |
const statsB = calculateStats(metricsInSubgroupB, model) | |
const valueA = !isNaN(statsA.avg) ? statsA.avg : -Infinity | |
const valueB = !isNaN(statsB.avg) ? statsB.avg : -Infinity | |
return direction === 'asc' ? valueA - valueB : valueB - valueA | |
}) | |
} | |
subGroupEntries.forEach(([subGroup, metrics]) => { | |
const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup(metrics, group, subGroup) | |
if (visibleSubgroupMetrics.length === 0) return | |
const subgroupRow: TableRow = { | |
id: `group-${group}-subgroup-${subGroup}`, | |
type: 'subgroup', | |
groupId: group, | |
name: subGroup, | |
visible: true, | |
depth: 1, | |
isExpanded: false, | |
} | |
selectedOverallMetrics.forEach((metric) => { | |
if (overallMetrics.includes(metric)) { | |
models.forEach((model) => { | |
const allMetricsWithName = findAllMetricsForName(metric) | |
const metricsInSubgroupForThisMetric = visibleSubgroupMetrics.filter((m) => | |
allMetricsWithName.includes(m) | |
) | |
const stats = calculateStats(metricsInSubgroupForThisMetric, model) | |
subgroupRow[`${metric}-${model}`] = !isNaN(stats.avg) | |
? { avg: stats.avg, stdDev: stats.stdDev } | |
: null | |
}) | |
} | |
}) | |
rows.push(subgroupRow) | |
}) | |
} | |
}) | |
return rows | |
}, [ | |
rawRows, | |
groupRows, | |
openGroupRows, | |
selectedOverallMetrics, | |
selectedMetrics, | |
models, | |
columnSortState, | |
modelOrderByOverallMetric, | |
rowSortState, | |
]) | |
// Effect: update model order when columnSortState or dependencies change | |
useEffect(() => { | |
console.log(columnSortState) | |
if (!columnSortState) return | |
// Parse out group, subGroup, metric from rowKey | |
const [group, subGroup, metric] = columnSortState.rowKey.split('||').map((v) => v || null) | |
const newDirection = columnSortState.direction | |
console.log(newDirection, group, subGroup, metric) | |
if (!newDirection) return // Only run if a sort direction is present | |
// Update model order for all visible metrics | |
const metricsToUpdate = Array.from(selectedOverallMetrics) | |
// Find the row in tableData that was clicked for sorting | |
let rowToSort: TableRow | undefined | |
if (group && subGroup && !metric) { | |
// Subgroup row | |
rowToSort = tableData.find( | |
(row) => row.type === 'subgroup' && row.groupId === group && row.name === subGroup | |
) | |
} else if (group && !subGroup && !metric) { | |
// Group row | |
rowToSort = tableData.find((row) => row.type === 'group' && row.name === group) | |
} else if (metric) { | |
// Metric row - not currently in tableData, handled differently | |
rowToSort = undefined | |
} | |
if (!rowToSort && !metric) { | |
console.log('Row to sort not found', { | |
group, | |
subGroup, | |
metric, | |
rowKey: columnSortState.rowKey, | |
}) | |
// Try to proceed anyway with group/subgroup sorting | |
if (group) { | |
metricsToUpdate.forEach((metricName) => { | |
// Get existing model order | |
const currentOrder = modelOrderByOverallMetric[metricName] || [...models] | |
// For group/subgroup with no row found, keep current model order but reverse it if changing direction | |
if (newDirection === 'asc') { | |
setModelOrderByOverallMetric((prev) => ({ | |
...prev, | |
[metricName]: [...currentOrder], | |
})) | |
} else { | |
setModelOrderByOverallMetric((prev) => ({ | |
...prev, | |
[metricName]: [...currentOrder].reverse(), | |
})) | |
} | |
}) | |
} | |
return | |
} | |
// Check if rowToSort has all the expected metrics | |
for (const metricName of metricsToUpdate) { | |
if ( | |
!rowToSort || | |
!models.some((model) => rowToSort[`${metricName}-${model}`] !== undefined) | |
) { | |
console.log(`Row does not have metric values for ${metricName}`, rowToSort) | |
} | |
} | |
// Sort the models for each metric | |
const newOrders: { [key: string]: string[] } = {} | |
metricsToUpdate.forEach((metricName) => { | |
// Sort models based on the values in the clicked row | |
const modelScores: { model: string; score: number }[] = models.map((model: string) => { | |
let score = -Infinity | |
if (rowToSort) { | |
// For group/subgroup rows, use the aggregated values in the row for each metric | |
const value: { avg: number; stdDev: number } | null = | |
rowToSort[`${metricName}-${model}`] ?? null | |
score = value && !isNaN(value.avg) ? value.avg : -Infinity | |
} else if (metric) { | |
// For metric rows (which aren't in tableData), we need a different approach | |
// Find metrics for this group that have this metric name | |
const allMetricsWithName = findAllMetricsForName(metric) | |
if (allMetricsWithName.length > 0) { | |
const values = allMetricsWithName | |
.map((metricId) => { | |
const row = rawRows.find((r) => r.metric === metricId) | |
return row ? Number(row[model]) : NaN | |
}) | |
.filter((val) => !isNaN(val)) | |
if (values.length > 0) { | |
const avg = values.reduce((sum, val) => sum + val, 0) / values.length | |
score = !isNaN(avg) ? avg : -Infinity | |
} | |
} | |
} | |
return { model, score } | |
}) | |
modelScores.sort((a, b) => (newDirection === 'asc' ? a.score - b.score : b.score - a.score)) | |
newOrders[metricName] = modelScores | |
.map((item) => item.model) | |
.filter((m) => selectedModels.has(m)) | |
}) | |
// Only update if any order actually changed | |
setModelOrderByOverallMetric((prev) => { | |
let changed = false | |
const next = { ...prev } | |
metricsToUpdate.forEach((metricName) => { | |
const currentOrder = prev[metricName] || [] | |
const newOrder = newOrders[metricName] || [] | |
const arraysEqual = (a: string[], b: string[]) => | |
a.length === b.length && a.every((v, i) => v === b[i]) | |
if (!arraysEqual(currentOrder, newOrder)) { | |
next[metricName] = [...newOrder] | |
changed = true | |
} | |
}) | |
return changed ? next : prev | |
}) | |
}, [ | |
columnSortState, | |
models, | |
selectedModels, | |
modelOrderByOverallMetric, | |
tableData, | |
rawRows, | |
selectedOverallMetrics, | |
]) | |
console.log(modelOrderByOverallMetric) | |
// CSV export function | |
const exportToCsv = () => { | |
// Build header row | |
const header = [ | |
'Attack Categories', | |
...overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.flatMap((metric) => { | |
const metricModels = modelOrderByOverallMetric[metric] || models | |
return metricModels.map((model) => `${metric} - ${model}`) | |
}), | |
] | |
// Build data rows | |
const rows: (string | number)[][] = [] | |
tableData.forEach((row) => { | |
const csvRow: (string | number)[] = [row.name] | |
overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.forEach((metric) => { | |
const metricModels = modelOrderByOverallMetric[metric] || models | |
metricModels.forEach((model: string) => { | |
const value = row[`${metric}-${model}`] as { avg: number; stdDev: number } | null | |
if (!value) { | |
csvRow.push('N/A') | |
} else { | |
csvRow.push(`${value.avg.toFixed(3)} ± ${value.stdDev.toFixed(3)}`) | |
} | |
}) | |
}) | |
rows.push(csvRow) | |
}) | |
// Generate CSV | |
const csv = [header, ...rows] | |
.map((row) => row.map((cell) => `"${String(cell).replace(/"/g, '""')}"`).join(',')) | |
.join('\n') | |
// Download | |
const blob = new Blob([csv], { type: 'text/csv' }) | |
const url = URL.createObjectURL(blob) | |
const a = document.createElement('a') | |
a.href = url | |
a.download = 'leaderboard_metrics.csv' | |
document.body.appendChild(a) | |
a.click() | |
document.body.removeChild(a) | |
URL.revokeObjectURL(url) | |
} | |
// Toggle group expansion | |
const toggleGroup = (group: string) => { | |
setOpenGroupRows((prev) => ({ | |
...prev, | |
[group]: !prev[group], | |
})) | |
} | |
// Helper to get current column sort config for a row | |
function getColumnSort(group: string | null, subGroup: string | null, metric: string | null) { | |
const rowKey = getColumnSortRowKey(group, subGroup, metric) | |
return columnSortState && columnSortState.rowKey === rowKey ? columnSortState : null | |
} | |
// Prepare columns for TanStack Table | |
const columns = useMemo<any[]>(() => { | |
const columnHelper = createColumnHelper<TableRow>() | |
const cols: any[] = [] | |
cols.push( | |
columnHelper.accessor((row) => row.name, { | |
id: 'category', | |
header: () => 'Attack Categories', | |
cell: (info) => { | |
const row = info.row.original as TableRow | |
const depth = row.depth || 0 | |
if (row.type === 'group') { | |
return ( | |
<div | |
className="sticky left-0 font-medium cursor-pointer select-none flex items-center" | |
onClick={() => toggleGroup(row.name)} | |
> | |
<span>{row.isExpanded ? '▼ ' : '▶ '}</span> | |
<span className="flex-1">{row.name}</span>{' '} | |
<span | |
className="ml-1 cursor-pointer font-bold" | |
onClick={(e) => { | |
e.stopPropagation() | |
handleColumnSort(row.name, null, null) | |
}} | |
title={ | |
getColumnSort(row.name, null, null) | |
? getColumnSort(row.name, null, null)?.direction === 'asc' | |
? 'Currently sorting models by this row in ascending order (low to high). Click for descending order.' | |
: 'Currently sorting models by this row in descending order (high to low). Click to clear sort.' | |
: 'Click to sort models by values in this row (independent of row sorting)' | |
} | |
> | |
{getColumnSort(row.name, null, null) | |
? getColumnSort(row.name, null, null)?.direction === 'asc' | |
? '→' | |
: '←' | |
: '⇆'} | |
</span> | |
</div> | |
) | |
} else if (row.type === 'subgroup') { | |
return ( | |
<div className="sticky left-0 pl-6 font-medium flex items-center gap-1"> | |
<span className="flex-1">{row.name}</span> | |
<span | |
className="ml-1 cursor-pointer font-bold" | |
onClick={(e) => { | |
e.stopPropagation() | |
handleColumnSort(row.groupId!, row.name, null) | |
}} | |
title={ | |
getColumnSort(row.groupId!, row.name, null) | |
? getColumnSort(row.groupId!, row.name, null)?.direction === 'asc' | |
? 'Currently sorting models by this subgroup in ascending order (low to high). Click for descending order.' | |
: 'Currently sorting models by this subgroup in descending order (high to low). Click to clear sort.' | |
: 'Click to sort models by values in this subgroup (independent of row sorting)' | |
} | |
> | |
{getColumnSort(row.groupId!, row.name, null) | |
? getColumnSort(row.groupId!, row.name, null)?.direction === 'asc' | |
? '→' | |
: '←' | |
: '⇆'} | |
</span> | |
</div> | |
) | |
} else { | |
// Metric row (add column sorting for model order) | |
return ( | |
<div className="sticky left-0 pl-12 font-medium flex items-center gap-1"> | |
<span className="flex-1">{row.name}</span> | |
<span | |
className="ml-1 cursor-pointer font-bold" | |
onClick={(e) => { | |
e.stopPropagation() | |
handleColumnSort( | |
row.groupId ?? null, | |
row.subgroupId ?? null, | |
row.metricName ?? row.name | |
) | |
}} | |
title={ | |
getColumnSort( | |
row.groupId ?? null, | |
row.subgroupId ?? null, | |
row.metricName ?? row.name | |
) | |
? getColumnSort( | |
row.groupId ?? null, | |
row.subgroupId ?? null, | |
row.metricName ?? row.name | |
)?.direction === 'asc' | |
? 'Currently sorting models by this metric in ascending order (low to high). Click for descending order.' | |
: 'Currently sorting models by this metric in descending order (high to low). Click to clear sort.' | |
: 'Click to sort models by values in this metric (independent of row sorting)' | |
} | |
> | |
{getColumnSort( | |
row.groupId ?? null, | |
row.subgroupId ?? null, | |
row.metricName ?? row.name | |
) | |
? getColumnSort( | |
row.groupId ?? null, | |
row.subgroupId ?? null, | |
row.metricName ?? row.name | |
)?.direction === 'asc' | |
? '→' | |
: '←' | |
: '⇆'} | |
</span> | |
</div> | |
) | |
} | |
}, | |
}) | |
) | |
overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.forEach((metric) => { | |
const metricModels = modelOrderByOverallMetric[metric] || models | |
metricModels.forEach((model: string) => { | |
cols.push( | |
columnHelper.accessor((row) => row[`${metric}-${model}`], { | |
id: `${metric}-${model}`, | |
header: () => { | |
const isSorted = rowSortState && rowSortState.columnId === `${metric}-${model}` | |
const direction = rowSortState ? rowSortState.direction : 'desc' | |
return ( | |
<div | |
className="cursor-pointer select-none" | |
onClick={() => handleRowSort(metric, model)} | |
> | |
{model} | |
<span | |
className="ml-1 font-bold" | |
title={ | |
isSorted | |
? direction === 'asc' | |
? 'Currently sorting rows by this column in ascending order (low to high). Click for descending order.' | |
: 'Currently sorting rows by this column in descending order (high to low). Click to clear sort.' | |
: 'Click to sort rows by values in this column (subgroups always stay with their parent group)' | |
} | |
> | |
{isSorted ? (direction === 'asc' ? '↑' : '↓') : '⇅'} | |
</span> | |
</div> | |
) | |
}, | |
cell: (info) => { | |
const value = info.getValue() as { avg: number; stdDev: number } | null | |
if (!value) return 'N/A' | |
return `${value.avg.toFixed(3)} ± ${value.stdDev.toFixed(3)}` | |
}, | |
}) | |
) | |
}) | |
}) | |
return cols | |
}, [ | |
selectedOverallMetrics, | |
overallMetrics, | |
modelOrderByOverallMetric, | |
rowSortState, | |
columnSortState, | |
models, | |
]) | |
// Create the table instance | |
const table = useReactTable({ | |
data: tableData, | |
columns, | |
getCoreRowModel: getCoreRowModel(), | |
}) | |
return ( | |
<div className="rounded"> | |
{error && <div className="text-red-500">{error}</div>} | |
{!error && ( | |
<div className="flex flex-col gap-4"> | |
<div className="flex flex-col gap-4"> | |
<OverallMetricFilter | |
overallMetrics={overallMetrics} | |
selectedOverallMetrics={selectedOverallMetrics} | |
setSelectedOverallMetrics={setSelectedOverallMetrics} | |
/> | |
{/* <LeaderboardFilter | |
groups={groupRows} | |
selectedMetrics={selectedMetrics} | |
setSelectedMetrics={setSelectedMetrics} | |
/> */} | |
</div> | |
{selectedModels.size === 0 || | |
selectedMetrics.size === 0 || | |
visibleMetrics.length === 0 ? ( | |
<div className="text-center p-4 text-lg"> | |
Please select at least one model and one metric to display the data | |
</div> | |
) : ( | |
<> | |
{/* Quality metrics table */} | |
<QualityMetricsTable | |
qualityMetrics={findQualityMetrics()} | |
tableHeader={tableHeader} | |
selectedModels={selectedModels} | |
tableRows={rawRows} | |
/> | |
{/* Main metrics table */} | |
<div className="relative flex justify-end mb-6"> | |
<button | |
className="absolute top-0 right-0 btn btn-ghost btn-circle" | |
title="Export CSV" | |
onClick={exportToCsv} | |
> | |
<ArrowDownTrayIcon className="h-6 w-6" /> | |
</button> | |
</div> | |
<div className="overflow-x-auto max-h-[80vh] overflow-y-auto"> | |
<table className="table w-full min-w-max border-separate border-spacing-0 border-gray-700 border"> | |
<thead> | |
<tr> | |
<th className="sticky left-0 top-0 bg-base-100 z-20 border border-gray-700"> | |
Attack Categories | |
</th> | |
{/* Add metric group headers */} | |
{overallMetrics | |
.filter((metric) => selectedOverallMetrics.has(metric)) | |
.map((metric) => ( | |
<th | |
key={`header-metric-${metric}`} | |
className="sticky top-0 bg-base-100 z-10 text-center text-xs border border-gray-700 select-none" | |
colSpan={(modelOrderByOverallMetric[metric] || models).length} | |
> | |
<div className="flex items-center justify-center"> | |
<span>{metric}</span> | |
<MetricInfoIcon metricName={metric} /> | |
</div> | |
</th> | |
))} | |
</tr> | |
{/* Add model headers */} | |
<tr> | |
<th className="sticky left-0 top-12 bg-base-100 z-30 border border-gray-700"></th> | |
{table | |
.getHeaderGroups()[0] | |
.headers.slice(1) | |
.map((header) => ( | |
<th | |
key={header.id} | |
className="sticky top-12 bg-base-100 z-10 text-center text-xs border border-gray-700" | |
> | |
{header.isPlaceholder | |
? null | |
: flexRender(header.column.columnDef.header, header.getContext())} | |
</th> | |
))} | |
</tr> | |
</thead> | |
<tbody> | |
{table.getRowModel().rows.map((row) => ( | |
<tr | |
key={row.id} | |
className={`${ | |
row.original.type === 'group' | |
? 'bg-base-200 hover:bg-base-300' | |
: 'bg-base-100 hover:bg-base-200' | |
}`} | |
> | |
{row.getVisibleCells().map((cell) => ( | |
<td | |
key={cell.id} | |
className={`${ | |
cell.column.id === 'category' | |
? `sticky left-0 ${row.original.type === 'group' ? 'bg-base-200' : 'bg-base-100'} z-10` | |
: 'font-medium text-center' | |
} border-gray-700 border`} | |
> | |
{flexRender(cell.column.columnDef.cell, cell.getContext())} | |
</td> | |
))} | |
</tr> | |
))} | |
</tbody> | |
</table> | |
</div> | |
</> | |
)} | |
</div> | |
)} | |
</div> | |
) | |
} | |
export default LeaderboardTable | |