blumenstiel commited on
Commit
7f0a7e4
·
1 Parent(s): 652cb49

Update examples

Browse files
app.py CHANGED
@@ -1,5 +1,5 @@
1
-
2
  import base64
 
3
  from io import BytesIO
4
  from pathlib import Path
5
 
@@ -14,45 +14,53 @@ from matplotlib import rcParams
14
  from msclip.inference import run_inference_classification
15
 
16
  rcParams["font.size"] = 9
 
17
  IMG_PX = 300
18
 
19
  EXAMPLES = {
20
  "EuroSAT": {
21
  "images": glob.glob("examples/eurosat/*.tif"),
22
  "classes": [
23
- "AnnualCrop","Forest","HerbaceousVegetation","Highway","Industrial",
24
- "Pasture","PermanentCrop","Residential","River","SeaLake"
25
  ]
26
  },
27
  "Meter-ML": {
28
  "images": glob.glob("examples/meterml/*.tif"),
29
  "classes": [
30
- "Todo"
 
 
 
 
 
 
31
  ]
32
  },
33
  "TerraMesh": {
34
  "images": glob.glob("examples/terramesh/*.tif"),
35
  "classes": [
36
- "Agriculture", "Beach", "River", "Ice", "Fields"
37
  ]
38
  },
39
  }
40
 
41
 
42
  def load_eurosat_example():
43
- return EXAMPLES["EuroSAT"]["images"], ",".join(EXAMPLES["EuroSAT"]["classes"])
44
 
45
 
46
  def load_meterml_example():
47
- return EXAMPLES["Meter-ML"]["images"], ",".join(EXAMPLES["Meter-ML"]["classes"])
48
 
49
 
50
  def load_terramesh_example():
51
- return EXAMPLES["TerraMesh"]["images"], ",".join(EXAMPLES["TerraMesh"]["classes"])
52
 
53
 
54
  pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]
55
 
 
56
  def build_colormap(class_names):
57
  return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))}
58
 
@@ -74,7 +82,7 @@ def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000):
74
  array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling)
75
 
76
  # Update scaling params using a 10th of the tolerance for max value
77
- limit_low, limit_high = np.quantile(array, q=[tolerance/10, 1. - tolerance/10])
78
  limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value
79
  limit_low = limit_low.clip(0, 500) # Scale only pixels below 500
80
  limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
@@ -121,7 +129,7 @@ def _img_to_b64(path: str | Path) -> str:
121
  return base64.b64encode(buf.getvalue()).decode()
122
 
123
 
124
- def _bar_chart(top_scores, cmap) -> str:
125
  scores = top_scores.values.tolist()
126
  labels = top_scores.index.tolist()
127
  while len(scores) < 3:
@@ -138,10 +146,16 @@ def _bar_chart(top_scores, cmap) -> str:
138
  ax.set_xlim(0, 1)
139
  ax.invert_yaxis()
140
  ax.axis("off")
 
 
 
 
141
 
142
  for i, (cls, val) in enumerate(zip(labels, scores)):
 
 
143
  if val > 0: # skip padded rows
144
- ax.text(0.02, i+0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center")
145
 
146
  buf = BytesIO()
147
  fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True)
@@ -154,7 +168,7 @@ def classify(images, class_text):
154
  class_names = [c.strip() for c in class_text.split(",") if c.strip()]
155
  cards = []
156
 
157
- df = run_inference_classification(image_path=images, class_names=class_names) # one row per call
158
  for img_path, (id, row) in zip(images, df.iterrows()):
159
  scores = row[2:].astype(float) # drop filename column
160
  top = scores.sort_values(ascending=False)[:3]
@@ -166,21 +180,18 @@ def classify(images, class_text):
166
  <img src="data:image/png;base64,{_img_to_b64(img_path)}"
167
  style="width:{IMG_PX}px;height:{IMG_PX}px;object-fit:cover;
168
  border-radius:8px;box-shadow:0 2px 6px rgba(0,0,0,.15);display:block;margin:auto;">
169
- {_bar_chart(top, cmap)}
170
  </div>""")
171
 
172
  return (
173
- "<div style='display:flex;flex-wrap:wrap;justify-content:center;'>"
174
- + "".join(cards)
175
- + "</div>"
176
  )
177
 
178
 
179
  # UI
180
 
181
- DEFAULT_CLASSES = ["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]
182
-
183
- # with gr.Blocks(css=".gradio-container") as demo:
184
  with gr.Blocks(
185
  css="""
186
  .gradio-container
@@ -188,8 +199,9 @@ with gr.Blocks(
188
  #result_box.gr-skeleton {min-height:280px !important;}
189
  """) as demo:
190
  gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP")
191
- gr.Markdown("Provide Sentinel-2 tif files with all 12 or 13 bands and define the class names. "
192
- "You can also load one of the three provided example sets with class names that you can modify. The example images are comming from [EuroSAT](), [Meter-ML](), and [TerraMesh](). "
 
193
  "The images are classified based on the similarity between the images embeddings and text embeddings. "
194
  "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ")
195
  with gr.Row():
@@ -197,7 +209,8 @@ with gr.Blocks(
197
  label="Upload S-2 images", file_count="multiple", type="filepath"
198
  )
199
  cls_in = gr.Textbox(
200
- value=", ".join(DEFAULT_CLASSES),
 
201
  label="Class names (comma‑separated)",
202
  )
203
 
@@ -243,6 +256,5 @@ with gr.Blocks(
243
  outputs=out_html,
244
  )
245
 
246
-
247
  if __name__ == "__main__":
248
  demo.launch()
 
 
1
  import base64
2
+ import os.path
3
  from io import BytesIO
4
  from pathlib import Path
5
 
 
14
  from msclip.inference import run_inference_classification
15
 
16
  rcParams["font.size"] = 9
17
+ rcParams["axes.titlesize"] = 9
18
  IMG_PX = 300
19
 
20
  EXAMPLES = {
21
  "EuroSAT": {
22
  "images": glob.glob("examples/eurosat/*.tif"),
23
  "classes": [
24
+ "Annual crop", "Forest", "Herbaceous vegetation", "Highway", "Industrial",
25
+ "Pasture", "Permanent crop", "Residential", "River", "Sea lake"
26
  ]
27
  },
28
  "Meter-ML": {
29
  "images": glob.glob("examples/meterml/*.tif"),
30
  "classes": [
31
+ "Concentrated animal feeding operations",
32
+ "Landfills",
33
+ "Coal mines",
34
+ "Other features",
35
+ "Natural gas processing plants",
36
+ "Oil refineries and petroleum terminals",
37
+ "Wastewater treatment plants",
38
  ]
39
  },
40
  "TerraMesh": {
41
  "images": glob.glob("examples/terramesh/*.tif"),
42
  "classes": [
43
+ "Village", "Beach", "River", "Ice", "Fields", "Mountains", "Desert"
44
  ]
45
  },
46
  }
47
 
48
 
49
  def load_eurosat_example():
50
+ return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"])
51
 
52
 
53
  def load_meterml_example():
54
+ return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"])
55
 
56
 
57
  def load_terramesh_example():
58
+ return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"])
59
 
60
 
61
  pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]
62
 
63
+
64
  def build_colormap(class_names):
65
  return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))}
66
 
 
82
  array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling)
83
 
84
  # Update scaling params using a 10th of the tolerance for max value
85
+ limit_low, limit_high = np.quantile(array, q=[tolerance / 10, 1. - tolerance / 10])
86
  limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value
87
  limit_low = limit_low.clip(0, 500) # Scale only pixels below 500
88
  limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
 
129
  return base64.b64encode(buf.getvalue()).decode()
130
 
131
 
132
+ def _bar_chart(top_scores, img_name, cmap) -> str:
133
  scores = top_scores.values.tolist()
134
  labels = top_scores.index.tolist()
135
  while len(scores) < 3:
 
146
  ax.set_xlim(0, 1)
147
  ax.invert_yaxis()
148
  ax.axis("off")
149
+ img_name = os.path.splitext(img_name)[0]
150
+ if len(img_name) > 25:
151
+ img_name = img_name[:23] + "..."
152
+ ax.set_title(img_name)
153
 
154
  for i, (cls, val) in enumerate(zip(labels, scores)):
155
+ if len(cls) > 25:
156
+ cls = cls[:23] + "..."
157
  if val > 0: # skip padded rows
158
+ ax.text(0.02, i + 0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center")
159
 
160
  buf = BytesIO()
161
  fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True)
 
168
  class_names = [c.strip() for c in class_text.split(",") if c.strip()]
169
  cards = []
170
 
171
+ df = run_inference_classification(image_path=images, class_names=class_names, verbose=False)
172
  for img_path, (id, row) in zip(images, df.iterrows()):
173
  scores = row[2:].astype(float) # drop filename column
174
  top = scores.sort_values(ascending=False)[:3]
 
180
  <img src="data:image/png;base64,{_img_to_b64(img_path)}"
181
  style="width:{IMG_PX}px;height:{IMG_PX}px;object-fit:cover;
182
  border-radius:8px;box-shadow:0 2px 6px rgba(0,0,0,.15);display:block;margin:auto;">
183
+ {_bar_chart(top, os.path.basename(img_path), cmap)}
184
  </div>""")
185
 
186
  return (
187
+ "<div style='display:flex;flex-wrap:wrap;justify-content:center;'>"
188
+ + "".join(cards)
189
+ + "</div>"
190
  )
191
 
192
 
193
  # UI
194
 
 
 
 
195
  with gr.Blocks(
196
  css="""
197
  .gradio-container
 
199
  #result_box.gr-skeleton {min-height:280px !important;}
200
  """) as demo:
201
  gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP")
202
+ gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. "
203
+ "You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). "
204
+ "We provide three sets of example images with class names that you can modify. The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). "
205
  "The images are classified based on the similarity between the images embeddings and text embeddings. "
206
  "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ")
207
  with gr.Row():
 
209
  label="Upload S-2 images", file_count="multiple", type="filepath"
210
  )
211
  cls_in = gr.Textbox(
212
+ value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]),
213
+ # some default classes
214
  label="Class names (comma‑separated)",
215
  )
216
 
 
256
  outputs=out_html,
257
  )
258
 
 
259
  if __name__ == "__main__":
260
  demo.launch()
examples/eurosat/AnnualCrop_2515.tif ADDED

Git LFS Details

  • SHA256: a56404f05fc2af9c7cebc9fd1fb8fdf5a00c6887e982d40061356fad1e5d47df
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/eurosat/Forest_321.tif ADDED

Git LFS Details

  • SHA256: 3200b606545961c2b0c855744771cb7160db88c9bda16a0d93177a82d0585549
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/eurosat/Highway_2219.tif ADDED

Git LFS Details

  • SHA256: 4aefc1c3a0e692756ec49892a6a6ff9d4ba49ec8eba33012803bd5896ca5ae70
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/eurosat/Industrial_967.tif ADDED

Git LFS Details

  • SHA256: 72379c66959433855b5e72e2d9041dc4e8d5ce3bff51ec2a60d623404ac04e52
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/eurosat/Pasture_742.tif ADDED

Git LFS Details

  • SHA256: 64c3e966774ed09aef63a0cfbd62b6ecd70a7b75848f23f5dd9ee2c1ad92b6ae
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/eurosat/Residential_1511.tif ADDED

Git LFS Details

  • SHA256: f3c4d978bdb89a445a510a5a49828b8b73668475cf3e16262fb29d7765a9c924
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/eurosat/River_1690.tif ADDED

Git LFS Details

  • SHA256: 94e19db500187220960b974f34d977e908bbe97b8684d4d81990d78d712b63a4
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/eurosat/SeaLake_2589.tif ADDED

Git LFS Details

  • SHA256: 460865c769ce04bbc9b335e1d4c3bfa203cae37769b8953a65ee0a66ddd301d9
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
examples/meterml/CAFOs_30.8439_-90.0443.tif ADDED

Git LFS Details

  • SHA256: 788417d8e932ff148e17913be7c1f69a55a266584f0aa2517e7bee6c34d25dcb
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
examples/meterml/Landfills_31.36143_-86.27454.tif ADDED

Git LFS Details

  • SHA256: 1cb34b93a50495733224d2378a86b581857a6888b829cc9ba2b8a8c10c31a1f1
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
examples/meterml/Mines_34.634_-117.307.tif ADDED

Git LFS Details

  • SHA256: e82ba21b3a665a0f6500220ecdccd4dc3e1b6f95fac5c2c5c63b6359179329d2
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
examples/meterml/Others_30.0523_-89.8799.tif ADDED

Git LFS Details

  • SHA256: 6a6d6cfc05490192d01ea13073f239fbf912863f43397e02a6cc8ccb852a347e
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
examples/meterml/Others_31.286_-91.068.tif ADDED

Git LFS Details

  • SHA256: dd4d28c04dec3ef76100cff4da9e2f9e6506b4fd9940349adadccf97de08cbad
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
examples/meterml/ProcessingPlants_28.7516_-98.0138.tif ADDED

Git LFS Details

  • SHA256: a1e2a3c4e5536b4301a817b726022a3285bf25172174dd43b5a5b62a721f4baa
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
examples/meterml/RefineriesAndTerminals_29.71373_-95.23472.tif ADDED

Git LFS Details

  • SHA256: af02bfdd79525afc76a0b8e108d17e98dbca402f1f748d6339ed8a0cee3cafec
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
examples/meterml/WWTreatment_31.193_-91.012.tif ADDED

Git LFS Details

  • SHA256: 30d2423de08836825eee5a4c5848d83daa0d24dca44c6129a1459ee6e3194ed3
  • Pointer size: 131 Bytes
  • Size of remote file: 127 kB
examples/terramesh/{282D_485L_3_3.tif → majortom_282D_485L_3_3.tif} RENAMED
File without changes
examples/terramesh/{38D_378R_2_3.tif → majortom_38D_378R_2_3.tif} RENAMED
File without changes
examples/terramesh/{433D_629L_3_1.tif → majortom_433D_629L_3_1.tif} RENAMED
File without changes
examples/terramesh/{609U_541L_3_0.tif → majortom_609U_541L_3_0.tif} RENAMED
File without changes
examples/terramesh/{637U_59R_1_3.tif → majortom_637U_59R_1_3.tif} RENAMED
File without changes
examples/terramesh/ssl4eos12_0001175_2.tif ADDED

Git LFS Details

  • SHA256: 60451be0cfd7b1b3b2597d2121fcc0cd59e3803069ad794000cf218bf3677341
  • Pointer size: 132 Bytes
  • Size of remote file: 1.58 MB
examples/terramesh/ssl4eos12_0001732_0.tif ADDED

Git LFS Details

  • SHA256: f769694202e021f912f786f3f072ed391a754a4b713e4ca305e31be5c2bc1d8e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.58 MB
examples/terramesh/ssl4eos12_0028834_3.tif ADDED

Git LFS Details

  • SHA256: 18f2080170d835848a6dafa5ece36d3d5a1bea9894c51a804d5773fd3bafcc48
  • Pointer size: 132 Bytes
  • Size of remote file: 1.58 MB