File size: 313 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from abc import ABC, abstractmethod

import torch


class BaseLoss(ABC):
    def __init__(self, config) -> None:
        super().__init__()

    @abstractmethod
    def get_loss_metric_names(self) -> list[str]:
        ...

    @abstractmethod
    def __call__(self, data) -> dict[str, torch.Tensor]:
        ...