VTBench / evaluations /word_error_rate.py
huaweilin's picture
update
14ce5a9
import torch
from torchmetrics import Metric
import Levenshtein
class WordErrorRate(Metric):
def __init__(self, ocr, dist_sync_on_step=False):
# super().__init__(dist_sync_on_step=dist_sync_on_step)
super().__init__()
self.ocr = ocr
self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total_words", default=torch.tensor(0.0), dist_reduce_fx="sum")
def update(self, pred_images, target_images):
for pred_img, target_img in zip(pred_images, target_images):
pred_text = self.ocr.predict(pred_img)
target_text = self.ocr.predict(target_img)
pred_words = pred_text.strip().split()
target_words = target_text.strip().split()
dist = Levenshtein.distance(" ".join(pred_words), " ".join(target_words))
self.total_errors += dist
self.total_words += len(target_words)
def compute(self):
if self.total_words == 0:
return torch.tensor(0.0)
return self.total_errors / self.total_words