/** * LeaderboardTable.tsx * * This component displays a structured table with hierarchical data (groups, subgroups, metrics) * and provides two independent sorting mechanisms: * * 1. Row Sorting: Clicking on a column header (model name) sorts the rows * - Implemented using TanStack Table's built-in sorting * - Controls which rows appear first in the table * - Groups sort against other groups based on their values * - Subgroups stay with their parent group but sort within the group * * 2. Column Sorting: Clicking on a row header sorts the columns (models) * - Custom implementation using modelOrderByOverallMetric * - Controls the order of models (columns) for each metric * - Completely independent of row sorting * * Both sorting mechanisms operate independently and can be used simultaneously. */ import React, { useEffect, useState, useMemo, useCallback } from 'react' import { ArrowDownTrayIcon } from '@heroicons/react/24/solid' import QualityMetricsTable from './QualityMetricsTable' import MetricInfoIcon from './MetricInfoIcon' import { createColumnHelper, flexRender, getCoreRowModel, useReactTable, ColumnDef, } from '@tanstack/react-table' interface LeaderboardTableProps { benchmarkData: any selectedModels: Set } // Original Row interface - used for the raw data interface OriginalRow { metric: string [key: string]: string | number } // New TableRow interface for the structured hierarchical data interface TableRow { id: string type: 'group' | 'subgroup' | 'metric' groupId?: string subgroupId?: string metricName?: string name: string visible: boolean depth: number isExpanded?: boolean [key: string]: any } interface Groups { [group: string]: { [subgroup: string]: string[] } } // For sorting rows (used when clicking column headers) interface RowSortState { columnId: string direction: 'asc' | 'desc' } // For sorting columns (used when clicking row headers) interface ColumnSortState { rowKey: string direction: 'asc' | 'desc' } const OVERALL_ROW = 'Overall' const DEFAULT_SELECTED_METRICS = new Set(['log10_p_value']) const OverallMetricFilter: React.FC<{ overallMetrics: string[] selectedOverallMetrics: Set setSelectedOverallMetrics: (metrics: Set) => void }> = ({ overallMetrics, selectedOverallMetrics, setSelectedOverallMetrics }) => { const toggleMetric = (metric: string) => { const newSelected = new Set(selectedOverallMetrics) if (newSelected.has(metric)) { newSelected.delete(metric) } else { newSelected.add(metric) } setSelectedOverallMetrics(newSelected) } return (
Metrics ({selectedOverallMetrics.size}/{overallMetrics.length})
{overallMetrics.map((metric) => ( ))}
) } const LeaderboardTable: React.FC = ({ benchmarkData, selectedModels }) => { const [rawRows, setRawRows] = useState([]) const [tableHeader, setTableHeader] = useState([]) const [error, setError] = useState(null) const [groupRows, setGroupRows] = useState({}) const [openGroupRows, setOpenGroupRows] = useState<{ [key: string]: boolean }>({}) const [selectedMetrics, setSelectedMetrics] = useState>(new Set()) const [overallMetrics, setOverallMetrics] = useState([]) const [selectedOverallMetrics, setSelectedOverallMetrics] = useState>(DEFAULT_SELECTED_METRICS) const [rowSortState, setRowSortState] = useState(null) const [columnSortState, setColumnSortState] = useState(null) const [modelOrderByOverallMetric, setModelOrderByOverallMetric] = useState<{ [key: string]: string[] }>({}) // Get filtered models based on selectedModels const models = useMemo(() => { return tableHeader.filter((model) => selectedModels.has(model)) }, [tableHeader, selectedModels]) // Parse benchmark data when it changes useEffect(() => { if (!benchmarkData) { return } try { const data = benchmarkData const rows: OriginalRow[] = data['rows'] const allGroups = data['groups'] as { [key: string]: string[] } const { Overall: overallGroup, ...groups } = allGroups const uniqueMetrics = new Set() overallGroup?.forEach((metric) => { if (metric.includes('_')) { const metricName = metric.split('_').slice(1).join('_') uniqueMetrics.add(metricName) } }) setOverallMetrics(Array.from(uniqueMetrics).sort()) setSelectedOverallMetrics(new Set(DEFAULT_SELECTED_METRICS)) // setSelectedOverallMetrics(new Set(Array.from(uniqueMetrics))) const groupsData = Object.entries(allGroups) .sort(([groupA], [groupB]) => { if (groupA === OVERALL_ROW) return -1 if (groupB === OVERALL_ROW) return 1 return groupA.localeCompare(groupB) }) .reduce( (acc, [group, metrics]) => { const sortedMetrics = [...metrics].sort() acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => { const [mainGroup, subGroup] = metric.split('_') if (!subAcc[mainGroup]) { subAcc[mainGroup] = [] } subAcc[mainGroup].push(metric) return subAcc }, {}) acc[group] = Object.fromEntries( Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) => subGroupA.localeCompare(subGroupB) ) ) return acc }, {} as { [key: string]: { [key: string]: string[] } } ) const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row)))) const headers = allKeys.filter((key) => key !== 'metric') const initialOpenGroups: { [key: string]: boolean } = {} Object.keys(groupsData).forEach((group) => { initialOpenGroups[group] = false }) const allMetrics = Object.values(allGroups).flat() setSelectedMetrics(new Set(allMetrics)) setTableHeader(headers) setRawRows(rows) setGroupRows(groupsData) setOpenGroupRows(initialOpenGroups) // Initialize row sort state for Overall group setColumnSortState({ rowKey: getColumnSortRowKey(OVERALL_ROW, null, null), direction: 'asc', }) // Initialize model order by overall metric const metricOrders: { [key: string]: string[] } = {} Array.from(uniqueMetrics).forEach((metric) => { metricOrders[metric] = [...headers] }) // Store the original model order for resetting when sort is cleared setModelOrderByOverallMetric(metricOrders) setError(null) } catch (err: any) { setError('Failed to parse benchmark data, please try again: ' + err.message) } }, [benchmarkData]) const handleRowSort = (overallMetric: string, model: string) => { // Create the column ID for this metric-model combination const columnId = `${overallMetric}-${model}` let nextDirection: 'asc' | 'desc' | null = null if (!rowSortState || rowSortState.columnId !== columnId) { nextDirection = 'asc' } else if (rowSortState.direction === 'asc') { nextDirection = 'desc' } else { nextDirection = null } setRowSortState(nextDirection ? { columnId, direction: nextDirection } : null) } // Helper to generate a stable composite key for row-based column sorting function getColumnSortRowKey( group: string | null, subGroup: string | null, metric: string | null ): string { return `${group ?? ''}||${subGroup ?? ''}||${metric ?? ''}` } // Update the column order when a row's sort icon is clicked const handleColumnSort = ( group: string | null, subGroup: string | null, metric: string | null ) => { const rowKey = getColumnSortRowKey(group, subGroup, metric) console.log('Column sort clicked:', { group, subGroup, metric, rowKey }) // First determine the new sort direction let newDirection: 'asc' | 'desc' | null = null if (!columnSortState || columnSortState.rowKey !== rowKey) { // New sort, start with ascending newDirection = 'asc' } else if (columnSortState.direction === 'asc') { // Toggle from ascending to descending newDirection = 'desc' } else { // Toggle from descending to null (clear sort) newDirection = null } setColumnSortState(newDirection ? { rowKey, direction: newDirection } : null) // If clearing the sort, reset to default column order for this metric only if (!newDirection && metric) { setModelOrderByOverallMetric((prev) => { const newOrder = { ...prev } newOrder[metric] = [...tableHeader.filter((model) => selectedModels.has(model))] return newOrder }) } } // Find all metrics matching a particular extracted metric name (like "log10_p_value") const findAllMetricsForName = useCallback( (metricName: string): string[] => { return rawRows .filter((row) => { const metric = row.metric as string if (metric.includes('_')) { const extractedName = metric.split('_').slice(1).join('_') return extractedName.endsWith(metricName) } return false }) .map((row) => row.metric as string) }, [rawRows] ) // Identify metrics that don't belong to any overall metric group const findQualityMetrics = useCallback((): string[] => { const allMetrics = rawRows.map((row) => row.metric as string) return allMetrics.filter((metric: string) => { for (const overall of overallMetrics) { if (metric.endsWith(`_${overall}`) || metric === overall) { return false } } return true }) }, [rawRows, overallMetrics]) // Calculate average and standard deviation for a set of metrics for a specific column const calculateStats = useCallback( (metricNames: string[], columnKey: string): { avg: number; stdDev: number } => { const values = metricNames .map((metricName) => { const row = rawRows.find((row) => row.metric === metricName) return row ? Number(row[columnKey]) : NaN }) .filter((value) => !isNaN(value)) if (values.length === 0) return { avg: NaN, stdDev: NaN } const avg = values.reduce((sum, val) => sum + val, 0) / values.length const squareDiffs = values.map((value) => { const diff = value - avg return diff * diff }) const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length const stdDev = Math.sqrt(variance) return { avg, stdDev } }, [rawRows] ) // Filter metrics by group and/or subgroup const filterMetricsByGroupAndSubgroup = useCallback( ( metricNames: string[], group: string | null = null, subgroup: string | null = null ): string[] => { if (!group) return metricNames const groupMetrics = Object.values(groupRows[group] || {}).flat() as string[] if (subgroup && groupRows[group]?.[subgroup]) { return metricNames.filter( (metric) => groupRows[group][subgroup].includes(metric) && selectedMetrics.has(metric) ) } return metricNames.filter( (metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric) ) }, [groupRows, selectedMetrics] ) // Compute visible metrics for rendering const visibleMetrics = overallMetrics.filter((metric) => selectedOverallMetrics.has(metric)) // Generate data for the table const tableData = useMemo(() => { const rows: TableRow[] = [] let groupEntries = Object.entries(groupRows) // --- Manual row sorting using rowSortState --- if (rowSortState) { const { columnId, direction } = rowSortState const [metric, model] = columnId.split('-') groupEntries = [...groupEntries].sort(([groupA, subGroupsA], [groupB, subGroupsB]) => { const allGroupMetricsA = Object.values(subGroupsA).flat() const allGroupMetricsB = Object.values(subGroupsB).flat() const allMetricsWithNameA = findAllMetricsForName(metric) const allMetricsWithNameB = allMetricsWithNameA const metricsInGroupA = allGroupMetricsA.filter((m) => allMetricsWithNameA.includes(m)) const metricsInGroupB = allGroupMetricsB.filter((m) => allMetricsWithNameB.includes(m)) const statsA = calculateStats(metricsInGroupA, model) const statsB = calculateStats(metricsInGroupB, model) const valueA = !isNaN(statsA.avg) ? statsA.avg : -Infinity const valueB = !isNaN(statsB.avg) ? statsB.avg : -Infinity return direction === 'asc' ? valueA - valueB : valueB - valueA }) } groupEntries.forEach(([group, subGroups]) => { const allGroupMetrics = Object.values(subGroups).flat() const visibleGroupMetrics = filterMetricsByGroupAndSubgroup(allGroupMetrics, group) if (visibleGroupMetrics.length === 0) return const groupRow: TableRow = { id: `group-${group}`, type: 'group', name: group, visible: true, depth: 0, isExpanded: openGroupRows[group], } selectedOverallMetrics.forEach((metric) => { if (overallMetrics.includes(metric)) { models.forEach((model) => { const allMetricsWithName = findAllMetricsForName(metric) const metricsInGroupForThisMetric = visibleGroupMetrics.filter((m) => allMetricsWithName.includes(m) ) const stats = calculateStats(metricsInGroupForThisMetric, model) groupRow[`${metric}-${model}`] = !isNaN(stats.avg) ? { avg: stats.avg, stdDev: stats.stdDev } : null }) } }) rows.push(groupRow) if (openGroupRows[group]) { let subGroupEntries = Object.entries(subGroups).sort(([a], [b]) => a.localeCompare(b)) if (rowSortState) { const { columnId, direction } = rowSortState const [metric, model] = columnId.split('-') subGroupEntries = [...subGroupEntries].sort(([subA, metricsA], [subB, metricsB]) => { const allMetricsWithName = findAllMetricsForName(metric) const metricsInSubgroupA = metricsA.filter((m) => allMetricsWithName.includes(m)) const metricsInSubgroupB = metricsB.filter((m) => allMetricsWithName.includes(m)) const statsA = calculateStats(metricsInSubgroupA, model) const statsB = calculateStats(metricsInSubgroupB, model) const valueA = !isNaN(statsA.avg) ? statsA.avg : -Infinity const valueB = !isNaN(statsB.avg) ? statsB.avg : -Infinity return direction === 'asc' ? valueA - valueB : valueB - valueA }) } subGroupEntries.forEach(([subGroup, metrics]) => { const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup(metrics, group, subGroup) if (visibleSubgroupMetrics.length === 0) return const subgroupRow: TableRow = { id: `group-${group}-subgroup-${subGroup}`, type: 'subgroup', groupId: group, name: subGroup, visible: true, depth: 1, isExpanded: false, } selectedOverallMetrics.forEach((metric) => { if (overallMetrics.includes(metric)) { models.forEach((model) => { const allMetricsWithName = findAllMetricsForName(metric) const metricsInSubgroupForThisMetric = visibleSubgroupMetrics.filter((m) => allMetricsWithName.includes(m) ) const stats = calculateStats(metricsInSubgroupForThisMetric, model) subgroupRow[`${metric}-${model}`] = !isNaN(stats.avg) ? { avg: stats.avg, stdDev: stats.stdDev } : null }) } }) rows.push(subgroupRow) }) } }) return rows }, [ rawRows, groupRows, openGroupRows, selectedOverallMetrics, selectedMetrics, models, columnSortState, modelOrderByOverallMetric, rowSortState, ]) // Effect: update model order when columnSortState or dependencies change useEffect(() => { console.log(columnSortState) if (!columnSortState) return // Parse out group, subGroup, metric from rowKey const [group, subGroup, metric] = columnSortState.rowKey.split('||').map((v) => v || null) const newDirection = columnSortState.direction console.log(newDirection, group, subGroup, metric) if (!newDirection) return // Only run if a sort direction is present // Update model order for all visible metrics const metricsToUpdate = Array.from(selectedOverallMetrics) // Find the row in tableData that was clicked for sorting let rowToSort: TableRow | undefined if (group && subGroup && !metric) { // Subgroup row rowToSort = tableData.find( (row) => row.type === 'subgroup' && row.groupId === group && row.name === subGroup ) } else if (group && !subGroup && !metric) { // Group row rowToSort = tableData.find((row) => row.type === 'group' && row.name === group) } else if (metric) { // Metric row - not currently in tableData, handled differently rowToSort = undefined } if (!rowToSort && !metric) { console.log('Row to sort not found', { group, subGroup, metric, rowKey: columnSortState.rowKey, }) // Try to proceed anyway with group/subgroup sorting if (group) { metricsToUpdate.forEach((metricName) => { // Get existing model order const currentOrder = modelOrderByOverallMetric[metricName] || [...models] // For group/subgroup with no row found, keep current model order but reverse it if changing direction if (newDirection === 'asc') { setModelOrderByOverallMetric((prev) => ({ ...prev, [metricName]: [...currentOrder], })) } else { setModelOrderByOverallMetric((prev) => ({ ...prev, [metricName]: [...currentOrder].reverse(), })) } }) } return } // Check if rowToSort has all the expected metrics for (const metricName of metricsToUpdate) { if ( !rowToSort || !models.some((model) => rowToSort[`${metricName}-${model}`] !== undefined) ) { console.log(`Row does not have metric values for ${metricName}`, rowToSort) } } // Sort the models for each metric const newOrders: { [key: string]: string[] } = {} metricsToUpdate.forEach((metricName) => { // Sort models based on the values in the clicked row const modelScores: { model: string; score: number }[] = models.map((model: string) => { let score = -Infinity if (rowToSort) { // For group/subgroup rows, use the aggregated values in the row for each metric const value: { avg: number; stdDev: number } | null = rowToSort[`${metricName}-${model}`] ?? null score = value && !isNaN(value.avg) ? value.avg : -Infinity } else if (metric) { // For metric rows (which aren't in tableData), we need a different approach // Find metrics for this group that have this metric name const allMetricsWithName = findAllMetricsForName(metric) if (allMetricsWithName.length > 0) { const values = allMetricsWithName .map((metricId) => { const row = rawRows.find((r) => r.metric === metricId) return row ? Number(row[model]) : NaN }) .filter((val) => !isNaN(val)) if (values.length > 0) { const avg = values.reduce((sum, val) => sum + val, 0) / values.length score = !isNaN(avg) ? avg : -Infinity } } } return { model, score } }) modelScores.sort((a, b) => (newDirection === 'asc' ? a.score - b.score : b.score - a.score)) newOrders[metricName] = modelScores .map((item) => item.model) .filter((m) => selectedModels.has(m)) }) // Only update if any order actually changed setModelOrderByOverallMetric((prev) => { let changed = false const next = { ...prev } metricsToUpdate.forEach((metricName) => { const currentOrder = prev[metricName] || [] const newOrder = newOrders[metricName] || [] const arraysEqual = (a: string[], b: string[]) => a.length === b.length && a.every((v, i) => v === b[i]) if (!arraysEqual(currentOrder, newOrder)) { next[metricName] = [...newOrder] changed = true } }) return changed ? next : prev }) }, [ columnSortState, models, selectedModels, modelOrderByOverallMetric, tableData, rawRows, selectedOverallMetrics, ]) console.log(modelOrderByOverallMetric) // CSV export function const exportToCsv = () => { // Build header row const header = [ 'Attack Categories', ...overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .flatMap((metric) => { const metricModels = modelOrderByOverallMetric[metric] || models return metricModels.map((model) => `${metric} - ${model}`) }), ] // Build data rows const rows: (string | number)[][] = [] tableData.forEach((row) => { const csvRow: (string | number)[] = [row.name] overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .forEach((metric) => { const metricModels = modelOrderByOverallMetric[metric] || models metricModels.forEach((model: string) => { const value = row[`${metric}-${model}`] as { avg: number; stdDev: number } | null if (!value) { csvRow.push('N/A') } else { csvRow.push(`${value.avg.toFixed(3)} ± ${value.stdDev.toFixed(3)}`) } }) }) rows.push(csvRow) }) // Generate CSV const csv = [header, ...rows] .map((row) => row.map((cell) => `"${String(cell).replace(/"/g, '""')}"`).join(',')) .join('\n') // Download const blob = new Blob([csv], { type: 'text/csv' }) const url = URL.createObjectURL(blob) const a = document.createElement('a') a.href = url a.download = 'leaderboard_metrics.csv' document.body.appendChild(a) a.click() document.body.removeChild(a) URL.revokeObjectURL(url) } // Toggle group expansion const toggleGroup = (group: string) => { setOpenGroupRows((prev) => ({ ...prev, [group]: !prev[group], })) } // Helper to get current column sort config for a row function getColumnSort(group: string | null, subGroup: string | null, metric: string | null) { const rowKey = getColumnSortRowKey(group, subGroup, metric) return columnSortState && columnSortState.rowKey === rowKey ? columnSortState : null } // Prepare columns for TanStack Table const columns = useMemo(() => { const columnHelper = createColumnHelper() const cols: any[] = [] cols.push( columnHelper.accessor((row) => row.name, { id: 'category', header: () => 'Attack Categories', cell: (info) => { const row = info.row.original as TableRow const depth = row.depth || 0 if (row.type === 'group') { return (
toggleGroup(row.name)} > {row.isExpanded ? '▼ ' : '▶ '} {row.name}{' '} { e.stopPropagation() handleColumnSort(row.name, null, null) }} title={ getColumnSort(row.name, null, null) ? getColumnSort(row.name, null, null)?.direction === 'asc' ? 'Currently sorting models by this row in ascending order (low to high). Click for descending order.' : 'Currently sorting models by this row in descending order (high to low). Click to clear sort.' : 'Click to sort models by values in this row (independent of row sorting)' } > {getColumnSort(row.name, null, null) ? getColumnSort(row.name, null, null)?.direction === 'asc' ? '→' : '←' : '⇆'}
) } else if (row.type === 'subgroup') { return (
{row.name} { e.stopPropagation() handleColumnSort(row.groupId!, row.name, null) }} title={ getColumnSort(row.groupId!, row.name, null) ? getColumnSort(row.groupId!, row.name, null)?.direction === 'asc' ? 'Currently sorting models by this subgroup in ascending order (low to high). Click for descending order.' : 'Currently sorting models by this subgroup in descending order (high to low). Click to clear sort.' : 'Click to sort models by values in this subgroup (independent of row sorting)' } > {getColumnSort(row.groupId!, row.name, null) ? getColumnSort(row.groupId!, row.name, null)?.direction === 'asc' ? '→' : '←' : '⇆'}
) } else { // Metric row (add column sorting for model order) return (
{row.name} { e.stopPropagation() handleColumnSort( row.groupId ?? null, row.subgroupId ?? null, row.metricName ?? row.name ) }} title={ getColumnSort( row.groupId ?? null, row.subgroupId ?? null, row.metricName ?? row.name ) ? getColumnSort( row.groupId ?? null, row.subgroupId ?? null, row.metricName ?? row.name )?.direction === 'asc' ? 'Currently sorting models by this metric in ascending order (low to high). Click for descending order.' : 'Currently sorting models by this metric in descending order (high to low). Click to clear sort.' : 'Click to sort models by values in this metric (independent of row sorting)' } > {getColumnSort( row.groupId ?? null, row.subgroupId ?? null, row.metricName ?? row.name ) ? getColumnSort( row.groupId ?? null, row.subgroupId ?? null, row.metricName ?? row.name )?.direction === 'asc' ? '→' : '←' : '⇆'}
) } }, }) ) overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .forEach((metric) => { const metricModels = modelOrderByOverallMetric[metric] || models metricModels.forEach((model: string) => { cols.push( columnHelper.accessor((row) => row[`${metric}-${model}`], { id: `${metric}-${model}`, header: () => { const isSorted = rowSortState && rowSortState.columnId === `${metric}-${model}` const direction = rowSortState ? rowSortState.direction : 'desc' return (
handleRowSort(metric, model)} > {model} {isSorted ? (direction === 'asc' ? '↑' : '↓') : '⇅'}
) }, cell: (info) => { const value = info.getValue() as { avg: number; stdDev: number } | null if (!value) return 'N/A' return `${value.avg.toFixed(3)} ± ${value.stdDev.toFixed(3)}` }, }) ) }) }) return cols }, [ selectedOverallMetrics, overallMetrics, modelOrderByOverallMetric, rowSortState, columnSortState, models, ]) // Create the table instance const table = useReactTable({ data: tableData, columns, getCoreRowModel: getCoreRowModel(), }) return (
{error &&
{error}
} {!error && (
{/* */}
{selectedModels.size === 0 || selectedMetrics.size === 0 || visibleMetrics.length === 0 ? (
Please select at least one model and one metric to display the data
) : ( <> {/* Quality metrics table */} {/* Main metrics table */}
{/* Add metric group headers */} {overallMetrics .filter((metric) => selectedOverallMetrics.has(metric)) .map((metric) => ( ))} {/* Add model headers */} {table .getHeaderGroups()[0] .headers.slice(1) .map((header) => ( ))} {table.getRowModel().rows.map((row) => ( {row.getVisibleCells().map((cell) => ( ))} ))}
Attack Categories
{metric}
{header.isPlaceholder ? null : flexRender(header.column.columnDef.header, header.getContext())}
{flexRender(cell.column.columnDef.cell, cell.getContext())}
)}
)}
) } export default LeaderboardTable