Spaces:
Sleeping
Sleeping
File size: 3,520 Bytes
6df18f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import os
import numpy as np
import warnings
warnings.filterwarnings("ignore", message="PyTorch version 1.7.1 or higher is recommended")
import clip
join = os.path.join
class SimTests:
def __init__(self, clip_model_name="ViT-B/16", ac_scale=336, device="cuda:0"):
self.device = device
self.clip_size = 224
self.clip_model_name = clip_model_name
self.clip_model = (
clip.load(self.clip_model_name, device=self.device, jit=False)[0]
.eval()
.requires_grad_(False)
)
def read_image(self, img_path, dest_size=None):
image = Image.open(img_path).convert("RGB")
if (dest_size is not None) and (image.size != dest_size):
image = image.resize(dest_size, Image.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(self.device)
return image
def assert_properties(self, image_in, image_out):
for image in (image_in, image_out):
assert torch.is_tensor(image)
assert image.min() >= 0 and image.max() <= 1
# CLIP similarity and Directional CLIP similarity (higher is better)
@torch.no_grad()
def clip_sim(self, image_in, image_out, text_in, text_out):
"""Calculates:
1. CLIP similarity between output image and output caption
2. Directional CLIP similarity of the change between input and output images,
with the change between input and output captions
"""
self.assert_properties(image_in, image_out)
text_in = [text_in]
text_out = [text_out]
image_features_in = self.encode_image(image_in)
image_features_out = self.encode_image(image_out)
text_features_in = self.encode_text(text_in)
text_features_out = self.encode_text(text_out)
sim_out = F.cosine_similarity(image_features_out, text_features_out)
sim_direction = F.cosine_similarity(
image_features_out - image_features_in, text_features_out - text_features_in
)
return sim_out, sim_direction
def encode_text(self, prompt):
assert type(prompt) in (list, tuple)
text = clip.tokenize(prompt).to(self.device)
text_features = self.clip_model.encode_text(text)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features
def encode_image(self, image):
image_transform = transforms.Compose(
[
transforms.Resize(self.clip_size, interpolation=Image.BICUBIC),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
image = image_transform(image).half()
image_features = self.clip_model.encode_image(image)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features
# L1 distance (lower is better)
def L1_dist(self, image_in, image_out):
"""
Mean L1 pixel distance between input and output images, to measure the
amount of change in the entire image
"""
self.assert_properties(image_in, image_out)
return F.l1_loss(image_in, image_out).unsqueeze(0)
|