File size: 1,105 Bytes
399ea10
 
 
 
 
 
 
 
a49c2ff
 
399ea10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a49c2ff
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from PIL import Image
import requests

import torch

import matplotlib.pyplot as plt
import numpy as np

import gradio as gr

from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

def visualize_segmentation(image, prompts, preds):
  _, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))
  [a.axis('off') for a in ax.flatten()]
  ax[0].imshow(image)
  [ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(len(prompts))];
  [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];
     

def segment(img, clases):
    prompts = clases.split(',')

    inputs = processor(text=prompts, images=[image] * len(img), padding="max_length", return_tensors="pt")
    
    with torch.no_grad():
      outputs = model(**inputs)
    preds = outputs.logits.unsqueeze(1)

    
    
    return "Hello " + prompts + "!!"

demo = gr.Interface(fn=greet, inputs=["image","text"], outputs="text")
demo.launch()