|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import logging | 
					
						
						|  | from typing import List, Optional, Sequence, Tuple | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from detectron2.layers.nms import batched_nms | 
					
						
						|  | from detectron2.structures.instances import Instances | 
					
						
						|  |  | 
					
						
						|  | from densepose.converters import ToChartResultConverterWithConfidences | 
					
						
						|  | from densepose.structures import ( | 
					
						
						|  | DensePoseChartResultWithConfidences, | 
					
						
						|  | DensePoseEmbeddingPredictorOutput, | 
					
						
						|  | ) | 
					
						
						|  | from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer | 
					
						
						|  | from densepose.vis.densepose_outputs_vertex import DensePoseOutputsVertexVisualizer | 
					
						
						|  | from densepose.vis.densepose_results import DensePoseResultsVisualizer | 
					
						
						|  |  | 
					
						
						|  | from .base import CompoundVisualizer | 
					
						
						|  |  | 
					
						
						|  | Scores = Sequence[float] | 
					
						
						|  | DensePoseChartResultsWithConfidences = List[DensePoseChartResultWithConfidences] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def extract_scores_from_instances(instances: Instances, select=None): | 
					
						
						|  | if instances.has("scores"): | 
					
						
						|  | return instances.scores if select is None else instances.scores[select] | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def extract_boxes_xywh_from_instances(instances: Instances, select=None): | 
					
						
						|  | if instances.has("pred_boxes"): | 
					
						
						|  | boxes_xywh = instances.pred_boxes.tensor.clone() | 
					
						
						|  | boxes_xywh[:, 2] -= boxes_xywh[:, 0] | 
					
						
						|  | boxes_xywh[:, 3] -= boxes_xywh[:, 1] | 
					
						
						|  | return boxes_xywh if select is None else boxes_xywh[select] | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_extractor(visualizer: object): | 
					
						
						|  | """ | 
					
						
						|  | Create an extractor for the provided visualizer | 
					
						
						|  | """ | 
					
						
						|  | if isinstance(visualizer, CompoundVisualizer): | 
					
						
						|  | extractors = [create_extractor(v) for v in visualizer.visualizers] | 
					
						
						|  | return CompoundExtractor(extractors) | 
					
						
						|  | elif isinstance(visualizer, DensePoseResultsVisualizer): | 
					
						
						|  | return DensePoseResultExtractor() | 
					
						
						|  | elif isinstance(visualizer, ScoredBoundingBoxVisualizer): | 
					
						
						|  | return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances]) | 
					
						
						|  | elif isinstance(visualizer, BoundingBoxVisualizer): | 
					
						
						|  | return extract_boxes_xywh_from_instances | 
					
						
						|  | elif isinstance(visualizer, DensePoseOutputsVertexVisualizer): | 
					
						
						|  | return DensePoseOutputsExtractor() | 
					
						
						|  | else: | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  | logger.error(f"Could not create extractor for {visualizer}") | 
					
						
						|  | return None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BoundingBoxExtractor: | 
					
						
						|  | """ | 
					
						
						|  | Extracts bounding boxes from instances | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, instances: Instances): | 
					
						
						|  | boxes_xywh = extract_boxes_xywh_from_instances(instances) | 
					
						
						|  | return boxes_xywh | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ScoredBoundingBoxExtractor: | 
					
						
						|  | """ | 
					
						
						|  | Extracts bounding boxes from instances | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, instances: Instances, select=None): | 
					
						
						|  | scores = extract_scores_from_instances(instances) | 
					
						
						|  | boxes_xywh = extract_boxes_xywh_from_instances(instances) | 
					
						
						|  | if (scores is None) or (boxes_xywh is None): | 
					
						
						|  | return (boxes_xywh, scores) | 
					
						
						|  | if select is not None: | 
					
						
						|  | scores = scores[select] | 
					
						
						|  | boxes_xywh = boxes_xywh[select] | 
					
						
						|  | return (boxes_xywh, scores) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DensePoseResultExtractor: | 
					
						
						|  | """ | 
					
						
						|  | Extracts DensePose chart result with confidences from instances | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __call__( | 
					
						
						|  | self, instances: Instances, select=None | 
					
						
						|  | ) -> Tuple[Optional[DensePoseChartResultsWithConfidences], Optional[torch.Tensor]]: | 
					
						
						|  | if instances.has("pred_densepose") and instances.has("pred_boxes"): | 
					
						
						|  | dpout = instances.pred_densepose | 
					
						
						|  | boxes_xyxy = instances.pred_boxes | 
					
						
						|  | boxes_xywh = extract_boxes_xywh_from_instances(instances) | 
					
						
						|  | if select is not None: | 
					
						
						|  | dpout = dpout[select] | 
					
						
						|  | boxes_xyxy = boxes_xyxy[select] | 
					
						
						|  | converter = ToChartResultConverterWithConfidences() | 
					
						
						|  | results = [converter.convert(dpout[i], boxes_xyxy[[i]]) for i in range(len(dpout))] | 
					
						
						|  | return results, boxes_xywh | 
					
						
						|  | else: | 
					
						
						|  | return None, None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DensePoseOutputsExtractor: | 
					
						
						|  | """ | 
					
						
						|  | Extracts DensePose result from instances | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __call__( | 
					
						
						|  | self, | 
					
						
						|  | instances: Instances, | 
					
						
						|  | select=None, | 
					
						
						|  | ) -> Tuple[ | 
					
						
						|  | Optional[DensePoseEmbeddingPredictorOutput], Optional[torch.Tensor], Optional[List[int]] | 
					
						
						|  | ]: | 
					
						
						|  | if not (instances.has("pred_densepose") and instances.has("pred_boxes")): | 
					
						
						|  | return None, None, None | 
					
						
						|  |  | 
					
						
						|  | dpout = instances.pred_densepose | 
					
						
						|  | boxes_xyxy = instances.pred_boxes | 
					
						
						|  | boxes_xywh = extract_boxes_xywh_from_instances(instances) | 
					
						
						|  |  | 
					
						
						|  | if instances.has("pred_classes"): | 
					
						
						|  | classes = instances.pred_classes.tolist() | 
					
						
						|  | else: | 
					
						
						|  | classes = None | 
					
						
						|  |  | 
					
						
						|  | if select is not None: | 
					
						
						|  | dpout = dpout[select] | 
					
						
						|  | boxes_xyxy = boxes_xyxy[select] | 
					
						
						|  | if classes is not None: | 
					
						
						|  | classes = classes[select] | 
					
						
						|  |  | 
					
						
						|  | return dpout, boxes_xywh, classes | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CompoundExtractor: | 
					
						
						|  | """ | 
					
						
						|  | Extracts data for CompoundVisualizer | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, extractors): | 
					
						
						|  | self.extractors = extractors | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, instances: Instances, select=None): | 
					
						
						|  | datas = [] | 
					
						
						|  | for extractor in self.extractors: | 
					
						
						|  | data = extractor(instances, select) | 
					
						
						|  | datas.append(data) | 
					
						
						|  | return datas | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class NmsFilteredExtractor: | 
					
						
						|  | """ | 
					
						
						|  | Extracts data in the format accepted by NmsFilteredVisualizer | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, extractor, iou_threshold): | 
					
						
						|  | self.extractor = extractor | 
					
						
						|  | self.iou_threshold = iou_threshold | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, instances: Instances, select=None): | 
					
						
						|  | scores = extract_scores_from_instances(instances) | 
					
						
						|  | boxes_xywh = extract_boxes_xywh_from_instances(instances) | 
					
						
						|  | if boxes_xywh is None: | 
					
						
						|  | return None | 
					
						
						|  | select_local_idx = batched_nms( | 
					
						
						|  | boxes_xywh, | 
					
						
						|  | scores, | 
					
						
						|  | torch.zeros(len(scores), dtype=torch.int32), | 
					
						
						|  | iou_threshold=self.iou_threshold, | 
					
						
						|  | ).squeeze() | 
					
						
						|  | select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device) | 
					
						
						|  | select_local[select_local_idx] = True | 
					
						
						|  | select = select_local if select is None else (select & select_local) | 
					
						
						|  | return self.extractor(instances, select=select) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ScoreThresholdedExtractor: | 
					
						
						|  | """ | 
					
						
						|  | Extracts data in the format accepted by ScoreThresholdedVisualizer | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, extractor, min_score): | 
					
						
						|  | self.extractor = extractor | 
					
						
						|  | self.min_score = min_score | 
					
						
						|  |  | 
					
						
						|  | def __call__(self, instances: Instances, select=None): | 
					
						
						|  | scores = extract_scores_from_instances(instances) | 
					
						
						|  | if scores is None: | 
					
						
						|  | return None | 
					
						
						|  | select_local = scores > self.min_score | 
					
						
						|  | select = select_local if select is None else (select & select_local) | 
					
						
						|  | data = self.extractor(instances, select=select) | 
					
						
						|  | return data | 
					
						
						|  |  |