Spaces:
Running
on
Zero
Running
on
Zero
import base64 | |
import os | |
import sys | |
import csv | |
import spaces | |
import glob | |
import numpy as np | |
import gradio as gr | |
import rasterio as rio | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
from io import BytesIO | |
from pathlib import Path | |
from PIL import Image | |
from matplotlib import rcParams | |
from msclip.inference import run_inference_classification | |
from msclip.inference.utils import build_model | |
rcParams["font.size"] = 9 | |
rcParams["axes.titlesize"] = 9 | |
IMG_PX = 300 | |
csv.field_size_limit(sys.maxsize) | |
# Init Llama3-MS-CLIP from Hugging Face | |
model, preprocess, tokenizer = build_model() | |
EXAMPLES = { | |
"EuroSAT": { | |
"images": glob.glob("examples/eurosat/*.tif"), | |
"classes": [ | |
"Annual crop", "Forest", "Herbaceous vegetation", "Highway", "Industrial", | |
"Pasture", "Permanent crop", "Residential", "River", "Sea lake" | |
] | |
}, | |
"Meter-ML": { | |
"images": glob.glob("examples/meterml/*.tif"), | |
"classes": [ | |
"Concentrated animal feeding operations", | |
"Landfills", | |
"Coal mines", | |
"Other features", | |
"Natural gas processing plants", | |
"Oil refineries and petroleum terminals", | |
"Wastewater treatment plants", | |
] | |
}, | |
"TerraMesh": { | |
"images": glob.glob("examples/terramesh/*.tif"), | |
"classes": [ | |
"Village", "Beach", "River", "Ice", "Fields", "Mountains", "Desert" | |
] | |
}, | |
} | |
pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors] | |
def build_colormap(class_names): | |
return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))} | |
def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000): | |
""" | |
array: numpy array with dimensions [C, H, W] | |
returns 0-1 scaled array | |
""" | |
# Get scaling thresholds for smoothing the brightness | |
limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance]) | |
limit_high = limit_high.clip(default) # Scale only pixels above default value | |
limit_low = limit_low.clip(0, 1000) # Scale only pixels below 1000 | |
limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already | |
# Smooth very dark and bright values using linear scaling | |
array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling) | |
array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling) | |
# Update scaling params using a 10th of the tolerance for max value | |
limit_low, limit_high = np.quantile(array, q=[tolerance / 10, 1. - tolerance / 10]) | |
limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value | |
limit_low = limit_low.clip(0, 500) # Scale only pixels below 500 | |
limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already | |
# Scale data to 0-255 | |
array = (array - limit_low) / (limit_high - limit_low) | |
return array | |
def _s2_to_rgb(data, smooth_quantiles=True): | |
# Select | |
if data.shape[0] > 13: | |
# assuming channel last | |
rgb = data[:, :, [3, 2, 1]] | |
else: | |
# assuming channel first | |
rgb = data[[3, 2, 1]].transpose((1, 2, 0)) | |
if smooth_quantiles: | |
rgb = _rgb_smooth_quantiles(rgb) | |
else: | |
rgb = rgb / 2000 | |
# to uint8 | |
rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8) | |
return rgb | |
def _img_to_b64(path: str | Path) -> str: | |
"""Encode image as base64 (optionally downsized).""" | |
with rio.open(path) as src: | |
data = src.read() | |
rgb = _s2_to_rgb(data) | |
img = Image.fromarray(rgb) | |
side = max(img.size) | |
# create square canvas, paste centred, then resize | |
canvas = Image.new("RGB", (side, side), (255, 255, 255)) | |
canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2)) | |
canvas = canvas.resize((IMG_PX, IMG_PX)) | |
buf = BytesIO() | |
canvas.save(buf, format="PNG") | |
return base64.b64encode(buf.getvalue()).decode() | |
def _bar_chart(top_scores, img_name, cmap) -> str: | |
scores = top_scores.values.tolist() | |
labels = top_scores.index.tolist() | |
while len(scores) < 3: | |
scores.append(0) | |
labels.append("") | |
fig, ax = plt.subplots(figsize=(3, 1)) | |
y_pos = np.arange(3) | |
colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0) | |
for cls, val in zip(labels, scores)] | |
ax.barh(y_pos, scores, height=0.7, color=colors) | |
ax.set_xlim(0, 1) | |
ax.invert_yaxis() | |
ax.axis("off") | |
img_name = os.path.splitext(img_name)[0] | |
if len(img_name) > 25: | |
img_name = img_name[:23] + "..." | |
ax.set_title(img_name) | |
for i, (cls, val) in enumerate(zip(labels, scores)): | |
if len(cls) > 25: | |
cls = cls[:23] + "..." | |
if val > 0: # skip padded rows | |
ax.text(0.02, i + 0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center") | |
buf = BytesIO() | |
fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True) | |
plt.close(fig) | |
b64 = base64.b64encode(buf.getvalue()).decode() | |
return f'<img src="data:image/png;base64,{b64}" style="display:block;margin:auto;width:{IMG_PX}px;" />' | |
def classify(images, class_text): | |
class_names = [c.strip() for c in class_text.split(",") if c.strip()] | |
cards = [] | |
df = run_inference_classification( | |
model=model, | |
preprocess=preprocess, | |
tokenizer=tokenizer, | |
image_path=images, | |
class_names=class_names, | |
verbose=False | |
) | |
for img_path, (id, row) in zip(images, df.iterrows()): | |
scores = row[2:].astype(float) # drop filename column | |
top = scores.sort_values(ascending=False)[:3] | |
top = top[top > 0.01] # filter low scores | |
cmap = build_colormap(class_names) | |
cards.append(f""" | |
<div style="width:{IMG_PX}px;margin:18px auto;text-align:left;"> | |
<img src="data:image/png;base64,{_img_to_b64(img_path)}" | |
style="width:{IMG_PX}px;height:{IMG_PX}px;object-fit:cover; | |
border-radius:8px;box-shadow:0 2px 6px rgba(0,0,0,.15);display:block;margin:auto;"> | |
{_bar_chart(top, os.path.basename(img_path), cmap)} | |
</div>""") | |
return ( | |
"<div style='display:flex;flex-wrap:wrap;justify-content:center;'>" | |
+ "".join(cards) | |
+ "</div>" | |
) | |
# Cache examples | |
terramesh_html = classify(EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"])) | |
eurosat_html = classify(EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"])) | |
meterml_html = classify(EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"])) | |
def load_eurosat_example(): | |
return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]), eurosat_html | |
def load_meterml_example(): | |
return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]), meterml_html | |
def load_terramesh_example(): | |
return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]), terramesh_html | |
# UI | |
with gr.Blocks( | |
css=""" | |
.gradio-container | |
#result_box, | |
#result_box.gr-skeleton {min-height:280px !important;} | |
""") as demo: | |
gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP") | |
gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. " | |
"You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). " | |
"We provide three sets of example images with class names and cached outputs. " | |
"The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). " | |
"The images are classified based on the similarity between the images embeddings and text embeddings. " | |
"You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ") | |
with gr.Row(): | |
img_in = gr.File( | |
label="Upload S-2 images", file_count="multiple", type="filepath" | |
) | |
cls_in = gr.Textbox( | |
value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]), | |
label="Class names (comma‑separated)", | |
) | |
run_btn = gr.Button("Classify", variant="primary") | |
# Examples | |
gr.Markdown("#### Load examples") | |
with gr.Row(): | |
btn_terramesh = gr.Button("TerraMesh") | |
btn_eurosat = gr.Button("EuroSAT") | |
btn_meterml = gr.Button("Meter-ML") | |
out_html = gr.HTML(label="Results", | |
elem_id="result_box", | |
min_height=280) | |
run_btn.click(classify, inputs=[img_in, cls_in], outputs=out_html) | |
btn_terramesh.click( | |
load_terramesh_example, | |
outputs=[img_in, cls_in, out_html], | |
) | |
btn_eurosat.click( | |
load_eurosat_example, | |
outputs=[img_in, cls_in, out_html], | |
) | |
btn_meterml.click( | |
load_meterml_example, | |
outputs=[img_in, cls_in, out_html], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True, ssr_mode=False) | |