File size: 7,127 Bytes
			
			| 6eb1d7d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# pyre-unsafe
from typing import Any, List
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import CfgNode
from detectron2.structures import Instances
from densepose.data.meshes.catalog import MeshCatalog
from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix
from .embed_utils import PackedCseAnnotations
from .mask import extract_data_for_mask_loss_from_matches
def _create_pixel_dist_matrix(grid_size: int) -> torch.Tensor:
    rows = torch.arange(grid_size)
    cols = torch.arange(grid_size)
    # at index `i` contains [row, col], where
    # row = i // grid_size
    # col = i % grid_size
    pix_coords = (
        torch.stack(torch.meshgrid(rows, cols), -1).reshape((grid_size * grid_size, 2)).float()
    )
    return squared_euclidean_distance_matrix(pix_coords, pix_coords)
def _sample_fg_pixels_randperm(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor:
    fg_mask_flattened = fg_mask.reshape((-1,))
    num_pixels = int(fg_mask_flattened.sum().item())
    fg_pixel_indices = fg_mask_flattened.nonzero(as_tuple=True)[0]
    if (sample_size <= 0) or (num_pixels <= sample_size):
        return fg_pixel_indices
    sample_indices = torch.randperm(num_pixels, device=fg_mask.device)[:sample_size]
    return fg_pixel_indices[sample_indices]
def _sample_fg_pixels_multinomial(fg_mask: torch.Tensor, sample_size: int) -> torch.Tensor:
    fg_mask_flattened = fg_mask.reshape((-1,))
    num_pixels = int(fg_mask_flattened.sum().item())
    if (sample_size <= 0) or (num_pixels <= sample_size):
        return fg_mask_flattened.nonzero(as_tuple=True)[0]
    return fg_mask_flattened.float().multinomial(sample_size, replacement=False)
class PixToShapeCycleLoss(nn.Module):
    """
    Cycle loss for pixel-vertex correspondence
    """
    def __init__(self, cfg: CfgNode):
        super().__init__()
        self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys())
        self.embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
        self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P
        self.use_all_meshes_not_gt_only = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY
        )
        self.num_pixels_to_sample = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE
        )
        self.pix_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA
        self.temperature_pix_to_vertex = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX
        )
        self.temperature_vertex_to_pix = (
            cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL
        )
        self.pixel_dists = _create_pixel_dist_matrix(cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE)
    def forward(
        self,
        proposals_with_gt: List[Instances],
        densepose_predictor_outputs: Any,
        packed_annotations: PackedCseAnnotations,
        embedder: nn.Module,
    ):
        """
        Args:
            proposals_with_gt (list of Instances): detections with associated
                ground truth data; each item corresponds to instances detected
                on 1 image; the number of items corresponds to the number of
                images in a batch
            densepose_predictor_outputs: an object of a dataclass that contains predictor
                outputs with estimated values; assumed to have the following attributes:
                * embedding - embedding estimates, tensor of shape [N, D, S, S], where
                  N = number of instances (= sum N_i, where N_i is the number of
                      instances on image i)
                  D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE)
                  S = output size (width and height)
            packed_annotations (PackedCseAnnotations): contains various data useful
                for loss computation, each data is packed into a single tensor
            embedder (nn.Module): module that computes vertex embeddings for different meshes
        """
        pix_embeds = densepose_predictor_outputs.embedding
        if self.pixel_dists.device != pix_embeds.device:
            # should normally be done only once
            self.pixel_dists = self.pixel_dists.to(device=pix_embeds.device)
        with torch.no_grad():
            mask_loss_data = extract_data_for_mask_loss_from_matches(
                proposals_with_gt, densepose_predictor_outputs.coarse_segm
            )
        # GT masks - tensor of shape [N, S, S] of int64
        masks_gt = mask_loss_data.masks_gt.long()  # pyre-ignore[16]
        assert len(pix_embeds) == len(masks_gt), (
            f"Number of instances with embeddings {len(pix_embeds)} != "
            f"number of instances with GT masks {len(masks_gt)}"
        )
        losses = []
        mesh_names = (
            self.shape_names
            if self.use_all_meshes_not_gt_only
            else [
                MeshCatalog.get_mesh_name(mesh_id.item())
                for mesh_id in packed_annotations.vertex_mesh_ids_gt.unique()
            ]
        )
        for pixel_embeddings, mask_gt in zip(pix_embeds, masks_gt):
            # pixel_embeddings [D, S, S]
            # mask_gt [S, S]
            for mesh_name in mesh_names:
                mesh_vertex_embeddings = embedder(mesh_name)
                # pixel indices [M]
                pixel_indices_flattened = _sample_fg_pixels_randperm(
                    mask_gt, self.num_pixels_to_sample
                )
                # pixel distances [M, M]
                pixel_dists = self.pixel_dists.to(pixel_embeddings.device)[
                    torch.meshgrid(pixel_indices_flattened, pixel_indices_flattened)
                ]
                # pixel embeddings [M, D]
                pixel_embeddings_sampled = normalize_embeddings(
                    pixel_embeddings.reshape((self.embed_size, -1))[:, pixel_indices_flattened].T
                )
                # pixel-vertex similarity [M, K]
                sim_matrix = pixel_embeddings_sampled.mm(mesh_vertex_embeddings.T)
                c_pix_vertex = F.softmax(sim_matrix / self.temperature_pix_to_vertex, dim=1)
                c_vertex_pix = F.softmax(sim_matrix.T / self.temperature_vertex_to_pix, dim=1)
                c_cycle = c_pix_vertex.mm(c_vertex_pix)
                loss_cycle = torch.norm(pixel_dists * c_cycle, p=self.norm_p)
                losses.append(loss_cycle)
        if len(losses) == 0:
            return pix_embeds.sum() * 0
        return torch.stack(losses, dim=0).mean()
    def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module):
        losses = [embedder(mesh_name).sum() * 0 for mesh_name in embedder.mesh_names]
        losses.append(densepose_predictor_outputs.embedding.sum() * 0)
        return torch.mean(torch.stack(losses))
 | 
 
			
