blumenstiel commited on
Commit
9d4302e
·
1 Parent(s): a0a9e81
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. README.md +4 -4
  3. app.py +248 -0
  4. requirements.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
  title: Llama3 MS CLIP Demo
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Zero-Chot Classification with Llama3-MS-CLIP
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Llama3 MS CLIP Demo
3
+ emoji: 🛰️
4
+ colorFrom: cyan
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.38.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Zero-Shot Classification with Llama3-MS-CLIP
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import base64
3
+ from io import BytesIO
4
+ from pathlib import Path
5
+
6
+ import glob
7
+ import numpy as np
8
+ import gradio as gr
9
+ import rasterio as rio
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib as mpl
12
+ from PIL import Image
13
+ from matplotlib import rcParams
14
+ from msclip.inference import run_inference_classification
15
+
16
+ rcParams["font.size"] = 9
17
+ IMG_PX = 300
18
+
19
+ EXAMPLES = {
20
+ "EuroSAT": {
21
+ "images": glob.glob("examples/eurosat/*.tif"),
22
+ "classes": [
23
+ "AnnualCrop","Forest","HerbaceousVegetation","Highway","Industrial",
24
+ "Pasture","PermanentCrop","Residential","River","SeaLake"
25
+ ]
26
+ },
27
+ "Meter-ML": {
28
+ "images": glob.glob("examples/meterml/*.tif"),
29
+ "classes": [
30
+ "Todo"
31
+ ]
32
+ },
33
+ "TerraMesh": {
34
+ "images": glob.glob("examples/terramesh/*.tif"),
35
+ "classes": [
36
+ "Agriculture", "Beach", "River", "Ice", "Fields"
37
+ ]
38
+ },
39
+ }
40
+
41
+
42
+ def load_eurosat_example():
43
+ return EXAMPLES["EuroSAT"]["images"], ",".join(EXAMPLES["EuroSAT"]["classes"])
44
+
45
+
46
+ def load_meterml_example():
47
+ return EXAMPLES["Meter-ML"]["images"], ",".join(EXAMPLES["Meter-ML"]["classes"])
48
+
49
+
50
+ def load_terramesh_example():
51
+ return EXAMPLES["TerraMesh"]["images"], ",".join(EXAMPLES["TerraMesh"]["classes"])
52
+
53
+
54
+ pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]
55
+
56
+ def build_colormap(class_names):
57
+ return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))}
58
+
59
+
60
+ def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000):
61
+ """
62
+ array: numpy array with dimensions [C, H, W]
63
+ returns 0-1 scaled array
64
+ """
65
+
66
+ # Get scaling thresholds for smoothing the brightness
67
+ limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance])
68
+ limit_high = limit_high.clip(default) # Scale only pixels above default value
69
+ limit_low = limit_low.clip(0, 1000) # Scale only pixels below 1000
70
+ limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
71
+
72
+ # Smooth very dark and bright values using linear scaling
73
+ array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling)
74
+ array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling)
75
+
76
+ # Update scaling params using a 10th of the tolerance for max value
77
+ limit_low, limit_high = np.quantile(array, q=[tolerance/10, 1. - tolerance/10])
78
+ limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value
79
+ limit_low = limit_low.clip(0, 500) # Scale only pixels below 500
80
+ limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
81
+
82
+ # Scale data to 0-255
83
+ array = (array - limit_low) / (limit_high - limit_low)
84
+
85
+ return array
86
+
87
+
88
+ def _s2_to_rgb(data, smooth_quantiles=True):
89
+ # Select
90
+ if data.shape[0] > 13:
91
+ # assuming channel last
92
+ rgb = data[:, :, [3, 2, 1]]
93
+ else:
94
+ # assuming channel first
95
+ rgb = data[[3, 2, 1]].transpose((1, 2, 0))
96
+
97
+ if smooth_quantiles:
98
+ rgb = _rgb_smooth_quantiles(rgb)
99
+ else:
100
+ rgb = rgb / 2000
101
+
102
+ # to uint8
103
+ rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8)
104
+
105
+ return rgb
106
+
107
+
108
+ def _img_to_b64(path: str | Path) -> str:
109
+ """Encode image as base64 (optionally downsized)."""
110
+ with rio.open(path) as src:
111
+ data = src.read()
112
+ rgb = _s2_to_rgb(data)
113
+ img = Image.fromarray(rgb)
114
+ side = max(img.size)
115
+ # create square canvas, paste centred, then resize
116
+ canvas = Image.new("RGB", (side, side), (255, 255, 255))
117
+ canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2))
118
+ canvas = canvas.resize((IMG_PX, IMG_PX))
119
+ buf = BytesIO()
120
+ canvas.save(buf, format="PNG")
121
+ return base64.b64encode(buf.getvalue()).decode()
122
+
123
+
124
+ def _bar_chart(top_scores, cmap) -> str:
125
+ scores = top_scores.values.tolist()
126
+ labels = top_scores.index.tolist()
127
+ while len(scores) < 3:
128
+ scores.append(0)
129
+ labels.append("")
130
+
131
+ fig, ax = plt.subplots(figsize=(3, 1))
132
+ y_pos = np.arange(3)
133
+
134
+ colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0)
135
+ for cls, val in zip(labels, scores)]
136
+
137
+ ax.barh(y_pos, scores, height=0.7, color=colors)
138
+ ax.set_xlim(0, 1)
139
+ ax.invert_yaxis()
140
+ ax.axis("off")
141
+
142
+ for i, (cls, val) in enumerate(zip(labels, scores)):
143
+ if val > 0: # skip padded rows
144
+ ax.text(0.02, i+0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center")
145
+
146
+ buf = BytesIO()
147
+ fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True)
148
+ plt.close(fig)
149
+ b64 = base64.b64encode(buf.getvalue()).decode()
150
+ return f'<img src="data:image/png;base64,{b64}" style="display:block;margin:auto;width:{IMG_PX}px;" />'
151
+
152
+
153
+ def classify(images, class_text):
154
+ class_names = [c.strip() for c in class_text.split(",") if c.strip()]
155
+ cards = []
156
+
157
+ df = run_inference_classification(image_path=images, class_names=class_names) # one row per call
158
+ for img_path, (id, row) in zip(images, df.iterrows()):
159
+ scores = row[2:].astype(float) # drop filename column
160
+ top = scores.sort_values(ascending=False)[:3]
161
+ top = top[top > 0.01] # filter low scores
162
+ cmap = build_colormap(class_names)
163
+
164
+ cards.append(f"""
165
+ <div style="width:{IMG_PX}px;margin:18px auto;text-align:left;">
166
+ <img src="data:image/png;base64,{_img_to_b64(img_path)}"
167
+ style="width:{IMG_PX}px;height:{IMG_PX}px;object-fit:cover;
168
+ border-radius:8px;box-shadow:0 2px 6px rgba(0,0,0,.15);display:block;margin:auto;">
169
+ {_bar_chart(top, cmap)}
170
+ </div>""")
171
+
172
+ return (
173
+ "<div style='display:flex;flex-wrap:wrap;justify-content:center;'>"
174
+ + "".join(cards)
175
+ + "</div>"
176
+ )
177
+
178
+
179
+ # UI
180
+
181
+ DEFAULT_CLASSES = ["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]
182
+
183
+ # with gr.Blocks(css=".gradio-container") as demo:
184
+ with gr.Blocks(
185
+ css="""
186
+ .gradio-container
187
+ #result_box,
188
+ #result_box.gr-skeleton {min-height:280px !important;}
189
+ """) as demo:
190
+ gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP")
191
+ gr.Markdown("Provide Sentinel-2 tif files with all 12 or 13 bands and define the class names. "
192
+ "You can also load one of the three provided example sets with class names that you can modify. The example images are comming from [EuroSAT](), [Meter-ML](), and [TerraMesh](). "
193
+ "The images are classified based on the similarity between the images embeddings and text embeddings. "
194
+ "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ")
195
+ with gr.Row():
196
+ img_in = gr.File(
197
+ label="Upload S-2 images", file_count="multiple", type="filepath"
198
+ )
199
+ cls_in = gr.Textbox(
200
+ value=", ".join(DEFAULT_CLASSES),
201
+ label="Class names (comma‑separated)",
202
+ )
203
+
204
+ run_btn = gr.Button("Classify", variant="primary")
205
+
206
+ # Examples
207
+ gr.Markdown("#### Load examples")
208
+ with gr.Row():
209
+ btn_terramesh = gr.Button("TerraMesh")
210
+ btn_eurosat = gr.Button("EuroSAT")
211
+ btn_meterml = gr.Button("Meter-ML")
212
+
213
+ out_html = gr.HTML(label="Results",
214
+ elem_id="result_box",
215
+ min_height=280)
216
+
217
+ run_btn.click(classify, inputs=[img_in, cls_in], outputs=out_html)
218
+
219
+ btn_terramesh.click(
220
+ load_terramesh_example,
221
+ outputs=[img_in, cls_in],
222
+ ).then(
223
+ classify,
224
+ inputs=[img_in, cls_in],
225
+ outputs=out_html,
226
+ )
227
+
228
+ btn_eurosat.click(
229
+ load_eurosat_example,
230
+ outputs=[img_in, cls_in],
231
+ ).then(
232
+ classify,
233
+ inputs=[img_in, cls_in],
234
+ outputs=out_html,
235
+ )
236
+
237
+ btn_meterml.click(
238
+ load_meterml_example,
239
+ outputs=[img_in, cls_in],
240
+ ).then(
241
+ classify,
242
+ inputs=[img_in, cls_in],
243
+ outputs=out_html,
244
+ )
245
+
246
+
247
+ if __name__ == "__main__":
248
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.31.0
2
+ plotly
3
+ rasterio
4
+ msclip@git+https://github.com/IBM/MS-CLIP.git
5
+ matplotlib