Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from mmengine.model import BaseModule | |
| from mmpretrain.registry import MODELS | |
| class CAEHead(BaseModule): | |
| """Head for CAE Pre-training. | |
| Compute the align loss and the main loss. In addition, this head also | |
| generates the prediction target generated by dalle. | |
| Args: | |
| loss (dict): The config of loss. | |
| tokenizer_path (str): The path of the tokenizer. | |
| init_cfg (dict or List[dict], optional): Initialization config dict. | |
| Defaults to None. | |
| """ | |
| def __init__(self, | |
| loss: dict, | |
| init_cfg: Optional[Union[dict, List[dict]]] = None) -> None: | |
| super().__init__(init_cfg=init_cfg) | |
| self.loss_module = MODELS.build(loss) | |
| def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor: | |
| """Generate the reconstruction target. | |
| Args: | |
| logits_target (torch.Tensor): The logits generated by DALL-E.s | |
| Returns: | |
| torch.Tensor: The logits target. | |
| """ | |
| target = torch.argmax(logits_target, dim=1) | |
| return target.flatten(1) | |
| def loss(self, logits: torch.Tensor, logits_target: torch.Tensor, | |
| latent_pred: torch.Tensor, latent_target: torch.Tensor, | |
| mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Generate loss. | |
| Args: | |
| logits (torch.Tensor): Logits generated by decoder. | |
| logits_target (img_target): Target generated by dalle for decoder | |
| prediction. | |
| latent_pred (torch.Tensor): Latent prediction by regressor. | |
| latent_target (torch.Tensor): Target for latent prediction, | |
| generated by teacher. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: The tuple of loss. | |
| - ``loss_main`` (torch.Tensor): Cross entropy loss. | |
| - ``loss_align`` (torch.Tensor): MSE loss. | |
| """ | |
| target = self._generate_target(logits_target) # target features | |
| target = target[mask].detach() | |
| # loss main for decoder, loss align for regressor | |
| loss_main, loss_align = self.loss_module(logits, target, latent_pred, | |
| latent_target) | |
| return (loss_main, loss_align) | |