omnisealbench / frontend /src /components /LeaderboardTable.tsx
Mark Duppenthaler
Add description tooltips
44072a9
/**
* LeaderboardTable.tsx
*
* This component displays a structured table with hierarchical data (groups, subgroups, metrics)
* and provides two independent sorting mechanisms:
*
* 1. Row Sorting: Clicking on a column header (model name) sorts the rows
* - Implemented using TanStack Table's built-in sorting
* - Controls which rows appear first in the table
* - Groups sort against other groups based on their values
* - Subgroups stay with their parent group but sort within the group
*
* 2. Column Sorting: Clicking on a row header sorts the columns (models)
* - Custom implementation using modelOrderByOverallMetric
* - Controls the order of models (columns) for each metric
* - Completely independent of row sorting
*
* Both sorting mechanisms operate independently and can be used simultaneously.
*/
import React, { useEffect, useState, useMemo, useCallback } from 'react'
import { ArrowDownTrayIcon } from '@heroicons/react/24/solid'
import QualityMetricsTable from './QualityMetricsTable'
import MetricInfoIcon from './MetricInfoIcon'
import {
createColumnHelper,
flexRender,
getCoreRowModel,
useReactTable,
ColumnDef,
} from '@tanstack/react-table'
interface LeaderboardTableProps {
benchmarkData: any
selectedModels: Set<string>
}
// Original Row interface - used for the raw data
interface OriginalRow {
metric: string
[key: string]: string | number
}
// New TableRow interface for the structured hierarchical data
interface TableRow {
id: string
type: 'group' | 'subgroup' | 'metric'
groupId?: string
subgroupId?: string
metricName?: string
name: string
visible: boolean
depth: number
isExpanded?: boolean
[key: string]: any
}
interface Groups {
[group: string]: { [subgroup: string]: string[] }
}
// For sorting rows (used when clicking column headers)
interface RowSortState {
columnId: string
direction: 'asc' | 'desc'
}
// For sorting columns (used when clicking row headers)
interface ColumnSortState {
rowKey: string
direction: 'asc' | 'desc'
}
const OVERALL_ROW = 'Overall'
const DEFAULT_SELECTED_METRICS = new Set(['log10_p_value'])
const OverallMetricFilter: React.FC<{
overallMetrics: string[]
selectedOverallMetrics: Set<string>
setSelectedOverallMetrics: (metrics: Set<string>) => void
}> = ({ overallMetrics, selectedOverallMetrics, setSelectedOverallMetrics }) => {
const toggleMetric = (metric: string) => {
const newSelected = new Set(selectedOverallMetrics)
if (newSelected.has(metric)) {
newSelected.delete(metric)
} else {
newSelected.add(metric)
}
setSelectedOverallMetrics(newSelected)
}
return (
<div className="w-full">
<fieldset className="fieldset w-full p-4 rounded border border-gray-700 bg-base-200">
<legend className="fieldset-legend font-semibold">
Metrics ({selectedOverallMetrics.size}/{overallMetrics.length})
</legend>
<div className="grid grid-cols-2 md:grid-cols-4 lg:grid-cols-6 gap-1 max-h-48 overflow-y-auto pr-2">
{overallMetrics.map((metric) => (
<label key={metric} className="flex items-center gap-2 text-sm">
<input
type="checkbox"
className="form-checkbox h-4 w-4"
checked={selectedOverallMetrics.has(metric)}
onChange={() => toggleMetric(metric)}
/>
<div className="flex items-center truncate">
<span className="truncate" title={metric}>
{metric}
</span>
<MetricInfoIcon metricName={metric} />
</div>
</label>
))}
</div>
</fieldset>
</div>
)
}
const LeaderboardTable: React.FC<LeaderboardTableProps> = ({ benchmarkData, selectedModels }) => {
const [rawRows, setRawRows] = useState<OriginalRow[]>([])
const [tableHeader, setTableHeader] = useState<string[]>([])
const [error, setError] = useState<string | null>(null)
const [groupRows, setGroupRows] = useState<Groups>({})
const [openGroupRows, setOpenGroupRows] = useState<{ [key: string]: boolean }>({})
const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set())
const [overallMetrics, setOverallMetrics] = useState<string[]>([])
const [selectedOverallMetrics, setSelectedOverallMetrics] =
useState<Set<string>>(DEFAULT_SELECTED_METRICS)
const [rowSortState, setRowSortState] = useState<RowSortState | null>(null)
const [columnSortState, setColumnSortState] = useState<ColumnSortState | null>(null)
const [modelOrderByOverallMetric, setModelOrderByOverallMetric] = useState<{
[key: string]: string[]
}>({})
// Get filtered models based on selectedModels
const models = useMemo(() => {
return tableHeader.filter((model) => selectedModels.has(model))
}, [tableHeader, selectedModels])
// Parse benchmark data when it changes
useEffect(() => {
if (!benchmarkData) {
return
}
try {
const data = benchmarkData
const rows: OriginalRow[] = data['rows']
const allGroups = data['groups'] as { [key: string]: string[] }
const { Overall: overallGroup, ...groups } = allGroups
const uniqueMetrics = new Set<string>()
overallGroup?.forEach((metric) => {
if (metric.includes('_')) {
const metricName = metric.split('_').slice(1).join('_')
uniqueMetrics.add(metricName)
}
})
setOverallMetrics(Array.from(uniqueMetrics).sort())
setSelectedOverallMetrics(new Set(DEFAULT_SELECTED_METRICS))
// setSelectedOverallMetrics(new Set(Array.from(uniqueMetrics)))
const groupsData = Object.entries(allGroups)
.sort(([groupA], [groupB]) => {
if (groupA === OVERALL_ROW) return -1
if (groupB === OVERALL_ROW) return 1
return groupA.localeCompare(groupB)
})
.reduce(
(acc, [group, metrics]) => {
const sortedMetrics = [...metrics].sort()
acc[group] = sortedMetrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => {
const [mainGroup, subGroup] = metric.split('_')
if (!subAcc[mainGroup]) {
subAcc[mainGroup] = []
}
subAcc[mainGroup].push(metric)
return subAcc
}, {})
acc[group] = Object.fromEntries(
Object.entries(acc[group]).sort(([subGroupA], [subGroupB]) =>
subGroupA.localeCompare(subGroupB)
)
)
return acc
},
{} as { [key: string]: { [key: string]: string[] } }
)
const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row))))
const headers = allKeys.filter((key) => key !== 'metric')
const initialOpenGroups: { [key: string]: boolean } = {}
Object.keys(groupsData).forEach((group) => {
initialOpenGroups[group] = false
})
const allMetrics = Object.values(allGroups).flat()
setSelectedMetrics(new Set(allMetrics))
setTableHeader(headers)
setRawRows(rows)
setGroupRows(groupsData)
setOpenGroupRows(initialOpenGroups)
// Initialize row sort state for Overall group
setColumnSortState({
rowKey: getColumnSortRowKey(OVERALL_ROW, null, null),
direction: 'asc',
})
// Initialize model order by overall metric
const metricOrders: { [key: string]: string[] } = {}
Array.from(uniqueMetrics).forEach((metric) => {
metricOrders[metric] = [...headers]
})
// Store the original model order for resetting when sort is cleared
setModelOrderByOverallMetric(metricOrders)
setError(null)
} catch (err: any) {
setError('Failed to parse benchmark data, please try again: ' + err.message)
}
}, [benchmarkData])
const handleRowSort = (overallMetric: string, model: string) => {
// Create the column ID for this metric-model combination
const columnId = `${overallMetric}-${model}`
let nextDirection: 'asc' | 'desc' | null = null
if (!rowSortState || rowSortState.columnId !== columnId) {
nextDirection = 'asc'
} else if (rowSortState.direction === 'asc') {
nextDirection = 'desc'
} else {
nextDirection = null
}
setRowSortState(nextDirection ? { columnId, direction: nextDirection } : null)
}
// Helper to generate a stable composite key for row-based column sorting
function getColumnSortRowKey(
group: string | null,
subGroup: string | null,
metric: string | null
): string {
return `${group ?? ''}||${subGroup ?? ''}||${metric ?? ''}`
}
// Update the column order when a row's sort icon is clicked
const handleColumnSort = (
group: string | null,
subGroup: string | null,
metric: string | null
) => {
const rowKey = getColumnSortRowKey(group, subGroup, metric)
console.log('Column sort clicked:', { group, subGroup, metric, rowKey })
// First determine the new sort direction
let newDirection: 'asc' | 'desc' | null = null
if (!columnSortState || columnSortState.rowKey !== rowKey) {
// New sort, start with ascending
newDirection = 'asc'
} else if (columnSortState.direction === 'asc') {
// Toggle from ascending to descending
newDirection = 'desc'
} else {
// Toggle from descending to null (clear sort)
newDirection = null
}
setColumnSortState(newDirection ? { rowKey, direction: newDirection } : null)
// If clearing the sort, reset to default column order for this metric only
if (!newDirection && metric) {
setModelOrderByOverallMetric((prev) => {
const newOrder = { ...prev }
newOrder[metric] = [...tableHeader.filter((model) => selectedModels.has(model))]
return newOrder
})
}
}
// Find all metrics matching a particular extracted metric name (like "log10_p_value")
const findAllMetricsForName = useCallback(
(metricName: string): string[] => {
return rawRows
.filter((row) => {
const metric = row.metric as string
if (metric.includes('_')) {
const extractedName = metric.split('_').slice(1).join('_')
return extractedName.endsWith(metricName)
}
return false
})
.map((row) => row.metric as string)
},
[rawRows]
)
// Identify metrics that don't belong to any overall metric group
const findQualityMetrics = useCallback((): string[] => {
const allMetrics = rawRows.map((row) => row.metric as string)
return allMetrics.filter((metric: string) => {
for (const overall of overallMetrics) {
if (metric.endsWith(`_${overall}`) || metric === overall) {
return false
}
}
return true
})
}, [rawRows, overallMetrics])
// Calculate average and standard deviation for a set of metrics for a specific column
const calculateStats = useCallback(
(metricNames: string[], columnKey: string): { avg: number; stdDev: number } => {
const values = metricNames
.map((metricName) => {
const row = rawRows.find((row) => row.metric === metricName)
return row ? Number(row[columnKey]) : NaN
})
.filter((value) => !isNaN(value))
if (values.length === 0) return { avg: NaN, stdDev: NaN }
const avg = values.reduce((sum, val) => sum + val, 0) / values.length
const squareDiffs = values.map((value) => {
const diff = value - avg
return diff * diff
})
const variance = squareDiffs.reduce((sum, sqrDiff) => sum + sqrDiff, 0) / values.length
const stdDev = Math.sqrt(variance)
return { avg, stdDev }
},
[rawRows]
)
// Filter metrics by group and/or subgroup
const filterMetricsByGroupAndSubgroup = useCallback(
(
metricNames: string[],
group: string | null = null,
subgroup: string | null = null
): string[] => {
if (!group) return metricNames
const groupMetrics = Object.values(groupRows[group] || {}).flat() as string[]
if (subgroup && groupRows[group]?.[subgroup]) {
return metricNames.filter(
(metric) => groupRows[group][subgroup].includes(metric) && selectedMetrics.has(metric)
)
}
return metricNames.filter(
(metric) => groupMetrics.includes(metric) && selectedMetrics.has(metric)
)
},
[groupRows, selectedMetrics]
)
// Compute visible metrics for rendering
const visibleMetrics = overallMetrics.filter((metric) => selectedOverallMetrics.has(metric))
// Generate data for the table
const tableData = useMemo(() => {
const rows: TableRow[] = []
let groupEntries = Object.entries(groupRows)
// --- Manual row sorting using rowSortState ---
if (rowSortState) {
const { columnId, direction } = rowSortState
const [metric, model] = columnId.split('-')
groupEntries = [...groupEntries].sort(([groupA, subGroupsA], [groupB, subGroupsB]) => {
const allGroupMetricsA = Object.values(subGroupsA).flat()
const allGroupMetricsB = Object.values(subGroupsB).flat()
const allMetricsWithNameA = findAllMetricsForName(metric)
const allMetricsWithNameB = allMetricsWithNameA
const metricsInGroupA = allGroupMetricsA.filter((m) => allMetricsWithNameA.includes(m))
const metricsInGroupB = allGroupMetricsB.filter((m) => allMetricsWithNameB.includes(m))
const statsA = calculateStats(metricsInGroupA, model)
const statsB = calculateStats(metricsInGroupB, model)
const valueA = !isNaN(statsA.avg) ? statsA.avg : -Infinity
const valueB = !isNaN(statsB.avg) ? statsB.avg : -Infinity
return direction === 'asc' ? valueA - valueB : valueB - valueA
})
}
groupEntries.forEach(([group, subGroups]) => {
const allGroupMetrics = Object.values(subGroups).flat()
const visibleGroupMetrics = filterMetricsByGroupAndSubgroup(allGroupMetrics, group)
if (visibleGroupMetrics.length === 0) return
const groupRow: TableRow = {
id: `group-${group}`,
type: 'group',
name: group,
visible: true,
depth: 0,
isExpanded: openGroupRows[group],
}
selectedOverallMetrics.forEach((metric) => {
if (overallMetrics.includes(metric)) {
models.forEach((model) => {
const allMetricsWithName = findAllMetricsForName(metric)
const metricsInGroupForThisMetric = visibleGroupMetrics.filter((m) =>
allMetricsWithName.includes(m)
)
const stats = calculateStats(metricsInGroupForThisMetric, model)
groupRow[`${metric}-${model}`] = !isNaN(stats.avg)
? { avg: stats.avg, stdDev: stats.stdDev }
: null
})
}
})
rows.push(groupRow)
if (openGroupRows[group]) {
let subGroupEntries = Object.entries(subGroups).sort(([a], [b]) => a.localeCompare(b))
if (rowSortState) {
const { columnId, direction } = rowSortState
const [metric, model] = columnId.split('-')
subGroupEntries = [...subGroupEntries].sort(([subA, metricsA], [subB, metricsB]) => {
const allMetricsWithName = findAllMetricsForName(metric)
const metricsInSubgroupA = metricsA.filter((m) => allMetricsWithName.includes(m))
const metricsInSubgroupB = metricsB.filter((m) => allMetricsWithName.includes(m))
const statsA = calculateStats(metricsInSubgroupA, model)
const statsB = calculateStats(metricsInSubgroupB, model)
const valueA = !isNaN(statsA.avg) ? statsA.avg : -Infinity
const valueB = !isNaN(statsB.avg) ? statsB.avg : -Infinity
return direction === 'asc' ? valueA - valueB : valueB - valueA
})
}
subGroupEntries.forEach(([subGroup, metrics]) => {
const visibleSubgroupMetrics = filterMetricsByGroupAndSubgroup(metrics, group, subGroup)
if (visibleSubgroupMetrics.length === 0) return
const subgroupRow: TableRow = {
id: `group-${group}-subgroup-${subGroup}`,
type: 'subgroup',
groupId: group,
name: subGroup,
visible: true,
depth: 1,
isExpanded: false,
}
selectedOverallMetrics.forEach((metric) => {
if (overallMetrics.includes(metric)) {
models.forEach((model) => {
const allMetricsWithName = findAllMetricsForName(metric)
const metricsInSubgroupForThisMetric = visibleSubgroupMetrics.filter((m) =>
allMetricsWithName.includes(m)
)
const stats = calculateStats(metricsInSubgroupForThisMetric, model)
subgroupRow[`${metric}-${model}`] = !isNaN(stats.avg)
? { avg: stats.avg, stdDev: stats.stdDev }
: null
})
}
})
rows.push(subgroupRow)
})
}
})
return rows
}, [
rawRows,
groupRows,
openGroupRows,
selectedOverallMetrics,
selectedMetrics,
models,
columnSortState,
modelOrderByOverallMetric,
rowSortState,
])
// Effect: update model order when columnSortState or dependencies change
useEffect(() => {
console.log(columnSortState)
if (!columnSortState) return
// Parse out group, subGroup, metric from rowKey
const [group, subGroup, metric] = columnSortState.rowKey.split('||').map((v) => v || null)
const newDirection = columnSortState.direction
console.log(newDirection, group, subGroup, metric)
if (!newDirection) return // Only run if a sort direction is present
// Update model order for all visible metrics
const metricsToUpdate = Array.from(selectedOverallMetrics)
// Find the row in tableData that was clicked for sorting
let rowToSort: TableRow | undefined
if (group && subGroup && !metric) {
// Subgroup row
rowToSort = tableData.find(
(row) => row.type === 'subgroup' && row.groupId === group && row.name === subGroup
)
} else if (group && !subGroup && !metric) {
// Group row
rowToSort = tableData.find((row) => row.type === 'group' && row.name === group)
} else if (metric) {
// Metric row - not currently in tableData, handled differently
rowToSort = undefined
}
if (!rowToSort && !metric) {
console.log('Row to sort not found', {
group,
subGroup,
metric,
rowKey: columnSortState.rowKey,
})
// Try to proceed anyway with group/subgroup sorting
if (group) {
metricsToUpdate.forEach((metricName) => {
// Get existing model order
const currentOrder = modelOrderByOverallMetric[metricName] || [...models]
// For group/subgroup with no row found, keep current model order but reverse it if changing direction
if (newDirection === 'asc') {
setModelOrderByOverallMetric((prev) => ({
...prev,
[metricName]: [...currentOrder],
}))
} else {
setModelOrderByOverallMetric((prev) => ({
...prev,
[metricName]: [...currentOrder].reverse(),
}))
}
})
}
return
}
// Check if rowToSort has all the expected metrics
for (const metricName of metricsToUpdate) {
if (
!rowToSort ||
!models.some((model) => rowToSort[`${metricName}-${model}`] !== undefined)
) {
console.log(`Row does not have metric values for ${metricName}`, rowToSort)
}
}
// Sort the models for each metric
const newOrders: { [key: string]: string[] } = {}
metricsToUpdate.forEach((metricName) => {
// Sort models based on the values in the clicked row
const modelScores: { model: string; score: number }[] = models.map((model: string) => {
let score = -Infinity
if (rowToSort) {
// For group/subgroup rows, use the aggregated values in the row for each metric
const value: { avg: number; stdDev: number } | null =
rowToSort[`${metricName}-${model}`] ?? null
score = value && !isNaN(value.avg) ? value.avg : -Infinity
} else if (metric) {
// For metric rows (which aren't in tableData), we need a different approach
// Find metrics for this group that have this metric name
const allMetricsWithName = findAllMetricsForName(metric)
if (allMetricsWithName.length > 0) {
const values = allMetricsWithName
.map((metricId) => {
const row = rawRows.find((r) => r.metric === metricId)
return row ? Number(row[model]) : NaN
})
.filter((val) => !isNaN(val))
if (values.length > 0) {
const avg = values.reduce((sum, val) => sum + val, 0) / values.length
score = !isNaN(avg) ? avg : -Infinity
}
}
}
return { model, score }
})
modelScores.sort((a, b) => (newDirection === 'asc' ? a.score - b.score : b.score - a.score))
newOrders[metricName] = modelScores
.map((item) => item.model)
.filter((m) => selectedModels.has(m))
})
// Only update if any order actually changed
setModelOrderByOverallMetric((prev) => {
let changed = false
const next = { ...prev }
metricsToUpdate.forEach((metricName) => {
const currentOrder = prev[metricName] || []
const newOrder = newOrders[metricName] || []
const arraysEqual = (a: string[], b: string[]) =>
a.length === b.length && a.every((v, i) => v === b[i])
if (!arraysEqual(currentOrder, newOrder)) {
next[metricName] = [...newOrder]
changed = true
}
})
return changed ? next : prev
})
}, [
columnSortState,
models,
selectedModels,
modelOrderByOverallMetric,
tableData,
rawRows,
selectedOverallMetrics,
])
console.log(modelOrderByOverallMetric)
// CSV export function
const exportToCsv = () => {
// Build header row
const header = [
'Attack Categories',
...overallMetrics
.filter((metric) => selectedOverallMetrics.has(metric))
.flatMap((metric) => {
const metricModels = modelOrderByOverallMetric[metric] || models
return metricModels.map((model) => `${metric} - ${model}`)
}),
]
// Build data rows
const rows: (string | number)[][] = []
tableData.forEach((row) => {
const csvRow: (string | number)[] = [row.name]
overallMetrics
.filter((metric) => selectedOverallMetrics.has(metric))
.forEach((metric) => {
const metricModels = modelOrderByOverallMetric[metric] || models
metricModels.forEach((model: string) => {
const value = row[`${metric}-${model}`] as { avg: number; stdDev: number } | null
if (!value) {
csvRow.push('N/A')
} else {
csvRow.push(`${value.avg.toFixed(3)} ± ${value.stdDev.toFixed(3)}`)
}
})
})
rows.push(csvRow)
})
// Generate CSV
const csv = [header, ...rows]
.map((row) => row.map((cell) => `"${String(cell).replace(/"/g, '""')}"`).join(','))
.join('\n')
// Download
const blob = new Blob([csv], { type: 'text/csv' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = 'leaderboard_metrics.csv'
document.body.appendChild(a)
a.click()
document.body.removeChild(a)
URL.revokeObjectURL(url)
}
// Toggle group expansion
const toggleGroup = (group: string) => {
setOpenGroupRows((prev) => ({
...prev,
[group]: !prev[group],
}))
}
// Helper to get current column sort config for a row
function getColumnSort(group: string | null, subGroup: string | null, metric: string | null) {
const rowKey = getColumnSortRowKey(group, subGroup, metric)
return columnSortState && columnSortState.rowKey === rowKey ? columnSortState : null
}
// Prepare columns for TanStack Table
const columns = useMemo<any[]>(() => {
const columnHelper = createColumnHelper<TableRow>()
const cols: any[] = []
cols.push(
columnHelper.accessor((row) => row.name, {
id: 'category',
header: () => 'Attack Categories',
cell: (info) => {
const row = info.row.original as TableRow
const depth = row.depth || 0
if (row.type === 'group') {
return (
<div
className="sticky left-0 font-medium cursor-pointer select-none flex items-center"
onClick={() => toggleGroup(row.name)}
>
<span>{row.isExpanded ? '▼ ' : '▶ '}</span>
<span className="flex-1">{row.name}</span>{' '}
<span
className="ml-1 cursor-pointer font-bold"
onClick={(e) => {
e.stopPropagation()
handleColumnSort(row.name, null, null)
}}
title={
getColumnSort(row.name, null, null)
? getColumnSort(row.name, null, null)?.direction === 'asc'
? 'Currently sorting models by this row in ascending order (low to high). Click for descending order.'
: 'Currently sorting models by this row in descending order (high to low). Click to clear sort.'
: 'Click to sort models by values in this row (independent of row sorting)'
}
>
{getColumnSort(row.name, null, null)
? getColumnSort(row.name, null, null)?.direction === 'asc'
? '→'
: '←'
: '⇆'}
</span>
</div>
)
} else if (row.type === 'subgroup') {
return (
<div className="sticky left-0 pl-6 font-medium flex items-center gap-1">
<span className="flex-1">{row.name}</span>
<span
className="ml-1 cursor-pointer font-bold"
onClick={(e) => {
e.stopPropagation()
handleColumnSort(row.groupId!, row.name, null)
}}
title={
getColumnSort(row.groupId!, row.name, null)
? getColumnSort(row.groupId!, row.name, null)?.direction === 'asc'
? 'Currently sorting models by this subgroup in ascending order (low to high). Click for descending order.'
: 'Currently sorting models by this subgroup in descending order (high to low). Click to clear sort.'
: 'Click to sort models by values in this subgroup (independent of row sorting)'
}
>
{getColumnSort(row.groupId!, row.name, null)
? getColumnSort(row.groupId!, row.name, null)?.direction === 'asc'
? '→'
: '←'
: '⇆'}
</span>
</div>
)
} else {
// Metric row (add column sorting for model order)
return (
<div className="sticky left-0 pl-12 font-medium flex items-center gap-1">
<span className="flex-1">{row.name}</span>
<span
className="ml-1 cursor-pointer font-bold"
onClick={(e) => {
e.stopPropagation()
handleColumnSort(
row.groupId ?? null,
row.subgroupId ?? null,
row.metricName ?? row.name
)
}}
title={
getColumnSort(
row.groupId ?? null,
row.subgroupId ?? null,
row.metricName ?? row.name
)
? getColumnSort(
row.groupId ?? null,
row.subgroupId ?? null,
row.metricName ?? row.name
)?.direction === 'asc'
? 'Currently sorting models by this metric in ascending order (low to high). Click for descending order.'
: 'Currently sorting models by this metric in descending order (high to low). Click to clear sort.'
: 'Click to sort models by values in this metric (independent of row sorting)'
}
>
{getColumnSort(
row.groupId ?? null,
row.subgroupId ?? null,
row.metricName ?? row.name
)
? getColumnSort(
row.groupId ?? null,
row.subgroupId ?? null,
row.metricName ?? row.name
)?.direction === 'asc'
? '→'
: '←'
: '⇆'}
</span>
</div>
)
}
},
})
)
overallMetrics
.filter((metric) => selectedOverallMetrics.has(metric))
.forEach((metric) => {
const metricModels = modelOrderByOverallMetric[metric] || models
metricModels.forEach((model: string) => {
cols.push(
columnHelper.accessor((row) => row[`${metric}-${model}`], {
id: `${metric}-${model}`,
header: () => {
const isSorted = rowSortState && rowSortState.columnId === `${metric}-${model}`
const direction = rowSortState ? rowSortState.direction : 'desc'
return (
<div
className="cursor-pointer select-none"
onClick={() => handleRowSort(metric, model)}
>
{model}
<span
className="ml-1 font-bold"
title={
isSorted
? direction === 'asc'
? 'Currently sorting rows by this column in ascending order (low to high). Click for descending order.'
: 'Currently sorting rows by this column in descending order (high to low). Click to clear sort.'
: 'Click to sort rows by values in this column (subgroups always stay with their parent group)'
}
>
{isSorted ? (direction === 'asc' ? '↑' : '↓') : '⇅'}
</span>
</div>
)
},
cell: (info) => {
const value = info.getValue() as { avg: number; stdDev: number } | null
if (!value) return 'N/A'
return `${value.avg.toFixed(3)} ± ${value.stdDev.toFixed(3)}`
},
})
)
})
})
return cols
}, [
selectedOverallMetrics,
overallMetrics,
modelOrderByOverallMetric,
rowSortState,
columnSortState,
models,
])
// Create the table instance
const table = useReactTable({
data: tableData,
columns,
getCoreRowModel: getCoreRowModel(),
})
return (
<div className="rounded">
{error && <div className="text-red-500">{error}</div>}
{!error && (
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-4">
<OverallMetricFilter
overallMetrics={overallMetrics}
selectedOverallMetrics={selectedOverallMetrics}
setSelectedOverallMetrics={setSelectedOverallMetrics}
/>
{/* <LeaderboardFilter
groups={groupRows}
selectedMetrics={selectedMetrics}
setSelectedMetrics={setSelectedMetrics}
/> */}
</div>
{selectedModels.size === 0 ||
selectedMetrics.size === 0 ||
visibleMetrics.length === 0 ? (
<div className="text-center p-4 text-lg">
Please select at least one model and one metric to display the data
</div>
) : (
<>
{/* Quality metrics table */}
<QualityMetricsTable
qualityMetrics={findQualityMetrics()}
tableHeader={tableHeader}
selectedModels={selectedModels}
tableRows={rawRows}
/>
{/* Main metrics table */}
<div className="relative flex justify-end mb-6">
<button
className="absolute top-0 right-0 btn btn-ghost btn-circle"
title="Export CSV"
onClick={exportToCsv}
>
<ArrowDownTrayIcon className="h-6 w-6" />
</button>
</div>
<div className="overflow-x-auto max-h-[80vh] overflow-y-auto">
<table className="table w-full min-w-max border-separate border-spacing-0 border-gray-700 border">
<thead>
<tr>
<th className="sticky left-0 top-0 bg-base-100 z-20 border border-gray-700">
Attack Categories
</th>
{/* Add metric group headers */}
{overallMetrics
.filter((metric) => selectedOverallMetrics.has(metric))
.map((metric) => (
<th
key={`header-metric-${metric}`}
className="sticky top-0 bg-base-100 z-10 text-center text-xs border border-gray-700 select-none"
colSpan={(modelOrderByOverallMetric[metric] || models).length}
>
<div className="flex items-center justify-center">
<span>{metric}</span>
<MetricInfoIcon metricName={metric} />
</div>
</th>
))}
</tr>
{/* Add model headers */}
<tr>
<th className="sticky left-0 top-12 bg-base-100 z-30 border border-gray-700"></th>
{table
.getHeaderGroups()[0]
.headers.slice(1)
.map((header) => (
<th
key={header.id}
className="sticky top-12 bg-base-100 z-10 text-center text-xs border border-gray-700"
>
{header.isPlaceholder
? null
: flexRender(header.column.columnDef.header, header.getContext())}
</th>
))}
</tr>
</thead>
<tbody>
{table.getRowModel().rows.map((row) => (
<tr
key={row.id}
className={`${
row.original.type === 'group'
? 'bg-base-200 hover:bg-base-300'
: 'bg-base-100 hover:bg-base-200'
}`}
>
{row.getVisibleCells().map((cell) => (
<td
key={cell.id}
className={`${
cell.column.id === 'category'
? `sticky left-0 ${row.original.type === 'group' ? 'bg-base-200' : 'bg-base-100'} z-10`
: 'font-medium text-center'
} border-gray-700 border`}
>
{flexRender(cell.column.columnDef.cell, cell.getContext())}
</td>
))}
</tr>
))}
</tbody>
</table>
</div>
</>
)}
</div>
)}
</div>
)
}
export default LeaderboardTable