CostalSegment / pipeline /ImgOutlier.py
AveMujica's picture
init
40aaca9
import numpy as np
import torch
import torchvision
from PIL import Image
from torch import nn
from torchvision import transforms as tr
from torchvision.models import vit_h_14
import cv2
class CosineSimilarity:
def __init__(self, vector='feature', threshold=0.8, mean_vec=[], device=None):
"""
Initialize the CosineSimilarity class.
Args:
vector (str): Type of vector to use ('feature' or 'image')
threshold (float): Threshold for determining outliers
mean_vec (numpy vector): Preloaded reference vector for comparison
device (str): Device to use for computation (default: 'mps' if available, else 'cuda' if available, else 'cpu')
"""
if device is None:
if torch.backends.mps.is_available():
self.device = 'mps'
elif torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
else:
self.device = device
self.vector = vector
self.threshold = threshold
self.model_instance = None
self.mean_vec = mean_vec
def model(self):
"""Initialize and return the ViT model."""
if self.model_instance is None:
wt = torchvision.models.ViT_H_14_Weights.DEFAULT
self.model_instance = vit_h_14(weights=wt)
self.model_instance.heads = nn.Sequential(*list(self.model_instance.heads.children())[:-1])
self.model_instance = self.model_instance.to(self.device)
return self.model_instance
def process_image(self, cv2_img):
"""
Process a cv2 image for the model.
Args:
cv2_img: OpenCV image (BGR format)
Returns:
Processed tensor
"""
# Convert BGR to RGB
rgb_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
# Convert to PIL Image
pil_img = Image.fromarray(rgb_img)
# A set of transformations to prepare the image in tensor format
transformations = tr.Compose([
tr.ToTensor(),
tr.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
tr.Resize((518, 518))
])
# preparing the image
img_tensor = transformations(pil_img).float()
if self.vector == 'image':
img_tensor = img_tensor.flatten()
img_tensor = img_tensor.unsqueeze_(0)
if self.vector == 'feature':
img_tensor = img_tensor.to(self.device)
return img_tensor
def get_embeddings(self, ref_images, test_images):
"""
Get embeddings for reference and test images.
Args:
ref_images: List of cv2 reference images
test_images: List of cv2 test images
Returns:
Reference embedding, list of test embeddings
"""
model = self.model()
# Process test images
emb_test = []
for img in test_images:
processed_img = self.process_image(img)
if self.vector == 'feature':
emb = model(processed_img).detach().cpu()
emb_test.append(emb)
else: # 'image'
emb_test.append(processed_img)
# This checks if a reference vector is loaded, if so the process of getting
# reference embeddings can be skipped for efficiency
if len(self.mean_vec) > 0:
emb_ref = torch.tensor(self.mean_vec)
# Process reference images if necessary
else:
if self.vector == 'feature':
# Standard method of getting reference embedding vector
emb_ref_list = []
for img in ref_images:
processed_img = self.process_image(img)
emb = model(processed_img).detach().cpu()
emb_ref_list.append(emb)
# Average the reference embeddings
emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
else: # 'image'
emb_ref_list = []
for img in ref_images:
processed_img = self.process_image(img)
emb_ref_list.append(processed_img)
# Average the reference images
emb_ref = torch.mean(torch.stack(emb_ref_list), dim=0)
return emb_ref, emb_test
def find_outliers(self, ref_images, test_images):
"""
Find outliers in test images compared to reference images.
Args:
ref_images: List of cv2 reference images
test_images: List of cv2 test images
Returns:
mask: Boolean array where True indicates an outlier
scores: Similarity scores for each test image
"""
emb_ref, emb_test = self.get_embeddings(ref_images, test_images)
scores = []
mask = []
for i in range(len(emb_test)):
score = torch.nn.functional.cosine_similarity(emb_ref, emb_test[i])
score_value = score.item()
scores.append(round(score_value, 4))
# True if it's an outlier (below threshold)
mask.append(score_value <= self.threshold)
return np.array(mask), scores, emb_ref
def filter_outliers(self, ref_images, test_images):
"""
Filter out outliers from test images.
Args:
ref_images: List of cv2 reference images
test_images: List of cv2 test images
Returns:
filtered_images: List of non-outlier test images
outlier_mask: Boolean array where True indicates an outlier
scores: Similarity scores for each test image
"""
outlier_mask, scores, mean = self.find_outliers(ref_images, test_images)
# Filter out outliers (keep only non-outliers)
filtered_images = [img for i, img in enumerate(test_images) if not outlier_mask[i]]
return filtered_images, outlier_mask, scores, mean
def detect_outliers(ref_imgs, imgs, mean_vec=[]):
"""
Detects outliers in a set of test images, can use a reference vector
Args:
ref_images: List of cv2 reference images
images: List of cv2 test images
mean_vec: optional pre-computed reference vector
Returns:
filtered_images: List of non-outlier test images
mean: the reference vector used (if a new reference vector should be saved)
"""
similarity = CosineSimilarity(vector='feature', threshold=0.8, mean_vec=mean_vec)
# Get outlier mask, scores, and reference vector
outlier_mask, scores, mean_vector = similarity.find_outliers(ref_imgs, imgs)
# Filter out outliers
filtered_images = [img for i, img in enumerate(imgs) if not outlier_mask[i]]
return filtered_images, mean_vector