Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,894 Bytes
a7dedf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import open_clip
from tqdm import tqdm
import numpy as np
from typing import Union, Tuple, List
num_to_word = {
"0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine",
"10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen",
"20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine",
"30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine",
"40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine",
"50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine",
"60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine",
"70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine",
"80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine",
"90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine",
"100": "one hundred"
}
prefixes = [
"",
"A photo of", "A block of", "An image of", "A picture of",
"There are",
"The image contains", "The photo contains", "The picture contains",
"The image shows", "The photo shows", "The picture shows",
]
arabic_numeral = [True, False]
compares = [
"more than", "greater than", "higher than", "larger than", "bigger than", "greater than or equal to",
"at least", "no less than", "not less than", "not fewer than", "not lower than", "not smaller than", "not less than or equal to",
"over", "above", "beyond", "exceeding", "surpassing",
]
suffixes = [
"people", "persons", "individuals", "humans", "faces", "heads", "figures", "",
]
def num2word(num: Union[int, str]) -> str:
"""
Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc.
"""
num = str(int(num))
return num_to_word.get(num, num)
def format_count(
bins: List[Union[float, Tuple[float, float]]],
) -> List[List[str]]:
text_prompts = []
for prefix in prefixes:
for numeral in arabic_numeral:
for compare in compares:
for suffix in suffixes:
prompts = []
for bin in bins:
if isinstance(bin, (int, float)): # count is a single number
count = int(bin)
if count == 0 or count == 1:
count = num2word(count) if not numeral else count
prefix_ = "There is" if prefix == "There are" else prefix
suffix_ = "person" if suffix == "people" else suffix[:-1]
prompt = f"{prefix_} {count} {suffix_}"
else: # count > 1
count = num2word(count) if not numeral else count
prompt = f"{prefix} {count} {suffix}"
elif bin[1] == float("inf"): # count is (lower_bound, inf)
count = int(bin[0])
count = num2word(count) if not numeral else count
prompt = f"{prefix} {compare} {count} {suffix}"
else: # bin is (lower_bound, upper_bound)
left, right = int(bin[0]), int(bin[1])
left, right = num2word(left) if not numeral else left, num2word(right) if not numeral else right
prompt = f"{prefix} between {left} and {right} {suffix}"
# Remove starting and trailing whitespaces
prompt = prompt.strip() + "."
prompts.append(prompt)
text_prompts.append(prompts)
return text_prompts
def encode_text(
model_name: str,
weight_name: str,
text: List[str]
) -> Tensor:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
text = open_clip.get_tokenizer(model_name)(text).to(device)
model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).to(device)
model.eval()
with torch.no_grad():
text_feats = model.encode_text(text)
text_feats = F.normalize(text_feats, p=2, dim=-1).detach().cpu()
return text_feats
def optimize_text_prompts(
model_name: str,
weight_name: str,
flat_bins: List[Union[float, Tuple[float, float]]],
batch_size: int = 1024,
) -> List[str]:
text_prompts = format_count(flat_bins)
# Find the template that has the smallest average similarity of bin prompts.
print("Finding the best setup for text prompts...")
text_prompts_ = [prompt for prompts in text_prompts for prompt in prompts] # flatten the list
text_feats = []
for i in tqdm(range(0, len(text_prompts_), batch_size)):
text_feats.append(encode_text(model_name, weight_name, text_prompts_[i: min(i + batch_size, len(text_prompts_))]))
text_feats = torch.cat(text_feats, dim=0)
sims = []
for idx, prompts in enumerate(text_prompts):
text_feats_ = text_feats[idx * len(prompts): (idx + 1) * len(prompts)]
sim = torch.mm(text_feats_, text_feats_.T)
sim = sim[~torch.eye(sim.shape[0], dtype=bool)].mean().item()
sims.append(sim)
optimal_prompts = text_prompts[np.argmin(sims)]
sim = sims[np.argmin(sims)]
print(f"Found the best text prompts: {optimal_prompts} (similarity: {sim:.2f})")
return optimal_prompts
|