File size: 13,317 Bytes
bad8293 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 |
# ---------------------------------------------------------------
# This file includes code adapted from:
# https://github.com/jbdel/RadEval/blob/null-hypothesis/utils.py
# Original author: Justin Xu
# ---------------------------------------------------------------
import re
import nltk
import random
from typing import List, Dict, Tuple, Callable, Optional
from collections import defaultdict
nltk.download("punkt_tab", quiet=True)
def clean_numbered_list(text):
"""
Clean a report if it's a numbered list by:
1. Adding proper spacing between numbered items
2. Removing the numbered list markers
3. Adding spaces after periods between sentences
"""
# First, separate numbered items that are stuck together without spaces
# Example: "textx.2. text2" -> "texty. 2. text2"
text = re.sub(r'\.(\d+\.)', r'. \1', text)
# Handle patterns where there's no period between numbered entries
# Example: "1. item1 2. item2" -> "1. item1. 2. item2"
text = re.sub(r'(\d+\.\s*[^.]+?)\s+(?=\d+\.)', r'\1. ', text)
# Then remove the numbered list markers
# But avoid removing decimal numbers in measurements like "3.5 cm"
text = re.sub(r'(?<!\d)\d+\.\s*', '', text)
# Add spaces after periods between sentences if missing
# Example: "sentence1.sentence2" -> "sentence1. sentence2"
# But don't split decimal numbers like "3.5 cm"
text = re.sub(r'\.([A-Za-z])', r'. \1', text)
return nltk.sent_tokenize(text)
class PairedTest:
"""
Paired significance testing for comparing radiology report generation systems.
Supports paired approximate randomization (AR).
"""
def __init__(self,
systems: Dict[str, List[str]],
metrics: Dict[str, Callable],
references: Optional[List[str]],
n_samples: int = 10000,
n_jobs: int = 1,
seed: int = 12345):
"""
Args:
systems: Dictionary mapping system names to their generated reports
metrics: Dictionary mapping metric names to metric functions
references: List of reference reports
n_samples: Number of resampling trials (default: 10000)
n_jobs: Number of parallel jobs (default: 1)
seed: Random seed for reproducibility
"""
self.systems = systems
self.metrics = metrics
self.references = references
self.n_samples = n_samples
self.n_jobs = n_jobs
self.seed = seed
random.seed(seed)
if not systems:
raise ValueError("At least one system is required")
system_lengths = [len(outputs) for outputs in systems.values()]
if len(set(system_lengths)) > 1:
raise ValueError("All systems must have the same number of outputs")
if references and len(references) != system_lengths[0]:
raise ValueError("References must have same length as system outputs")
self.n_instances = system_lengths[0]
def __call__(self) -> Tuple[Dict[str, str], Dict[str, Dict[str, float]]]:
"""
Run the paired significance test.
Returns:
Tuple of (signatures, scores) where:
- signatures: Dict mapping metric names to signature strings
- scores: Dict mapping system names to metric scores and p-values
"""
# Calculate baseline scores for all systems and metrics
baseline_scores = self._calculate_baseline_scores()
# Get baseline system (first system)
baseline_name = list(self.systems.keys())[0]
scores = {}
signatures = {}
# Calculate scores and p-values for each system
for system_name in self.systems.keys():
scores[system_name] = {}
for metric_name in self.metrics.keys():
score = baseline_scores[system_name][metric_name]
scores[system_name][metric_name] = score
if system_name != baseline_name:
p_value = self._calculate_p_value(
baseline_name, system_name, metric_name, baseline_scores
)
scores[system_name][f'{metric_name}_pvalue'] = p_value
for metric_name in self.metrics.keys():
signatures[metric_name] = f"{metric_name}|{'ar'}:{self.n_samples}|seed:{self.seed}"
return signatures, scores
def _calculate_baseline_scores(self) -> Dict[str, Dict[str, float]]:
"""Calculate baseline scores for all systems and metrics."""
scores = defaultdict(dict)
for system_name, outputs in self.systems.items():
for metric_name, metric_func in self.metrics.items():
if self.references:
score = metric_func(outputs, self.references)
else:
score = metric_func(outputs)
if isinstance(score, dict):
if 'score' in score:
scores[system_name][metric_name] = score['score']
else:
scores[system_name][metric_name] = list(score.values())[0]
elif isinstance(score, (tuple, list)):
scores[system_name][metric_name] = score[0]
else:
scores[system_name][metric_name] = score
return scores
def _calculate_p_value(self,
baseline_name: str,
system_name: str,
metric_name: str,
baseline_scores: Dict[str, Dict[str, float]]) -> float:
"""Calculate p-value using AR test"""
baseline_outputs = self.systems[baseline_name]
system_outputs = self.systems[system_name]
metric_func = self.metrics[metric_name]
baseline_score = baseline_scores[baseline_name][metric_name]
system_score = baseline_scores[system_name][metric_name]
original_delta = abs(system_score - baseline_score)
return self._approximate_randomization_test(
baseline_outputs, system_outputs, metric_func, original_delta
)
def _approximate_randomization_test(self,
baseline_outputs: List[str],
system_outputs: List[str],
metric_func: Callable,
original_delta: float) -> float:
"""
Perform AR test.
For each trial, randomly swap outputs between systems and calculate
the score difference. P-value is the proportion of trials where
the randomized delta >= original delta.
"""
count_greater = 0
for _ in range(self.n_samples):
randomized_baseline = []
randomized_system = []
for i in range(self.n_instances):
if random.random() < 0.5:
# Don't swap
randomized_baseline.append(baseline_outputs[i])
randomized_system.append(system_outputs[i])
else:
# Swap
randomized_baseline.append(system_outputs[i])
randomized_system.append(baseline_outputs[i])
if self.references:
rand_baseline_score = metric_func(randomized_baseline, self.references)
rand_system_score = metric_func(randomized_system, self.references)
else:
rand_baseline_score = metric_func(randomized_baseline)
rand_system_score = metric_func(randomized_system)
if isinstance(rand_baseline_score, dict):
rand_baseline_score = rand_baseline_score.get('score', list(rand_baseline_score.values())[0])
elif isinstance(rand_baseline_score, (tuple, list)):
rand_baseline_score = rand_baseline_score[0]
if isinstance(rand_system_score, dict):
rand_system_score = rand_system_score.get('score', list(rand_system_score.values())[0])
elif isinstance(rand_system_score, (tuple, list)):
rand_system_score = rand_system_score[0]
rand_delta = abs(rand_system_score - rand_baseline_score)
if rand_delta >= original_delta:
count_greater += 1
return count_greater / self.n_samples
def print_significance_results(scores: Dict[str, Dict[str, float]],
signatures: Dict[str, str],
baseline_name: str,
significance_level: float = 0.05):
"""
Args:
scores: Dictionary of system scores and p-values
signatures: Dictionary of metric signatures
baseline_name: Name of the baseline system
significance_level: Significance threshold (default: 0.05)
"""
assert baseline_name in scores, f"Baseline system '{baseline_name}' not found in scores."
metric_names = [name for name in signatures.keys()]
system_names = list(scores.keys())
print("=" * 80)
print("PAIRED SIGNIFICANCE TEST RESULTS")
print("=" * 80)
header = f"{'System':<40}"
for metric in metric_names:
header += f"{metric:>15}"
print(header)
print("-" * len(header))
baseline_row = f"Baseline: {baseline_name:<32}"
for metric in metric_names:
score = scores[baseline_name][metric]
baseline_row += f"{score:>12.4f} "
print(baseline_row)
print("-" * len(header))
for system_name in system_names:
if system_name == baseline_name:
continue
system_row = f"{system_name:<40}"
for metric in metric_names:
score = scores[system_name].get(metric, 0.0)
if isinstance(score, float):
system_row += f"{score:>12.4f} "
else:
system_row += f"{str(score):>12} "
print(system_row)
# P-value row
pvalue_row = " " * 40
for metric in metric_names:
pvalue_key = f"{metric}_pvalue"
if pvalue_key in scores[system_name]:
p_val = scores[system_name][pvalue_key]
significance_marker = "*" if p_val < significance_level else ""
pvalue_row += f"(p={p_val:.4f}){significance_marker:<2}".rjust(15)
else:
pvalue_row += " " * 15
print(pvalue_row)
print("-" * len(header))
# Footer
print(f"- Significance level: {significance_level}")
print("- '*' indicates significant difference (p < significance level)")
print("- Null hypothesis: systems are essentially the same")
print("- Significant results suggest systems are meaningfully different\n")
print("METRIC SIGNATURES:")
for metric, signature in signatures.items():
print(f"- {metric}: {signature}")
def compare_systems(systems: Dict[str, List[str]],
metrics: Dict[str, Callable],
references: Optional[List[str]] = None,
n_samples: int = 10000,
significance_level: float = 0.05,
seed: int = 12345,
print_results: bool = True) -> Tuple[Dict[str, str], Dict[str, Dict[str, float]]]:
"""
Args:
systems: Dictionary mapping system names to their generated reports
metrics: Dictionary mapping metric names to metric functions
references: Optional list of reference reports
n_samples: Number of resampling trials
significance_level: Significance threshold for printing results
seed: Random seed for reproducibility
print_results: Whether to print formatted results
Returns:
Tuple of (signatures, scores)
Example:
```python
systems = {
'baseline_model': baseline_reports,
'new_model': new_model_reports,
'other_model': other_model_reports
}
metrics = {
'bleu': lambda hyp, ref: bleu_score(hyp, ref),
'rouge': lambda hyp, ref: rouge_score(hyp, ref),
'bertscore': lambda hyp, ref: bert_score(hyp, ref)
'custom_metric': lambda hyp, ref: custom_metric(hyp, ref)
}
signatures, scores = compare_systems(
systems, metrics, references,
n_samples=10000
)
```
"""
paired_test = PairedTest(
systems=systems,
metrics=metrics,
references=references,
n_samples=n_samples,
seed=seed
)
signatures, scores = paired_test()
if print_results:
baseline_name = list(systems.keys())[0]
print_significance_results(scores, signatures, baseline_name, significance_level)
return signatures, scores |