Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class WeightedCrossEntropyLoss(nn.Module): | |
| def __init__(self, class_weights=None): | |
| """ | |
| Инициализация класса для кросс-энтропийной потери с возможностью взвешивания классов. | |
| :param class_weights: Вектор весов для классов (опционально) | |
| """ | |
| super(WeightedCrossEntropyLoss, self).__init__() | |
| self.class_weights = class_weights | |
| def forward(self, y_pred, y_true): | |
| """ | |
| Вычисление кросс-энтропийной потери с (или без) взвешиванием классов. | |
| :param y_true: Точные метки классов (вектор или одна метка) | |
| :param y_pred: Вероятностный вектор предсказаний | |
| :return: Значение потери | |
| """ | |
| y_true = y_true.to(torch.long) # Приводим метки к типу Long | |
| y_pred = y_pred.to(torch.float32) # Приводим предсказания к типу Float32 | |
| if self.class_weights is not None: | |
| class_weights = torch.tensor(self.class_weights).float().to(y_true.device) | |
| loss = F.cross_entropy(y_pred, y_true, weight=class_weights) | |
| else: | |
| loss = F.cross_entropy(y_pred, y_true) | |
| return loss |