Spaces:
Sleeping
Sleeping
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import torch.nn.functional as F | |
import os | |
import numpy as np | |
import warnings | |
warnings.filterwarnings("ignore", message="PyTorch version 1.7.1 or higher is recommended") | |
import clip | |
join = os.path.join | |
class SimTests: | |
def __init__(self, clip_model_name="ViT-B/16", ac_scale=336, device="cuda:0"): | |
self.device = device | |
self.clip_size = 224 | |
self.clip_model_name = clip_model_name | |
self.clip_model = ( | |
clip.load(self.clip_model_name, device=self.device, jit=False)[0] | |
.eval() | |
.requires_grad_(False) | |
) | |
def read_image(self, img_path, dest_size=None): | |
image = Image.open(img_path).convert("RGB") | |
if (dest_size is not None) and (image.size != dest_size): | |
image = image.resize(dest_size, Image.LANCZOS) | |
image = np.array(image) | |
image = image.astype(np.float32) / 255.0 | |
image = image[None].transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image).to(self.device) | |
return image | |
def assert_properties(self, image_in, image_out): | |
for image in (image_in, image_out): | |
assert torch.is_tensor(image) | |
assert image.min() >= 0 and image.max() <= 1 | |
# CLIP similarity and Directional CLIP similarity (higher is better) | |
def clip_sim(self, image_in, image_out, text_in, text_out): | |
"""Calculates: | |
1. CLIP similarity between output image and output caption | |
2. Directional CLIP similarity of the change between input and output images, | |
with the change between input and output captions | |
""" | |
self.assert_properties(image_in, image_out) | |
text_in = [text_in] | |
text_out = [text_out] | |
image_features_in = self.encode_image(image_in) | |
image_features_out = self.encode_image(image_out) | |
text_features_in = self.encode_text(text_in) | |
text_features_out = self.encode_text(text_out) | |
sim_out = F.cosine_similarity(image_features_out, text_features_out) | |
sim_direction = F.cosine_similarity( | |
image_features_out - image_features_in, text_features_out - text_features_in | |
) | |
return sim_out, sim_direction | |
def encode_text(self, prompt): | |
assert type(prompt) in (list, tuple) | |
text = clip.tokenize(prompt).to(self.device) | |
text_features = self.clip_model.encode_text(text) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
return text_features | |
def encode_image(self, image): | |
image_transform = transforms.Compose( | |
[ | |
transforms.Resize(self.clip_size, interpolation=Image.BICUBIC), | |
transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
image = image_transform(image).half() | |
image_features = self.clip_model.encode_image(image) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features | |
# L1 distance (lower is better) | |
def L1_dist(self, image_in, image_out): | |
""" | |
Mean L1 pixel distance between input and output images, to measure the | |
amount of change in the entire image | |
""" | |
self.assert_properties(image_in, image_out) | |
return F.l1_loss(image_in, image_out).unsqueeze(0) | |