onuralpszr's picture
feat: ✨ space library and text updates
cba65f8 verified
raw
history blame
2.58 kB
import os
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import supervision as sv
import cv2
import numpy as np
from PIL import Image
import gradio as gr
import space
BOX_ANNOTATOR = sv.BoxAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "google/paligemma2-3b-pt-448"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(DEVICE)
processor = PaliGemmaProcessor.from_pretrained(model_id)
@spaces.GPU
def process_image(input_image,input_text,class_names):
class_list = class_names.split(',')
cv_image = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR)
model_inputs = processor(text=input_text, images=input_image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
result = processor.decode(generation, skip_special_tokens=True)
detections = sv.Detections.from_lmm(
sv.LMM.PALIGEMMA,
result,
resolution_wh=(input_image.width, input_image.height),
classes=class_list
)
annotated_image = BOX_ANNOTATOR.annotate(
scene=cv_image.copy(),
detections=detections
)
annotated_image = LABEL_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = MASK_ANNOTATOR.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
annotated_image = Image.fromarray(annotated_image)
return annotated_image, result
app = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(lines=2, placeholder="Enter text here...", label="Enter prompt for example 'detect person;dog"),
gr.Textbox(lines=1, placeholder="Enter class names separated by commas...", label="Class Names")
],
outputs=[gr.Image(type="pil", label="Annotated Image"), gr.Textbox(label="Detection Result")],
title="PaliGemma2 Image Detection with Supervision",
description="Detect objects in an image using PaliGemma2 model."
)
if __name__ == "__main__":
app.launch()