import base64
import os
import sys
import csv
import spaces
import glob
import numpy as np
import gradio as gr
import rasterio as rio
import matplotlib.pyplot as plt
import matplotlib as mpl
from io import BytesIO
from pathlib import Path
from PIL import Image
from matplotlib import rcParams
from msclip.inference import run_inference_classification
from msclip.inference.utils import build_model
rcParams["font.size"] = 9
rcParams["axes.titlesize"] = 9
IMG_PX = 300
csv.field_size_limit(sys.maxsize)
# Init Llama3-MS-CLIP from Hugging Face
model, preprocess, tokenizer = build_model()
EXAMPLES = {
"EuroSAT": {
"images": glob.glob("examples/eurosat/*.tif"),
"classes": [
"Annual crop", "Forest", "Herbaceous vegetation", "Highway", "Industrial",
"Pasture", "Permanent crop", "Residential", "River", "Sea lake"
]
},
"Meter-ML": {
"images": glob.glob("examples/meterml/*.tif"),
"classes": [
"Concentrated animal feeding operations",
"Landfills",
"Coal mines",
"Other features",
"Natural gas processing plants",
"Oil refineries and petroleum terminals",
"Wastewater treatment plants",
]
},
"TerraMesh": {
"images": glob.glob("examples/terramesh/*.tif"),
"classes": [
"Village", "Beach", "River", "Ice", "Fields", "Mountains", "Desert"
]
},
}
pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]
def build_colormap(class_names):
return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))}
def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000):
"""
array: numpy array with dimensions [C, H, W]
returns 0-1 scaled array
"""
# Get scaling thresholds for smoothing the brightness
limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance])
limit_high = limit_high.clip(default) # Scale only pixels above default value
limit_low = limit_low.clip(0, 1000) # Scale only pixels below 1000
limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
# Smooth very dark and bright values using linear scaling
array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling)
array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling)
# Update scaling params using a 10th of the tolerance for max value
limit_low, limit_high = np.quantile(array, q=[tolerance / 10, 1. - tolerance / 10])
limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value
limit_low = limit_low.clip(0, 500) # Scale only pixels below 500
limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
# Scale data to 0-255
array = (array - limit_low) / (limit_high - limit_low)
return array
def _s2_to_rgb(data, smooth_quantiles=True):
# Select
if data.shape[0] > 13:
# assuming channel last
rgb = data[:, :, [3, 2, 1]]
else:
# assuming channel first
rgb = data[[3, 2, 1]].transpose((1, 2, 0))
if smooth_quantiles:
rgb = _rgb_smooth_quantiles(rgb)
else:
rgb = rgb / 2000
# to uint8
rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8)
return rgb
def _img_to_b64(path: str | Path) -> str:
"""Encode image as base64 (optionally downsized)."""
with rio.open(path) as src:
data = src.read()
rgb = _s2_to_rgb(data)
img = Image.fromarray(rgb)
side = max(img.size)
# create square canvas, paste centred, then resize
canvas = Image.new("RGB", (side, side), (255, 255, 255))
canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2))
canvas = canvas.resize((IMG_PX, IMG_PX))
buf = BytesIO()
canvas.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
def _bar_chart(top_scores, img_name, cmap) -> str:
scores = top_scores.values.tolist()
labels = top_scores.index.tolist()
while len(scores) < 3:
scores.append(0)
labels.append("")
fig, ax = plt.subplots(figsize=(3, 1))
y_pos = np.arange(3)
colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0)
for cls, val in zip(labels, scores)]
ax.barh(y_pos, scores, height=0.7, color=colors)
ax.set_xlim(0, 1)
ax.invert_yaxis()
ax.axis("off")
img_name = os.path.splitext(img_name)[0]
if len(img_name) > 25:
img_name = img_name[:23] + "..."
ax.set_title(img_name)
for i, (cls, val) in enumerate(zip(labels, scores)):
if len(cls) > 25:
cls = cls[:23] + "..."
if val > 0: # skip padded rows
ax.text(0.02, i + 0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center")
buf = BytesIO()
fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True)
plt.close(fig)
b64 = base64.b64encode(buf.getvalue()).decode()
return f''
@spaces.GPU
def classify(images, class_text):
class_names = [c.strip() for c in class_text.split(",") if c.strip()]
cards = []
df = run_inference_classification(
model=model,
preprocess=preprocess,
tokenizer=tokenizer,
image_path=images,
class_names=class_names,
verbose=False
)
for img_path, (id, row) in zip(images, df.iterrows()):
scores = row[2:].astype(float) # drop filename column
top = scores.sort_values(ascending=False)[:3]
top = top[top > 0.01] # filter low scores
cmap = build_colormap(class_names)
cards.append(f"""