Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional, Union | |
| import torch | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| class ContrastiveHead(BaseModule): | |
| """Head for contrastive learning. | |
| The contrastive loss is implemented in this head and is used in SimCLR, | |
| MoCo, DenseCL, etc. | |
| Args: | |
| loss (dict): Config dict for module of loss functions. | |
| temperature (float): The temperature hyper-parameter that | |
| controls the concentration level of the distribution. | |
| Defaults to 0.1. | |
| init_cfg (dict or List[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| loss: dict, | |
| temperature: float = 0.1, | |
| init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.loss_module = MODELS.build(loss) | |
| self.temperature = temperature | |
| def loss(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor: | |
| """Forward function to compute contrastive loss. | |
| Args: | |
| pos (torch.Tensor): Nx1 positive similarity. | |
| neg (torch.Tensor): Nxk negative similarity. | |
| Returns: | |
| torch.Tensor: The contrastive loss. | |
| """ | |
| N = pos.size(0) | |
| logits = torch.cat((pos, neg), dim=1) | |
| logits /= self.temperature | |
| labels = torch.zeros((N, ), dtype=torch.long).to(pos.device) | |
| loss = self.loss_module(logits, labels) | |
| return loss | |