segmentation / app.py
Alex
first commit
7077928
raw
history blame
1.67 kB
import gradio as gr
from transformers import SamModel, SamProcessor
from PIL import Image
import numpy as np
import torch
# Carica il modello e il processore
model = SamModel.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model.to("cpu") # Assicurati che il modello giri su CPU
def segment_image(image):
# Converti l'immagine in formato compatibile
img = Image.fromarray(np.uint8(image)).convert("RGB")
# Prepara l'input per SAM (segmentazione automatica)
inputs = processor(img, return_tensors="pt").to("cpu")
# Inferenza
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
# Post-processa per ottenere la maschera
mask = processor.image_processor.post_process_masks(
outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
)[0][0].cpu().numpy()
# Converti la maschera in un'immagine visibile
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
# Annotazioni (es. bounding box o etichette semplici)
annotations = {"mask": mask.tolist(), "label": "object"} # Personalizza come necessario
return mask_img, str(annotations)
# Interfaccia Gradio
interface = gr.Interface(
fn=segment_image,
inputs=gr.Image(type="numpy", label="Carica un'immagine"),
outputs=[
gr.Image(type="pil", label="Maschera di segmentazione"),
gr.Textbox(label="Annotazioni per inpainting")
],
title="Segmentazione di vestiti e oggetti con SAM",
description="Carica un'immagine per ottenere la segmentazione e le annotazioni."
)
interface.launch()