import base64 import os import sys import csv import spaces import glob import numpy as np import gradio as gr import rasterio as rio import matplotlib.pyplot as plt import matplotlib as mpl from io import BytesIO from pathlib import Path from PIL import Image from matplotlib import rcParams from msclip.inference import run_inference_classification from msclip.inference.utils import build_model rcParams["font.size"] = 9 rcParams["axes.titlesize"] = 9 IMG_PX = 300 csv.field_size_limit(sys.maxsize) # Init Llama3-MS-CLIP from Hugging Face model, preprocess, tokenizer = build_model() EXAMPLES = { "EuroSAT": { "images": glob.glob("examples/eurosat/*.tif"), "classes": [ "Annual crop", "Forest", "Herbaceous vegetation", "Highway", "Industrial", "Pasture", "Permanent crop", "Residential", "River", "Sea lake" ] }, "Meter-ML": { "images": glob.glob("examples/meterml/*.tif"), "classes": [ "Concentrated animal feeding operations", "Landfills", "Coal mines", "Other features", "Natural gas processing plants", "Oil refineries and petroleum terminals", "Wastewater treatment plants", ] }, "TerraMesh": { "images": glob.glob("examples/terramesh/*.tif"), "classes": [ "Village", "Beach", "River", "Ice", "Fields", "Mountains", "Desert" ] }, } pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors] def build_colormap(class_names): return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))} def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000): """ array: numpy array with dimensions [C, H, W] returns 0-1 scaled array """ # Get scaling thresholds for smoothing the brightness limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance]) limit_high = limit_high.clip(default) # Scale only pixels above default value limit_low = limit_low.clip(0, 1000) # Scale only pixels below 1000 limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already # Smooth very dark and bright values using linear scaling array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling) array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling) # Update scaling params using a 10th of the tolerance for max value limit_low, limit_high = np.quantile(array, q=[tolerance / 10, 1. - tolerance / 10]) limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value limit_low = limit_low.clip(0, 500) # Scale only pixels below 500 limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already # Scale data to 0-255 array = (array - limit_low) / (limit_high - limit_low) return array def _s2_to_rgb(data, smooth_quantiles=True): # Select if data.shape[0] > 13: # assuming channel last rgb = data[:, :, [3, 2, 1]] else: # assuming channel first rgb = data[[3, 2, 1]].transpose((1, 2, 0)) if smooth_quantiles: rgb = _rgb_smooth_quantiles(rgb) else: rgb = rgb / 2000 # to uint8 rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8) return rgb def _img_to_b64(path: str | Path) -> str: """Encode image as base64 (optionally downsized).""" with rio.open(path) as src: data = src.read() rgb = _s2_to_rgb(data) img = Image.fromarray(rgb) side = max(img.size) # create square canvas, paste centred, then resize canvas = Image.new("RGB", (side, side), (255, 255, 255)) canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2)) canvas = canvas.resize((IMG_PX, IMG_PX)) buf = BytesIO() canvas.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode() def _bar_chart(top_scores, img_name, cmap) -> str: scores = top_scores.values.tolist() labels = top_scores.index.tolist() while len(scores) < 3: scores.append(0) labels.append("") fig, ax = plt.subplots(figsize=(3, 1)) y_pos = np.arange(3) colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0) for cls, val in zip(labels, scores)] ax.barh(y_pos, scores, height=0.7, color=colors) ax.set_xlim(0, 1) ax.invert_yaxis() ax.axis("off") img_name = os.path.splitext(img_name)[0] if len(img_name) > 25: img_name = img_name[:23] + "..." ax.set_title(img_name) for i, (cls, val) in enumerate(zip(labels, scores)): if len(cls) > 25: cls = cls[:23] + "..." if val > 0: # skip padded rows ax.text(0.02, i + 0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center") buf = BytesIO() fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True) plt.close(fig) b64 = base64.b64encode(buf.getvalue()).decode() return f'' @spaces.GPU def classify(images, class_text): class_names = [c.strip() for c in class_text.split(",") if c.strip()] cards = [] df = run_inference_classification( model=model, preprocess=preprocess, tokenizer=tokenizer, image_path=images, class_names=class_names, verbose=False ) for img_path, (id, row) in zip(images, df.iterrows()): scores = row[2:].astype(float) # drop filename column top = scores.sort_values(ascending=False)[:3] top = top[top > 0.01] # filter low scores cmap = build_colormap(class_names) cards.append(f"""
{_bar_chart(top, os.path.basename(img_path), cmap)}
""") return ( "
" + "".join(cards) + "
" ) # Cache examples terramesh_html = classify(EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"])) eurosat_html = classify(EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"])) meterml_html = classify(EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"])) def load_eurosat_example(): return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]), eurosat_html def load_meterml_example(): return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]), meterml_html def load_terramesh_example(): return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]), terramesh_html # UI with gr.Blocks( css=""" .gradio-container #result_box, #result_box.gr-skeleton {min-height:280px !important;} """) as demo: gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP") gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. " "You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). " "We provide three sets of example images with class names and cached outputs. " "The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). " "The images are classified based on the similarity between the images embeddings and text embeddings. " "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ") with gr.Row(): img_in = gr.File( label="Upload S-2 images", file_count="multiple", type="filepath" ) cls_in = gr.Textbox( value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]), label="Class names (comma‑separated)", ) run_btn = gr.Button("Classify", variant="primary") # Examples gr.Markdown("#### Load examples") with gr.Row(): btn_terramesh = gr.Button("TerraMesh") btn_eurosat = gr.Button("EuroSAT") btn_meterml = gr.Button("Meter-ML") out_html = gr.HTML(label="Results", elem_id="result_box", min_height=280) run_btn.click(classify, inputs=[img_in, cls_in], outputs=out_html) btn_terramesh.click( load_terramesh_example, outputs=[img_in, cls_in, out_html], ) btn_eurosat.click( load_eurosat_example, outputs=[img_in, cls_in, out_html], ) btn_meterml.click( load_meterml_example, outputs=[img_in, cls_in, out_html], ) if __name__ == "__main__": demo.launch(share=True, ssr_mode=False)