omeregev's picture
Initial commit
6df18f5
import torch
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import os
import numpy as np
import warnings
warnings.filterwarnings("ignore", message="PyTorch version 1.7.1 or higher is recommended")
import clip
join = os.path.join
class SimTests:
def __init__(self, clip_model_name="ViT-B/16", ac_scale=336, device="cuda:0"):
self.device = device
self.clip_size = 224
self.clip_model_name = clip_model_name
self.clip_model = (
clip.load(self.clip_model_name, device=self.device, jit=False)[0]
.eval()
.requires_grad_(False)
)
def read_image(self, img_path, dest_size=None):
image = Image.open(img_path).convert("RGB")
if (dest_size is not None) and (image.size != dest_size):
image = image.resize(dest_size, Image.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(self.device)
return image
def assert_properties(self, image_in, image_out):
for image in (image_in, image_out):
assert torch.is_tensor(image)
assert image.min() >= 0 and image.max() <= 1
# CLIP similarity and Directional CLIP similarity (higher is better)
@torch.no_grad()
def clip_sim(self, image_in, image_out, text_in, text_out):
"""Calculates:
1. CLIP similarity between output image and output caption
2. Directional CLIP similarity of the change between input and output images,
with the change between input and output captions
"""
self.assert_properties(image_in, image_out)
text_in = [text_in]
text_out = [text_out]
image_features_in = self.encode_image(image_in)
image_features_out = self.encode_image(image_out)
text_features_in = self.encode_text(text_in)
text_features_out = self.encode_text(text_out)
sim_out = F.cosine_similarity(image_features_out, text_features_out)
sim_direction = F.cosine_similarity(
image_features_out - image_features_in, text_features_out - text_features_in
)
return sim_out, sim_direction
def encode_text(self, prompt):
assert type(prompt) in (list, tuple)
text = clip.tokenize(prompt).to(self.device)
text_features = self.clip_model.encode_text(text)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features
def encode_image(self, image):
image_transform = transforms.Compose(
[
transforms.Resize(self.clip_size, interpolation=Image.BICUBIC),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
image = image_transform(image).half()
image_features = self.clip_model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features
# L1 distance (lower is better)
def L1_dist(self, image_in, image_out):
"""
Mean L1 pixel distance between input and output images, to measure the
amount of change in the entire image
"""
self.assert_properties(image_in, image_out)
return F.l1_loss(image_in, image_out).unsqueeze(0)