Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.dist import all_gather, get_rank | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| class MoCoV3Head(BaseModule): | |
| """Head for MoCo v3 Pre-training. | |
| This head builds a predictor, which can be any registered neck component. | |
| It also implements latent contrastive loss between two forward features. | |
| Part of the code is modified from: | |
| `<https://github.com/facebookresearch/moco-v3/blob/main/moco/builder.py>`_. | |
| Args: | |
| predictor (dict): Config dict for module of predictor. | |
| loss (dict): Config dict for module of loss functions. | |
| temperature (float): The temperature hyper-parameter that | |
| controls the concentration level of the distribution. | |
| Defaults to 1.0. | |
| """ | |
| def __init__(self, | |
| predictor: dict, | |
| loss: dict, | |
| temperature: float = 1.0) -> None: | |
| super().__init__() | |
| self.predictor = MODELS.build(predictor) | |
| self.loss_module = MODELS.build(loss) | |
| self.temperature = temperature | |
| def loss(self, base_out: torch.Tensor, | |
| momentum_out: torch.Tensor) -> torch.Tensor: | |
| """Generate loss. | |
| Args: | |
| base_out (torch.Tensor): NxC features from base_encoder. | |
| momentum_out (torch.Tensor): NxC features from momentum_encoder. | |
| Returns: | |
| torch.Tensor: The loss tensor. | |
| """ | |
| # predictor computation | |
| pred = self.predictor([base_out])[0] | |
| # normalize | |
| pred = nn.functional.normalize(pred, dim=1) | |
| target = nn.functional.normalize(momentum_out, dim=1) | |
| # get negative samples | |
| target = torch.cat(all_gather(target), dim=0) | |
| # Einstein sum is more intuitive | |
| logits = torch.einsum('nc,mc->nm', [pred, target]) / self.temperature | |
| # generate labels | |
| batch_size = logits.shape[0] | |
| labels = (torch.arange(batch_size, dtype=torch.long) + | |
| batch_size * get_rank()).to(logits.device) | |
| loss = self.loss_module(logits, labels) | |
| return loss | |