|
import torch |
|
from torchmetrics import Metric |
|
import Levenshtein |
|
|
|
|
|
class WordErrorRate(Metric): |
|
def __init__(self, ocr, dist_sync_on_step=False): |
|
|
|
super().__init__() |
|
self.ocr = ocr |
|
self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum") |
|
self.add_state("total_words", default=torch.tensor(0.0), dist_reduce_fx="sum") |
|
|
|
def update(self, pred_images, target_images): |
|
for pred_img, target_img in zip(pred_images, target_images): |
|
pred_text = self.ocr.predict(pred_img) |
|
target_text = self.ocr.predict(target_img) |
|
|
|
pred_words = pred_text.strip().split() |
|
target_words = target_text.strip().split() |
|
|
|
dist = Levenshtein.distance(" ".join(pred_words), " ".join(target_words)) |
|
|
|
self.total_errors += dist |
|
self.total_words += len(target_words) |
|
|
|
def compute(self): |
|
if self.total_words == 0: |
|
return torch.tensor(0.0) |
|
return self.total_errors / self.total_words |
|
|