File size: 9,566 Bytes
9d4302e
ef153ea
 
 
 
9aec2d7
9d4302e
 
 
 
 
 
ef153ea
 
9d4302e
 
 
ef153ea
9d4302e
 
7f0a7e4
9d4302e
 
cbc0399
 
ef153ea
 
 
9d4302e
 
 
 
7f0a7e4
 
9d4302e
 
 
 
 
7f0a7e4
 
 
 
 
 
 
9d4302e
 
 
 
 
7f0a7e4
9d4302e
 
 
 
 
 
 
7f0a7e4
9d4302e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f0a7e4
9d4302e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f0a7e4
9d4302e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f0a7e4
 
 
 
9d4302e
 
7f0a7e4
 
9d4302e
7f0a7e4
9d4302e
 
 
 
 
 
 
9aec2d7
 
9d4302e
 
 
 
ef153ea
 
 
 
 
 
 
 
9d4302e
 
 
 
 
 
 
 
 
 
 
7f0a7e4
9d4302e
 
 
7f0a7e4
 
 
9d4302e
 
 
cbc0399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d4302e
cbc0399
9d4302e
 
 
 
 
 
 
7f0a7e4
 
cbc0399
 
9d4302e
 
 
 
 
 
 
7f0a7e4
9d4302e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbc0399
9d4302e
 
 
 
cbc0399
9d4302e
 
 
 
cbc0399
9d4302e
 
 
325902e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

import base64
import os
import sys
import csv
import spaces
import glob
import numpy as np
import gradio as gr
import rasterio as rio
import matplotlib.pyplot as plt
import matplotlib as mpl
from io import BytesIO
from pathlib import Path
from PIL import Image
from matplotlib import rcParams
from msclip.inference import run_inference_classification
from msclip.inference.utils import build_model

rcParams["font.size"] = 9
rcParams["axes.titlesize"] = 9
IMG_PX = 300

csv.field_size_limit(sys.maxsize)

# Init Llama3-MS-CLIP from Hugging Face
model, preprocess, tokenizer = build_model()

EXAMPLES = {
    "EuroSAT": {
        "images": glob.glob("examples/eurosat/*.tif"),
        "classes": [
            "Annual crop", "Forest", "Herbaceous vegetation", "Highway", "Industrial",
            "Pasture", "Permanent crop", "Residential", "River", "Sea lake"
        ]
    },
    "Meter-ML": {
        "images": glob.glob("examples/meterml/*.tif"),
        "classes": [
            "Concentrated animal feeding operations",
            "Landfills",
            "Coal mines",
            "Other features",
            "Natural gas processing plants",
            "Oil refineries and petroleum terminals",
            "Wastewater treatment plants",
        ]
    },
    "TerraMesh": {
        "images": glob.glob("examples/terramesh/*.tif"),
        "classes": [
            "Village", "Beach", "River", "Ice", "Fields", "Mountains", "Desert"
        ]
    },
}


pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]


def build_colormap(class_names):
    return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))}


def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000):
    """
    array: numpy array with dimensions [C, H, W]
    returns 0-1 scaled array
    """

    # Get scaling thresholds for smoothing the brightness
    limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance])
    limit_high = limit_high.clip(default)  # Scale only pixels above default value
    limit_low = limit_low.clip(0, 1000)  # Scale only pixels below 1000
    limit_low = np.where(median > default / 2, limit_low, 0)  # Make image only darker if it is not dark already

    # Smooth very dark and bright values using linear scaling
    array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling)
    array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling)

    # Update scaling params using a 10th of the tolerance for max value
    limit_low, limit_high = np.quantile(array, q=[tolerance / 10, 1. - tolerance / 10])
    limit_high = limit_high.clip(default, 20000)  # Scale only pixels above default value
    limit_low = limit_low.clip(0, 500)  # Scale only pixels below 500
    limit_low = np.where(median > default / 2, limit_low, 0)  # Make image only darker if it is not dark already

    # Scale data to 0-255
    array = (array - limit_low) / (limit_high - limit_low)

    return array


def _s2_to_rgb(data, smooth_quantiles=True):
    # Select
    if data.shape[0] > 13:
        # assuming channel last
        rgb = data[:, :, [3, 2, 1]]
    else:
        # assuming channel first
        rgb = data[[3, 2, 1]].transpose((1, 2, 0))

    if smooth_quantiles:
        rgb = _rgb_smooth_quantiles(rgb)
    else:
        rgb = rgb / 2000

    # to uint8
    rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8)

    return rgb


def _img_to_b64(path: str | Path) -> str:
    """Encode image as base64 (optionally downsized)."""
    with rio.open(path) as src:
        data = src.read()
    rgb = _s2_to_rgb(data)
    img = Image.fromarray(rgb)
    side = max(img.size)
    # create square canvas, paste centred, then resize
    canvas = Image.new("RGB", (side, side), (255, 255, 255))
    canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2))
    canvas = canvas.resize((IMG_PX, IMG_PX))
    buf = BytesIO()
    canvas.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()


def _bar_chart(top_scores, img_name, cmap) -> str:
    scores = top_scores.values.tolist()
    labels = top_scores.index.tolist()
    while len(scores) < 3:
        scores.append(0)
        labels.append("")

    fig, ax = plt.subplots(figsize=(3, 1))
    y_pos = np.arange(3)

    colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0)
              for cls, val in zip(labels, scores)]

    ax.barh(y_pos, scores, height=0.7, color=colors)
    ax.set_xlim(0, 1)
    ax.invert_yaxis()
    ax.axis("off")
    img_name = os.path.splitext(img_name)[0]
    if len(img_name) > 25:
        img_name = img_name[:23] + "..."
    ax.set_title(img_name)

    for i, (cls, val) in enumerate(zip(labels, scores)):
        if len(cls) > 25:
            cls = cls[:23] + "..."
        if val > 0:  # skip padded rows
            ax.text(0.02, i + 0.03, f"{cls}  ({round(val * 100)}%)", ha="left", va="center")

    buf = BytesIO()
    fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True)
    plt.close(fig)
    b64 = base64.b64encode(buf.getvalue()).decode()
    return f'<img src="data:image/png;base64,{b64}" style="display:block;margin:auto;width:{IMG_PX}px;" />'


@spaces.GPU
def classify(images, class_text):
    class_names = [c.strip() for c in class_text.split(",") if c.strip()]
    cards = []

    df = run_inference_classification(
        model=model,
        preprocess=preprocess,
        tokenizer=tokenizer,
        image_path=images,
        class_names=class_names,
        verbose=False
    )
    for img_path, (id, row) in zip(images, df.iterrows()):
        scores = row[2:].astype(float)  # drop filename column
        top = scores.sort_values(ascending=False)[:3]
        top = top[top > 0.01]  # filter low scores
        cmap = build_colormap(class_names)

        cards.append(f"""
        <div style="width:{IMG_PX}px;margin:18px auto;text-align:left;">
          <img src="data:image/png;base64,{_img_to_b64(img_path)}"
               style="width:{IMG_PX}px;height:{IMG_PX}px;object-fit:cover;
                      border-radius:8px;box-shadow:0 2px 6px rgba(0,0,0,.15);display:block;margin:auto;">
          {_bar_chart(top, os.path.basename(img_path), cmap)}
        </div>""")

    return (
            "<div style='display:flex;flex-wrap:wrap;justify-content:center;'>"
            + "".join(cards)
            + "</div>"
    )


# Cache examples
terramesh_html = classify(EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]))
eurosat_html = classify(EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]))
meterml_html = classify(EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]))


def load_eurosat_example():
    return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]), eurosat_html


def load_meterml_example():
    return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]), meterml_html


def load_terramesh_example():
    return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]), terramesh_html


# UI
with gr.Blocks(
        css="""
            .gradio-container
            #result_box,            
            #result_box.gr-skeleton {min-height:280px !important;}
            """) as demo:
    gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP")
    gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. "
                "You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). "
                "We provide three sets of example images with class names and cached outputs. "
                "The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). "
                "The images are classified based on the similarity between the images embeddings and text embeddings. "
                "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ")
    with gr.Row():
        img_in = gr.File(
            label="Upload S-2 images", file_count="multiple", type="filepath"
        )
        cls_in = gr.Textbox(
            value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]),
            label="Class names (comma‑separated)",
        )

    run_btn = gr.Button("Classify", variant="primary")

    # Examples
    gr.Markdown("#### Load examples")
    with gr.Row():
        btn_terramesh = gr.Button("TerraMesh")
        btn_eurosat = gr.Button("EuroSAT")
        btn_meterml = gr.Button("Meter-ML")

    out_html = gr.HTML(label="Results",
                       elem_id="result_box",
                       min_height=280)

    run_btn.click(classify, inputs=[img_in, cls_in], outputs=out_html)

    btn_terramesh.click(
        load_terramesh_example,
        outputs=[img_in, cls_in, out_html],
    )

    btn_eurosat.click(
        load_eurosat_example,
        outputs=[img_in, cls_in, out_html],
    )

    btn_meterml.click(
        load_meterml_example,
        outputs=[img_in, cls_in, out_html],
    )

if __name__ == "__main__":
    demo.launch(share=True, ssr_mode=False)