import torch from torch import Tensor, nn import torch.nn.functional as F import open_clip from tqdm import tqdm import numpy as np from typing import Union, Tuple, List num_to_word = { "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine", "10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen", "20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine", "30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine", "40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine", "50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine", "60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine", "70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine", "80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine", "90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine", "100": "one hundred" } prefixes = [ "", "A photo of", "A block of", "An image of", "A picture of", "There are", "The image contains", "The photo contains", "The picture contains", "The image shows", "The photo shows", "The picture shows", ] arabic_numeral = [True, False] compares = [ "more than", "greater than", "higher than", "larger than", "bigger than", "greater than or equal to", "at least", "no less than", "not less than", "not fewer than", "not lower than", "not smaller than", "not less than or equal to", "over", "above", "beyond", "exceeding", "surpassing", ] suffixes = [ "people", "persons", "individuals", "humans", "faces", "heads", "figures", "", ] def num2word(num: Union[int, str]) -> str: """ Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc. """ num = str(int(num)) return num_to_word.get(num, num) def format_count( bins: List[Union[float, Tuple[float, float]]], ) -> List[List[str]]: text_prompts = [] for prefix in prefixes: for numeral in arabic_numeral: for compare in compares: for suffix in suffixes: prompts = [] for bin in bins: if isinstance(bin, (int, float)): # count is a single number count = int(bin) if count == 0 or count == 1: count = num2word(count) if not numeral else count prefix_ = "There is" if prefix == "There are" else prefix suffix_ = "person" if suffix == "people" else suffix[:-1] prompt = f"{prefix_} {count} {suffix_}" else: # count > 1 count = num2word(count) if not numeral else count prompt = f"{prefix} {count} {suffix}" elif bin[1] == float("inf"): # count is (lower_bound, inf) count = int(bin[0]) count = num2word(count) if not numeral else count prompt = f"{prefix} {compare} {count} {suffix}" else: # bin is (lower_bound, upper_bound) left, right = int(bin[0]), int(bin[1]) left, right = num2word(left) if not numeral else left, num2word(right) if not numeral else right prompt = f"{prefix} between {left} and {right} {suffix}" # Remove starting and trailing whitespaces prompt = prompt.strip() + "." prompts.append(prompt) text_prompts.append(prompts) return text_prompts def encode_text( model_name: str, weight_name: str, text: List[str] ) -> Tensor: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") text = open_clip.get_tokenizer(model_name)(text).to(device) model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).to(device) model.eval() with torch.no_grad(): text_feats = model.encode_text(text) text_feats = F.normalize(text_feats, p=2, dim=-1).detach().cpu() return text_feats def optimize_text_prompts( model_name: str, weight_name: str, flat_bins: List[Union[float, Tuple[float, float]]], batch_size: int = 1024, ) -> List[str]: text_prompts = format_count(flat_bins) # Find the template that has the smallest average similarity of bin prompts. print("Finding the best setup for text prompts...") text_prompts_ = [prompt for prompts in text_prompts for prompt in prompts] # flatten the list text_feats = [] for i in tqdm(range(0, len(text_prompts_), batch_size)): text_feats.append(encode_text(model_name, weight_name, text_prompts_[i: min(i + batch_size, len(text_prompts_))])) text_feats = torch.cat(text_feats, dim=0) sims = [] for idx, prompts in enumerate(text_prompts): text_feats_ = text_feats[idx * len(prompts): (idx + 1) * len(prompts)] sim = torch.mm(text_feats_, text_feats_.T) sim = sim[~torch.eye(sim.shape[0], dtype=bool)].mean().item() sims.append(sim) optimal_prompts = text_prompts[np.argmin(sims)] sim = sims[np.argmin(sims)] print(f"Found the best text prompts: {optimal_prompts} (similarity: {sim:.2f})") return optimal_prompts