|
import torch |
|
from torchmetrics import Metric |
|
from ocr import OCR |
|
import Levenshtein |
|
|
|
|
|
class CharacterErrorRate(Metric): |
|
def __init__(self, ocr, dist_sync_on_step=False): |
|
|
|
super().__init__() |
|
self.add_state("total_errors", default=torch.tensor(0.0), dist_reduce_fx="sum") |
|
self.add_state("total_chars", default=torch.tensor(0.0), dist_reduce_fx="sum") |
|
self.ocr = ocr |
|
|
|
def update(self, pred_images, target_images): |
|
for pred_img, target_img in zip(pred_images, target_images): |
|
pred_text = self.ocr.predict(pred_img) |
|
target_text = self.ocr.predict(target_img) |
|
|
|
dist = Levenshtein.distance(pred_text, target_text) |
|
self.total_errors += dist |
|
self.total_chars += len(target_text) |
|
|
|
def compute(self): |
|
if self.total_chars == 0: |
|
return torch.tensor(0.0) |
|
return self.total_errors / self.total_chars |
|
|