feat: add cosine similarity measure
Browse files- app.py +7 -5
- lib/utils/model.py +0 -1
app.py
CHANGED
|
@@ -3,6 +3,7 @@ from lib.utils.model import get_model, get_similarities
|
|
| 3 |
from PIL import Image
|
| 4 |
|
| 5 |
st.title('IRRA Text-To-Image-Retrival')
|
|
|
|
| 6 |
|
| 7 |
st.header('Inputs')
|
| 8 |
caption = st.text_input('Description Input')
|
|
@@ -12,7 +13,7 @@ if images is not None:
|
|
| 12 |
st.image(images) # type: ignore
|
| 13 |
|
| 14 |
st.header('Options')
|
| 15 |
-
st.subheader('Ranks')
|
| 16 |
|
| 17 |
ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
|
| 18 |
|
|
@@ -26,15 +27,16 @@ if button:
|
|
| 26 |
st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
|
| 27 |
|
| 28 |
with st.spinner('Computing and ranking similarities'):
|
| 29 |
-
similarities = get_similarities(caption, images, model)
|
| 30 |
|
| 31 |
-
indices = similarities.argsort(descending=True).
|
| 32 |
|
| 33 |
for i, idx in enumerate(indices):
|
| 34 |
-
c1, c2 = st.columns(
|
| 35 |
with c1:
|
| 36 |
st.text(f'Rank {i + 1}')
|
| 37 |
with c2:
|
| 38 |
st.image(images[idx])
|
| 39 |
-
|
|
|
|
| 40 |
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
|
| 5 |
st.title('IRRA Text-To-Image-Retrival')
|
| 6 |
+
st.markdown('A text-to-image retrieval model implemented from [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
|
| 7 |
|
| 8 |
st.header('Inputs')
|
| 9 |
caption = st.text_input('Description Input')
|
|
|
|
| 13 |
st.image(images) # type: ignore
|
| 14 |
|
| 15 |
st.header('Options')
|
| 16 |
+
st.subheader('Ranks', help='How many predictions the model is allowed to make')
|
| 17 |
|
| 18 |
ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
|
| 19 |
|
|
|
|
| 27 |
st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
|
| 28 |
|
| 29 |
with st.spinner('Computing and ranking similarities'):
|
| 30 |
+
similarities = get_similarities(caption, images, model).squeeze(0)
|
| 31 |
|
| 32 |
+
indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
|
| 33 |
|
| 34 |
for i, idx in enumerate(indices):
|
| 35 |
+
c1, c2, c3 = st.columns(3)
|
| 36 |
with c1:
|
| 37 |
st.text(f'Rank {i + 1}')
|
| 38 |
with c2:
|
| 39 |
st.image(images[idx])
|
| 40 |
+
with c3:
|
| 41 |
+
st.text(f'Cosine sim {similarities[idx].cpu():.2f}')
|
| 42 |
|
lib/utils/model.py
CHANGED
|
@@ -24,7 +24,6 @@ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
|
| 24 |
txt = tokenize(text, tokenizer)
|
| 25 |
imgs = prepare_images(images)
|
| 26 |
|
| 27 |
-
print(imgs.shape)
|
| 28 |
image_feats = model.encode_image(imgs)
|
| 29 |
text_feats = model.encode_text(txt.unsqueeze(0))
|
| 30 |
|
|
|
|
| 24 |
txt = tokenize(text, tokenizer)
|
| 25 |
imgs = prepare_images(images)
|
| 26 |
|
|
|
|
| 27 |
image_feats = model.encode_image(imgs)
|
| 28 |
text_feats = model.encode_text(txt.unsqueeze(0))
|
| 29 |
|