Spaces:
Runtime error
Runtime error
| import typer | |
| from groundingdino.util.inference import load_model, load_image, predict | |
| from tqdm import tqdm | |
| import torchvision | |
| import torch | |
| import fiftyone as fo | |
| def main( | |
| image_directory: str = 'test_grounding_dino', | |
| text_prompt: str = 'bus, car', | |
| box_threshold: float = 0.15, | |
| text_threshold: float = 0.10, | |
| export_dataset: bool = False, | |
| view_dataset: bool = False, | |
| export_annotated_images: bool = True, | |
| weights_path : str = "groundingdino_swint_ogc.pth", | |
| config_path: str = "../../GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", | |
| subsample: int = None, | |
| ): | |
| model = load_model(config_path, weights_path) | |
| dataset = fo.Dataset.from_images_dir(image_directory) | |
| samples = [] | |
| if subsample is not None: | |
| if subsample < len(dataset): | |
| dataset = dataset.take(subsample).clone() | |
| for sample in tqdm(dataset): | |
| image_source, image = load_image(sample.filepath) | |
| boxes, logits, phrases = predict( | |
| model=model, | |
| image=image, | |
| caption=text_prompt, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| ) | |
| detections = [] | |
| for box, logit, phrase in zip(boxes, logits, phrases): | |
| rel_box = torchvision.ops.box_convert(box, 'cxcywh', 'xywh') | |
| detections.append( | |
| fo.Detection( | |
| label=phrase, | |
| bounding_box=rel_box, | |
| confidence=logit, | |
| )) | |
| # Store detections in a field name of your choice | |
| sample["detections"] = fo.Detections(detections=detections) | |
| sample.save() | |
| # loads the voxel fiftyone UI ready for viewing the dataset. | |
| if view_dataset: | |
| session = fo.launch_app(dataset) | |
| session.wait() | |
| # exports COCO dataset ready for training | |
| if export_dataset: | |
| dataset.export( | |
| 'coco_dataset', | |
| dataset_type=fo.types.COCODetectionDataset, | |
| ) | |
| # saves bounding boxes plotted on the input images to disk | |
| if export_annotated_images: | |
| dataset.draw_labels( | |
| 'images_with_bounding_boxes', | |
| label_fields=['detections'] | |
| ) | |
| if __name__ == '__main__': | |
| typer.run(main) | |