hhe1ibeb sujitpal commited on
Commit
1768dbe
·
0 Parent(s):

Duplicate from sujitpal/clip-rsicd-demo

Browse files

Co-authored-by: Sujit Pal <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.pt filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CLIP-RSICD Demo
3
+ emoji: 🛰️
4
+ colorFrom: green
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ duplicated_from: sujitpal/clip-rsicd-demo
10
+ ---
11
+
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio` or `streamlit`
28
+
29
+ `app_file`: _string_
30
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
31
+ Path is relative to the root of the repository.
32
+
33
+ `pinned`: _boolean_
34
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dashboard_text2image
2
+ import dashboard_image2image
3
+ import dashboard_featurefinder
4
+
5
+ import streamlit as st
6
+
7
+ PAGES = {
8
+ "Retrieve Images given Text": dashboard_text2image,
9
+ "Retrieve Images given Image": dashboard_image2image,
10
+ "Find Feature in Image": dashboard_featurefinder,
11
+ }
12
+
13
+ st.sidebar.title("CLIP-RSICD")
14
+ st.sidebar.image("thumbnail.jpg")
15
+ st.sidebar.markdown("""
16
+ We have fine-tuned the CLIP model (see [Model card](https://huggingface.co/flax-community/clip-rsicd-v2))
17
+ using remote sensing images and captions from the [RSICD dataset](https://github.com/201528014227051/RSICD_optimal).
18
+ The CLIP model from OpenAI is trained in a self-supervised manner using contrastive learning to project images
19
+ and caption text onto a common embedding space.
20
+
21
+ Please click here for [more information about our project](https://github.com/arampacha/CLIP-rsicd).
22
+
23
+ """)
24
+ selection = st.sidebar.radio("Go to", list(PAGES.keys()))
25
+ page = PAGES[selection]
26
+ page.app()
dashboard_featurefinder.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import flax
3
+ import matplotlib.pyplot as plt
4
+ import nmslib
5
+ import numpy as np
6
+ import os
7
+ import requests
8
+ import streamlit as st
9
+
10
+ from tempfile import NamedTemporaryFile
11
+ from torchvision.transforms import Compose, Resize, ToPILImage
12
+ from transformers import CLIPProcessor, FlaxCLIPModel
13
+ from PIL import Image
14
+
15
+ import utils
16
+
17
+ BASELINE_MODEL = "openai/clip-vit-base-patch32"
18
+ MODEL_PATH = "flax-community/clip-rsicd-v2"
19
+
20
+ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
21
+
22
+ IMAGES_DIR = "./images"
23
+ DEMO_IMAGES_DIR = "./demo-images"
24
+
25
+
26
+ def split_image(X):
27
+ num_rows = X.shape[0] // 224
28
+ num_cols = X.shape[1] // 224
29
+ Xc = X[0 : num_rows * 224, 0 : num_cols * 224, :]
30
+ patches = []
31
+ for j in range(num_rows):
32
+ for i in range(num_cols):
33
+ patches.append(Xc[j * 224 : (j + 1) * 224,
34
+ i * 224 : (i + 1) * 224,
35
+ :])
36
+ return num_rows, num_cols, patches
37
+
38
+
39
+ def get_patch_probabilities(patches, searched_feature,
40
+ image_preprocesor,
41
+ model, processor):
42
+ images = [image_preprocesor(patch) for patch in patches]
43
+ text = "An aerial image of {:s}".format(searched_feature)
44
+ inputs = processor(images=images,
45
+ text=text,
46
+ return_tensors="jax",
47
+ padding=True)
48
+ outputs = model(**inputs)
49
+ probs = jax.nn.softmax(outputs.logits_per_text, axis=-1)
50
+ probs_np = np.asarray(probs)[0]
51
+ return probs_np
52
+
53
+
54
+ def get_image_ranks(probs):
55
+ temp = np.argsort(-probs)
56
+ ranks = np.empty_like(temp)
57
+ ranks[temp] = np.arange(len(probs))
58
+ return ranks
59
+
60
+
61
+ def download_and_prepare_image(image_url):
62
+ """
63
+ Take input image and resize it to 672x896
64
+ """
65
+ try:
66
+ image_raw = requests.get(image_url, stream=True,).raw
67
+ image = Image.open(image_raw).convert("RGB")
68
+ width, height = image.size
69
+ # print("WID,HGT:", width, height)
70
+ if width < 224 or height < 224:
71
+ return None
72
+ # take the short edge and reduce to 672
73
+ if width < height:
74
+ resize_factor = 672 / width
75
+ image = image.resize((672, int(height * resize_factor)))
76
+ image = image.crop((0, 0, 672, 896))
77
+ else:
78
+ resize_factor = 672 / height
79
+ image = image.resize((int(width * resize_factor), 896))
80
+ image = image.crop((0, 0, 896, 672))
81
+ return np.asarray(image)
82
+ except Exception as e:
83
+ # print(e)
84
+ return None
85
+
86
+
87
+
88
+ def app():
89
+ model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
90
+
91
+ st.title("Find Features in Images")
92
+ st.markdown("""
93
+ This demo shows the ability of the model to find specific features
94
+ (specified as text queries) in the image. As an example, say you wish to
95
+ find the parts of the following image that contain a `beach`, `houses`,
96
+ or `ships`. We partition the image into tiles of (224, 224) and report
97
+ how likely each of them are to contain each text features.
98
+ """)
99
+ st.image("demo-images/st_tropez_1.png")
100
+ st.image("demo-images/st_tropez_2.png")
101
+ st.markdown("""
102
+ For this image and the queries listed above, our model reports that the
103
+ two left tiles are most likely to contain a `beach`, the two top right
104
+ tiles are most likely to contain `houses`, and the two bottom right tiles
105
+ are likely to contain `boats`.
106
+
107
+ We have provided a few representative images from [Unsplash](https://unsplash.com/s/photos/aerial-view)
108
+ that you can experiment with. Use the image name to put in an initial feature
109
+ to look for, this will show the original image, and you will get more ideas
110
+ for features that you can ask the model to identify.
111
+ """)
112
+ image_file = st.selectbox(
113
+ "Sample Image File",
114
+ options=[
115
+ "-- select one --",
116
+ "St-Tropez-Port.jpg",
117
+ "Acopulco-Bay.jpg",
118
+ "Highway-through-Forest.jpg",
119
+ "Forest-with-River.jpg",
120
+ "Eagle-Bay-Coastline.jpg",
121
+ "Multistoreyed-Buildings.jpg",
122
+ "Street-View-Malayasia.jpg",
123
+ ])
124
+ image_url = st.text_input(
125
+ "OR provide an image URL",
126
+ value="https://static.eos.com/wp-content/uploads/2019/04/Main.jpg")
127
+ searched_feature = st.text_input("Feature to find", value="beach")
128
+
129
+ if st.button("Find"):
130
+ if image_file.startswith("--"):
131
+ image = download_and_prepare_image(image_url)
132
+ else:
133
+ image = plt.imread(os.path.join("demo-images", image_file))
134
+
135
+ if image is None:
136
+ st.error("Image could not be downloaded, please try another one")
137
+ else:
138
+ st.image(image, caption="Input Image")
139
+ st.markdown("---")
140
+ num_rows, num_cols, patches = split_image(image)
141
+ image_preprocessor = Compose([
142
+ ToPILImage(),
143
+ Resize(224)
144
+ ])
145
+ num_rows, num_cols, patches = split_image(image)
146
+ patch_probs = get_patch_probabilities(
147
+ patches,
148
+ searched_feature,
149
+ image_preprocessor,
150
+ model,
151
+ processor)
152
+ patch_ranks = get_image_ranks(patch_probs)
153
+ pid = 0
154
+ for i in range(num_rows):
155
+ cols = st.columns(num_cols)
156
+ for col in cols:
157
+ caption = "#{:d} p({:s})={:.3f}".format(
158
+ patch_ranks[pid] + 1, searched_feature, patch_probs[pid])
159
+ col.image(patches[pid], caption=caption)
160
+ pid += 1
dashboard_image2image.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import nmslib
3
+ import numpy as np
4
+ import os
5
+ import requests
6
+ import streamlit as st
7
+
8
+ from PIL import Image
9
+ from transformers import CLIPProcessor, FlaxCLIPModel
10
+
11
+ import utils
12
+
13
+ BASELINE_MODEL = "openai/clip-vit-base-patch32"
14
+ MODEL_PATH = "flax-community/clip-rsicd-v2"
15
+ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
16
+ IMAGES_DIR = "./images"
17
+ CAPTIONS_FILE = os.path.join(IMAGES_DIR, "test-captions.json")
18
+
19
+ @st.cache(allow_output_mutation=True)
20
+ def load_example_images():
21
+ example_images = {}
22
+ image_names = os.listdir(IMAGES_DIR)
23
+ for image_name in image_names:
24
+ if image_name.find("_") < 0:
25
+ continue
26
+ image_class = image_name.split("_")[0]
27
+ if image_class in example_images.keys():
28
+ example_images[image_class].append(image_name)
29
+ else:
30
+ example_images[image_class] = [image_name]
31
+ example_image_list = sorted([v[np.random.randint(0, len(v))]
32
+ for k, v in example_images.items()][0:10])
33
+ return example_image_list
34
+
35
+
36
+ def get_image_thumbnail(image_filename):
37
+ image = Image.open(os.path.join(IMAGES_DIR, image_filename))
38
+ image = image.resize((100, 100))
39
+ return image
40
+
41
+
42
+ def download_and_prepare_image(image_url):
43
+ try:
44
+ image_raw = requests.get(image_url, stream=True,).raw
45
+ image = Image.open(image_raw).convert("RGB")
46
+ width, height = image.size
47
+ resize_mult = width / 224 if width < height else height / 224
48
+ image = image.resize((int(width // resize_mult),
49
+ int(height // resize_mult)))
50
+ width, height = image.size
51
+ left = int((width - 224) // 2)
52
+ top = int((height - 224) // 2)
53
+ right = int((width + 224) // 2)
54
+ bottom = int((height + 224) // 2)
55
+ image = image.crop((left, top, right, bottom))
56
+ return image
57
+ except Exception as e:
58
+ return None
59
+
60
+ def app():
61
+ filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
62
+ model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
63
+ image2caption = utils.load_captions(CAPTIONS_FILE)
64
+
65
+ example_image_list = load_example_images()
66
+
67
+ st.title("Retrieve Images given Images")
68
+ st.markdown("""
69
+ This demo shows the image to image retrieval capabilities of this model, i.e.,
70
+ given an image file name as a query, we use our fine-tuned CLIP model
71
+ to project the query image to the image/caption embedding space and search
72
+ for nearby images (by cosine similarity) in this space.
73
+
74
+ Our fine-tuned CLIP model was previously used to generate image vectors for
75
+ our demo, and NMSLib was used for fast vector access.
76
+
77
+ Here are some randomly generated image files from our corpus, that you can
78
+ find similar images for by selecting the button below it. Alternatively you
79
+ can upload your own image from the Internet.
80
+ """)
81
+
82
+ suggest_idx = -1
83
+ col0, col1, col2, col3, col4 = st.columns(5)
84
+ col0.image(get_image_thumbnail(example_image_list[0]))
85
+ col1.image(get_image_thumbnail(example_image_list[1]))
86
+ col2.image(get_image_thumbnail(example_image_list[2]))
87
+ col3.image(get_image_thumbnail(example_image_list[3]))
88
+ col4.image(get_image_thumbnail(example_image_list[4]))
89
+ col0t, col1t, col2t, col3t, col4t = st.columns(5)
90
+ with col0t:
91
+ if st.button("Image-1"):
92
+ suggest_idx = 0
93
+ with col1t:
94
+ if st.button("Image-2"):
95
+ suggest_idx = 1
96
+ with col2t:
97
+ if st.button("Image-3"):
98
+ suggest_idx = 2
99
+ with col3t:
100
+ if st.button("Image-4"):
101
+ suggest_idx = 3
102
+ with col4t:
103
+ if st.button("Image-5"):
104
+ suggest_idx = 4
105
+ col5, col6, col7, col8, col9 = st.columns(5)
106
+ col5.image(get_image_thumbnail(example_image_list[5]))
107
+ col6.image(get_image_thumbnail(example_image_list[6]))
108
+ col7.image(get_image_thumbnail(example_image_list[7]))
109
+ col8.image(get_image_thumbnail(example_image_list[8]))
110
+ col9.image(get_image_thumbnail(example_image_list[9]))
111
+ col5t, col6t, col7t, col8t, col9t = st.columns(5)
112
+ with col5t:
113
+ if st.button("Image-6"):
114
+ suggest_idx = 5
115
+ with col6t:
116
+ if st.button("Image-7"):
117
+ suggest_idx = 6
118
+ with col7t:
119
+ if st.button("Image-8"):
120
+ suggest_idx = 7
121
+ with col8t:
122
+ if st.button("Image-9"):
123
+ suggest_idx = 8
124
+ with col9t:
125
+ if st.button("Image-10"):
126
+ suggest_idx = 9
127
+
128
+ image_url = st.text_input(
129
+ "OR provide an image URL",
130
+ value="https://static.eos.com/wp-content/uploads/2019/04/Main.jpg")
131
+
132
+ submit_button = st.button("Find Similar")
133
+
134
+ if submit_button or suggest_idx > -1:
135
+ image_name = None
136
+ if suggest_idx > -1:
137
+ image_name = example_image_list[suggest_idx]
138
+ image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_name)))
139
+ else:
140
+ image = download_and_prepare_image(image_url)
141
+ st.image(image, caption="Input Image")
142
+ st.markdown("---")
143
+
144
+ if image is None:
145
+ st.error("Image could not be downloaded, please try another one!")
146
+ else:
147
+ inputs = processor(images=image, return_tensors="jax", padding=True)
148
+ query_vec = model.get_image_features(**inputs)
149
+ query_vec = np.asarray(query_vec)
150
+ ids, distances = index.knnQuery(query_vec, k=11)
151
+ result_filenames = [filenames[id] for id in ids]
152
+ rank = 0
153
+ for result_filename, score in zip(result_filenames, distances):
154
+ if image_name is not None and result_filename == image_name:
155
+ continue
156
+ caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
157
+ col1, col2, col3 = st.columns([2, 10, 10])
158
+ col1.markdown("{:d}.".format(rank + 1))
159
+ col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
160
+ caption=caption)
161
+ caption_text = []
162
+ for caption in image2caption[result_filename]:
163
+ caption_text.append("* {:s}\n".format(caption))
164
+ col3.markdown("".join(caption_text))
165
+ rank += 1
166
+ st.markdown("---")
167
+ suggest_idx = -1
dashboard_text2image.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import nmslib
3
+ import numpy as np
4
+ import os
5
+ import streamlit as st
6
+
7
+ from PIL import Image
8
+ from transformers import CLIPProcessor, FlaxCLIPModel
9
+
10
+ import utils
11
+
12
+ BASELINE_MODEL = "openai/clip-vit-base-patch32"
13
+ MODEL_PATH = "flax-community/clip-rsicd-v2"
14
+ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
15
+ IMAGES_DIR = "./images"
16
+ CAPTIONS_FILE = os.path.join(IMAGES_DIR, "test-captions.json")
17
+
18
+ def app():
19
+ filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
20
+ model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
21
+ image2caption = utils.load_captions(CAPTIONS_FILE)
22
+
23
+ st.title("Retrieve Images given Text")
24
+ st.markdown("""
25
+ This demo shows the image to text retrieval capabilities of this model, i.e.,
26
+ given a text query, we use our fine-tuned CLIP model to project the text query
27
+ to the image/caption embedding space and search for nearby images (by
28
+ cosine similarity) in this space.
29
+
30
+ Our fine-tuned CLIP model was previously used to generate image vectors for
31
+ our demo, and NMSLib was used for fast vector access.
32
+
33
+ """)
34
+ suggested_query = [
35
+ "ships",
36
+ "school house",
37
+ "military installation",
38
+ "mountains",
39
+ "beaches",
40
+ "airports",
41
+ "lakes"
42
+ ]
43
+ st.text("Some suggested queries to start you off with...")
44
+ col0, col1, col2, col3, col4, col5, col6 = st.columns(7)
45
+ # [1, 1.1, 1.3, 1.1, 1, 1, 1])
46
+ suggest_idx = -1
47
+ with col0:
48
+ if st.button(suggested_query[0]):
49
+ suggest_idx = 0
50
+ with col1:
51
+ if st.button(suggested_query[1]):
52
+ suggest_idx = 1
53
+ with col2:
54
+ if st.button(suggested_query[2]):
55
+ suggest_idx = 2
56
+ with col3:
57
+ if st.button(suggested_query[3]):
58
+ suggest_idx = 3
59
+ with col4:
60
+ if st.button(suggested_query[4]):
61
+ suggest_idx = 4
62
+ with col5:
63
+ if st.button(suggested_query[5]):
64
+ suggest_idx = 5
65
+ with col6:
66
+ if st.button(suggested_query[6]):
67
+ suggest_idx = 6
68
+ query = st.text_input("OR enter a text Query:")
69
+ query = suggested_query[suggest_idx] if suggest_idx > -1 else query
70
+
71
+ if st.button("Query") or suggest_idx > -1:
72
+ inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
73
+ query_vec = model.get_text_features(**inputs)
74
+ query_vec = np.asarray(query_vec)
75
+ ids, distances = index.knnQuery(query_vec, k=10)
76
+ result_filenames = [filenames[id] for id in ids]
77
+ for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
78
+ caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
79
+ col1, col2, col3 = st.columns([2, 10, 10])
80
+ col1.markdown("{:d}.".format(rank + 1))
81
+ col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
82
+ caption=caption)
83
+ caption_text = []
84
+ for caption in image2caption[result_filename]:
85
+ caption_text.append("* {:s}\n".format(caption))
86
+ col3.markdown("".join(caption_text))
87
+ st.markdown("---")
88
+ suggest_idx = -1
demo-image-encoder.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import json
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import requests
8
+ import os
9
+
10
+ from PIL import Image
11
+ from transformers import CLIPProcessor, FlaxCLIPModel
12
+
13
+
14
+ def encode_image(image_file, model, processor):
15
+ image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file)))
16
+ inputs = processor(images=image, return_tensors="jax")
17
+ image_vec = model.get_image_features(**inputs)
18
+ return np.array(image_vec).reshape(-1)
19
+
20
+
21
+ DATA_DIR = "/home/shared/data"
22
+ IMAGES_DIR = os.path.join(DATA_DIR, "rsicd_images")
23
+ CAPTIONS_FILE = os.path.join(DATA_DIR, "dataset_rsicd.json")
24
+ VECTORS_DIR = os.path.join(DATA_DIR, "vectors")
25
+ BASELINE_MODEL = "openai/clip-vit-base-patch32"
26
+
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("model_dir", help="Path to model to use for encoding")
29
+ args = parser.parse_args()
30
+
31
+ print("Loading image list...", end="")
32
+ image2captions = {}
33
+ with open(CAPTIONS_FILE, "r") as fcap:
34
+ data = json.loads(fcap.read())
35
+ for image in data["images"]:
36
+ if image["split"] == "test":
37
+ filename = image["filename"]
38
+ sentences = []
39
+ for sentence in image["sentences"]:
40
+ sentences.append(sentence["raw"])
41
+ image2captions[filename] = sentences
42
+
43
+ print("{:d} images".format(len(image2captions)))
44
+
45
+
46
+ print("Loading model...")
47
+ if args.model_dir == "baseline":
48
+ model = FlaxCLIPModel.from_pretrained(BASELINE_MODEL)
49
+ else:
50
+ model = FlaxCLIPModel.from_pretrained(args.model_dir)
51
+ processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
52
+
53
+
54
+ model_basename = "-".join(args.model_dir.split("/")[-2:])
55
+ vector_file = os.path.join(VECTORS_DIR, "test-{:s}.tsv".format(model_basename))
56
+ print("Vectors written to {:s}".format(vector_file))
57
+ num_written = 0
58
+ fvec = open(vector_file, "w")
59
+ for image_file in image2captions.keys():
60
+ if num_written % 100 == 0:
61
+ print("{:d} images processed".format(num_written))
62
+ image_vec = encode_image(image_file, model, processor)
63
+ image_vec_s = ",".join(["{:.7e}".format(x) for x in image_vec])
64
+ fvec.write("{:s}\t{:s}\n".format(image_file, image_vec_s))
65
+ num_written += 1
66
+
67
+ print("{:d} images processed, COMPLETE".format(num_written))
68
+ fvec.close()
69
+
demo-images/Acopulco-Bay.jpg ADDED
demo-images/Eagle-Bay-Coastline.jpg ADDED
demo-images/Forest-with-River.jpg ADDED
demo-images/Highway-through-Forest.jpg ADDED
demo-images/Multistoreyed-Buildings.jpg ADDED
demo-images/St-Tropez-Port.jpg ADDED
demo-images/Street-View-Malayasia.jpg ADDED
demo-images/st_tropez_1.png ADDED
demo-images/st_tropez_2.png ADDED
images/00623.jpg ADDED
images/00624.jpg ADDED
images/00625.jpg ADDED
images/00626.jpg ADDED
images/00627.jpg ADDED
images/00628.jpg ADDED
images/00629.jpg ADDED
images/00630.jpg ADDED
images/00631.jpg ADDED
images/00632.jpg ADDED
images/00633.jpg ADDED
images/00634.jpg ADDED
images/00635.jpg ADDED
images/00636.jpg ADDED
images/00637.jpg ADDED
images/00638.jpg ADDED
images/00639.jpg ADDED
images/00640.jpg ADDED
images/00641.jpg ADDED
images/00642.jpg ADDED
images/00643.jpg ADDED
images/00644.jpg ADDED
images/00646.jpg ADDED
images/00647.jpg ADDED
images/00650.jpg ADDED
images/00651.jpg ADDED
images/00652.jpg ADDED
images/00653.jpg ADDED
images/00655.jpg ADDED
images/00656.jpg ADDED
images/00657.jpg ADDED
images/00658.jpg ADDED
images/00659.jpg ADDED
images/00660.jpg ADDED