blumenstiel's picture
Init models only once
ef153ea
import base64
import os
import sys
import csv
import spaces
import glob
import numpy as np
import gradio as gr
import rasterio as rio
import matplotlib.pyplot as plt
import matplotlib as mpl
from io import BytesIO
from pathlib import Path
from PIL import Image
from matplotlib import rcParams
from msclip.inference import run_inference_classification
from msclip.inference.utils import build_model
rcParams["font.size"] = 9
rcParams["axes.titlesize"] = 9
IMG_PX = 300
csv.field_size_limit(sys.maxsize)
# Init Llama3-MS-CLIP from Hugging Face
model, preprocess, tokenizer = build_model()
EXAMPLES = {
"EuroSAT": {
"images": glob.glob("examples/eurosat/*.tif"),
"classes": [
"Annual crop", "Forest", "Herbaceous vegetation", "Highway", "Industrial",
"Pasture", "Permanent crop", "Residential", "River", "Sea lake"
]
},
"Meter-ML": {
"images": glob.glob("examples/meterml/*.tif"),
"classes": [
"Concentrated animal feeding operations",
"Landfills",
"Coal mines",
"Other features",
"Natural gas processing plants",
"Oil refineries and petroleum terminals",
"Wastewater treatment plants",
]
},
"TerraMesh": {
"images": glob.glob("examples/terramesh/*.tif"),
"classes": [
"Village", "Beach", "River", "Ice", "Fields", "Mountains", "Desert"
]
},
}
pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]
def build_colormap(class_names):
return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))}
def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000):
"""
array: numpy array with dimensions [C, H, W]
returns 0-1 scaled array
"""
# Get scaling thresholds for smoothing the brightness
limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance])
limit_high = limit_high.clip(default) # Scale only pixels above default value
limit_low = limit_low.clip(0, 1000) # Scale only pixels below 1000
limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
# Smooth very dark and bright values using linear scaling
array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling)
array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling)
# Update scaling params using a 10th of the tolerance for max value
limit_low, limit_high = np.quantile(array, q=[tolerance / 10, 1. - tolerance / 10])
limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value
limit_low = limit_low.clip(0, 500) # Scale only pixels below 500
limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
# Scale data to 0-255
array = (array - limit_low) / (limit_high - limit_low)
return array
def _s2_to_rgb(data, smooth_quantiles=True):
# Select
if data.shape[0] > 13:
# assuming channel last
rgb = data[:, :, [3, 2, 1]]
else:
# assuming channel first
rgb = data[[3, 2, 1]].transpose((1, 2, 0))
if smooth_quantiles:
rgb = _rgb_smooth_quantiles(rgb)
else:
rgb = rgb / 2000
# to uint8
rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8)
return rgb
def _img_to_b64(path: str | Path) -> str:
"""Encode image as base64 (optionally downsized)."""
with rio.open(path) as src:
data = src.read()
rgb = _s2_to_rgb(data)
img = Image.fromarray(rgb)
side = max(img.size)
# create square canvas, paste centred, then resize
canvas = Image.new("RGB", (side, side), (255, 255, 255))
canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2))
canvas = canvas.resize((IMG_PX, IMG_PX))
buf = BytesIO()
canvas.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
def _bar_chart(top_scores, img_name, cmap) -> str:
scores = top_scores.values.tolist()
labels = top_scores.index.tolist()
while len(scores) < 3:
scores.append(0)
labels.append("")
fig, ax = plt.subplots(figsize=(3, 1))
y_pos = np.arange(3)
colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0)
for cls, val in zip(labels, scores)]
ax.barh(y_pos, scores, height=0.7, color=colors)
ax.set_xlim(0, 1)
ax.invert_yaxis()
ax.axis("off")
img_name = os.path.splitext(img_name)[0]
if len(img_name) > 25:
img_name = img_name[:23] + "..."
ax.set_title(img_name)
for i, (cls, val) in enumerate(zip(labels, scores)):
if len(cls) > 25:
cls = cls[:23] + "..."
if val > 0: # skip padded rows
ax.text(0.02, i + 0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center")
buf = BytesIO()
fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True)
plt.close(fig)
b64 = base64.b64encode(buf.getvalue()).decode()
return f'<img src="data:image/png;base64,{b64}" style="display:block;margin:auto;width:{IMG_PX}px;" />'
@spaces.GPU
def classify(images, class_text):
class_names = [c.strip() for c in class_text.split(",") if c.strip()]
cards = []
df = run_inference_classification(
model=model,
preprocess=preprocess,
tokenizer=tokenizer,
image_path=images,
class_names=class_names,
verbose=False
)
for img_path, (id, row) in zip(images, df.iterrows()):
scores = row[2:].astype(float) # drop filename column
top = scores.sort_values(ascending=False)[:3]
top = top[top > 0.01] # filter low scores
cmap = build_colormap(class_names)
cards.append(f"""
<div style="width:{IMG_PX}px;margin:18px auto;text-align:left;">
<img src="data:image/png;base64,{_img_to_b64(img_path)}"
style="width:{IMG_PX}px;height:{IMG_PX}px;object-fit:cover;
border-radius:8px;box-shadow:0 2px 6px rgba(0,0,0,.15);display:block;margin:auto;">
{_bar_chart(top, os.path.basename(img_path), cmap)}
</div>""")
return (
"<div style='display:flex;flex-wrap:wrap;justify-content:center;'>"
+ "".join(cards)
+ "</div>"
)
# Cache examples
terramesh_html = classify(EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]))
eurosat_html = classify(EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]))
meterml_html = classify(EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]))
def load_eurosat_example():
return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]), eurosat_html
def load_meterml_example():
return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]), meterml_html
def load_terramesh_example():
return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]), terramesh_html
# UI
with gr.Blocks(
css="""
.gradio-container
#result_box,
#result_box.gr-skeleton {min-height:280px !important;}
""") as demo:
gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP")
gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. "
"You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). "
"We provide three sets of example images with class names and cached outputs. "
"The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). "
"The images are classified based on the similarity between the images embeddings and text embeddings. "
"You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ")
with gr.Row():
img_in = gr.File(
label="Upload S-2 images", file_count="multiple", type="filepath"
)
cls_in = gr.Textbox(
value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]),
label="Class names (comma‑separated)",
)
run_btn = gr.Button("Classify", variant="primary")
# Examples
gr.Markdown("#### Load examples")
with gr.Row():
btn_terramesh = gr.Button("TerraMesh")
btn_eurosat = gr.Button("EuroSAT")
btn_meterml = gr.Button("Meter-ML")
out_html = gr.HTML(label="Results",
elem_id="result_box",
min_height=280)
run_btn.click(classify, inputs=[img_in, cls_in], outputs=out_html)
btn_terramesh.click(
load_terramesh_example,
outputs=[img_in, cls_in, out_html],
)
btn_eurosat.click(
load_eurosat_example,
outputs=[img_in, cls_in, out_html],
)
btn_meterml.click(
load_meterml_example,
outputs=[img_in, cls_in, out_html],
)
if __name__ == "__main__":
demo.launch(share=True, ssr_mode=False)