Spaces:
Sleeping
Sleeping
| from turtle import title | |
| import os | |
| import gradio as gr | |
| from transformers import pipeline | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import cv2 | |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig | |
| from skimage.measure import label, regionprops | |
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| classes = list() | |
| def create_mask(image,image_mask,alpha=0.7): | |
| mask = np.zeros_like(image) | |
| # copy your image_mask to all dimensions (i.e. colors) of your image | |
| for i in range(3): | |
| mask[:,:,i] = image_mask.copy() | |
| # apply the mask to your image | |
| overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0) | |
| return overlay_image | |
| def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352): | |
| bbox = np.asarray(bbox)/model_shape | |
| y1,y2 = bbox[::2] *orig_image_shape[0] | |
| x1,x2 = bbox[1::2]*orig_image_shape[1] | |
| return [int(y1),int(x1),int(y2),int(x2)] | |
| def detect_using_clip(image,prompts=[],threshould=0.4): | |
| model_detections = dict() | |
| inputs = processor( | |
| text=prompts, | |
| images=[image] * len(prompts), | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation | |
| outputs = model(**inputs) | |
| preds = outputs.logits.unsqueeze(1) | |
| # tensor_images = [torch.sigmoid(preds[i][0]) for i in range(len(prompts))] | |
| detection = outputs.logits[0] # Assuming class index 0 | |
| for i,prompt in enumerate(prompts): | |
| predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy() | |
| predicted_image = np.where(predicted_image>threshould,255,0) | |
| # extract countours from the image | |
| lbl_0 = label(predicted_image) | |
| props = regionprops(lbl_0) | |
| model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props] | |
| return model_detections | |
| def visualize_images(image,detections,prompt): | |
| H,W = image.shape[:2] | |
| image_copy = image.copy() | |
| if prompt not in detections.keys(): | |
| print("prompt not in query ..") | |
| return image_copy | |
| for bbox in detections[prompt]: | |
| cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2) | |
| cv2.putText(image_copy,str(prompt),(int(bbox[1]), int(bbox[0])),cv2.FONT_HERSHEY_SIMPLEX, 2, 255) | |
| return image_copy | |
| def shot(image, labels_text,selected_categoty): | |
| prompts = labels_text.split(',') | |
| prompts = list(map(lambda x: x.strip(),prompts)) | |
| model_detections = detect_using_clip(image,prompts=prompts) | |
| category_image = visualize_images(image=image,detections=model_detections,prompt=selected_categoty) | |
| return category_image | |
| iface = gr.Interface(fn=shot, | |
| inputs = ["image","text","text"], | |
| outputs = "image", | |
| description ="Add an Image and list of category to be detected separated by commas", | |
| title = "Zero-shot Image Classification with Prompt ", | |
| examples=[ | |
| ["images/room.jpg","bed, table, plant, light, window",'plant'], | |
| ["images/image2.png","banner, building,door, sign"] | |
| ], | |
| # allow_flagging=False, | |
| # analytics_enabled=False, | |
| ) | |
| iface.launch() | |