Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,287 Bytes
69af19f e0af351 69af19f 51488de 69af19f 51488de 69af19f 51488de 69af19f e0af351 69af19f 51488de 69af19f 51488de 69af19f 51488de 69af19f 60fe69b 69af19f 60fe69b 69af19f 60fe69b 69af19f 60fe69b 69af19f 51488de e0af351 69af19f e0af351 69af19f e0af351 60fe69b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
#!/usr/bin/env python
import pathlib
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from sahi.prediction import ObjectPrediction
from sahi.utils.cv import visualize_object_predictions
from transformers import AutoImageProcessor, DetaForObjectDetection
DESCRIPTION = "# DETA (Detection Transformers with Assignment)"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_ID = "jozhang97/deta-swin-large"
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = DetaForObjectDetection.from_pretrained(MODEL_ID)
model.to(device)
@spaces.GPU
@torch.inference_mode()
def run(image_path: str, threshold: float) -> np.ndarray:
image = PIL.Image.open(image_path)
inputs = image_processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0]
boxes = results["boxes"].cpu().numpy()
scores = results["scores"].cpu().numpy()
cat_ids = results["labels"].cpu().numpy().tolist()
preds = []
for box, score, cat_id in zip(boxes, scores, cat_ids, strict=True):
box_int = np.round(box).astype(int)
cat_label = model.config.id2label[cat_id]
pred = ObjectPrediction(bbox=box_int, category_id=cat_id, category_name=cat_label, score=score)
preds.append(pred)
return visualize_object_predictions(np.asarray(image), preds)["image"]
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
image = gr.Image(label="Input image", type="filepath")
threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.01, value=0.1)
run_button = gr.Button()
result = gr.Image(label="Result")
gr.Examples(
examples=[[path, 0.1] for path in sorted(pathlib.Path("images").glob("*.jpg"))],
inputs=[image, threshold],
outputs=result,
fn=run,
)
run_button.click(
fn=run,
inputs=[image, threshold],
outputs=result,
api_name="predict",
)
if __name__ == "__main__":
demo.launch()
|