File size: 3,520 Bytes
6df18f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import os
import numpy as np
import warnings
warnings.filterwarnings("ignore", message="PyTorch version 1.7.1 or higher is recommended")
import clip


join = os.path.join


class SimTests:
    def __init__(self, clip_model_name="ViT-B/16", ac_scale=336, device="cuda:0"):
        self.device = device
        self.clip_size = 224
        self.clip_model_name = clip_model_name
        self.clip_model = (
            clip.load(self.clip_model_name, device=self.device, jit=False)[0]
            .eval()
            .requires_grad_(False)
        )

    def read_image(self, img_path, dest_size=None):
        image = Image.open(img_path).convert("RGB")
        if (dest_size is not None) and (image.size != dest_size):
            image = image.resize(dest_size, Image.LANCZOS)
        image = np.array(image)
        image = image.astype(np.float32) / 255.0
        image = image[None].transpose(0, 3, 1, 2)
        image = torch.from_numpy(image).to(self.device)
        return image

    def assert_properties(self, image_in, image_out):
        for image in (image_in, image_out):
            assert torch.is_tensor(image)
            assert image.min() >= 0 and image.max() <= 1

    # CLIP similarity and Directional CLIP similarity (higher is better)
    @torch.no_grad()
    def clip_sim(self, image_in, image_out, text_in, text_out):
        """Calculates:
        1. CLIP similarity between output image and output caption
        2. Directional CLIP similarity of the change between input and output images,
           with the change between input and output captions
        """
        self.assert_properties(image_in, image_out)
        text_in = [text_in]
        text_out = [text_out]
        image_features_in = self.encode_image(image_in)
        image_features_out = self.encode_image(image_out)
        text_features_in = self.encode_text(text_in)
        text_features_out = self.encode_text(text_out)
        sim_out = F.cosine_similarity(image_features_out, text_features_out)
        sim_direction = F.cosine_similarity(
            image_features_out - image_features_in, text_features_out - text_features_in
        )
        return sim_out, sim_direction

    def encode_text(self, prompt):
        assert type(prompt) in (list, tuple)
        text = clip.tokenize(prompt).to(self.device)
        text_features = self.clip_model.encode_text(text)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        return text_features

    def encode_image(self, image):
        image_transform = transforms.Compose(
            [
                transforms.Resize(self.clip_size, interpolation=Image.BICUBIC),
                transforms.Normalize(
                    (0.48145466, 0.4578275, 0.40821073),
                    (0.26862954, 0.26130258, 0.27577711),
                ),
            ]
        )
        image = image_transform(image).half()
        image_features = self.clip_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features

    # L1 distance (lower is better)
    def L1_dist(self, image_in, image_out):
        """
        Mean L1 pixel distance between input and output images, to measure the
        amount of change in the entire image
        """
        self.assert_properties(image_in, image_out)
        return F.l1_loss(image_in, image_out).unsqueeze(0)