Spaces:
Runtime error
Runtime error
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| from detectron2.config import configurable | |
| from detectron2.structures import ImageList, Instances, Boxes | |
| from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY | |
| from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN | |
| class GRiT(GeneralizedRCNN): | |
| def __init__( | |
| self, | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| assert self.proposal_generator is not None | |
| def from_config(cls, cfg): | |
| ret = super().from_config(cfg) | |
| return ret | |
| def inference( | |
| self, | |
| batched_inputs: Tuple[Dict[str, torch.Tensor]], | |
| detected_instances: Optional[List[Instances]] = None, | |
| do_postprocess: bool = True, | |
| ): | |
| assert not self.training | |
| assert detected_instances is None | |
| images = self.preprocess_image(batched_inputs) | |
| features = self.backbone(images.tensor) | |
| proposals, _ = self.proposal_generator(images, features, None) | |
| results, _ = self.roi_heads(features, proposals) | |
| if do_postprocess: | |
| assert not torch.jit.is_scripting(), \ | |
| "Scripting is not supported for postprocess." | |
| return GRiT._postprocess( | |
| results, batched_inputs, images.image_sizes) | |
| else: | |
| return results | |
| def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): | |
| if not self.training: | |
| return self.inference(batched_inputs) | |
| images = self.preprocess_image(batched_inputs) | |
| gt_instances = [x["instances"].to(self.device) for x in batched_inputs] | |
| targets_task = batched_inputs[0]['task'] | |
| for anno_per_image in batched_inputs: | |
| assert targets_task == anno_per_image['task'] | |
| features = self.backbone(images.tensor) | |
| proposals, proposal_losses = self.proposal_generator( | |
| images, features, gt_instances) | |
| proposals, roihead_textdecoder_losses = self.roi_heads( | |
| features, proposals, gt_instances, targets_task=targets_task) | |
| losses = {} | |
| losses.update(roihead_textdecoder_losses) | |
| losses.update(proposal_losses) | |
| return losses |