ancebuc's picture
Update app.py
c6cb1c5 verified
raw
history blame
1.13 kB
from PIL import Image
import requests
import torch
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def visualize_segmentation(image, prompts, preds):
_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
[a.axis('off') for a in ax.flatten()]
ax[0].imshow(image)
[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
[ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
def segment(img, clases):
print(img)
image = Image.fromarray(img, 'RGB')
prompts = clases.split(',')
inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
return "Hello " + prompts[0] + "!!"
demo = gr.Interface(fn=segment, inputs=["image","text"], outputs="text")
demo.launch()