Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
import numpy as np | |
import pandas as pd | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import segmentation_models_pytorch as smp | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import os | |
import random | |
from datetime importdatetime | |
# --- Best Practice: Set Matplotlib backend for server environments --- | |
matplotlib.use('Agg') | |
# --- CONFIGURATION (UPDATED FOR DEPLOYMENT) --- | |
class CFG: | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# CRITICAL: Use relative paths for deployment. | |
# Place your model file in the root of your Hugging Face Space repository. | |
MODEL_PATH = "best_model_optimized_83.98.pth" | |
# The app will scan this local folder for example images. | |
EXAMPLES_DIR = "examples" | |
MODEL_NAME = "CustomDeepLabV3+" | |
ENCODER_NAME = "timm-efficientnet-b2" | |
NUM_CLASSES = 8 | |
IMG_SIZE = 256 | |
# Constants for area calculation | |
ORIGINAL_PATCH_DIM = 64 | |
RESOLUTION_M_PER_PIXEL = 10 | |
SQ_METERS_PER_HECTARE = 10000 | |
TOTAL_PATCH_AREA_HECTARES = (ORIGINAL_PATCH_DIM**2 * RESOLUTION_M_PER_PIXEL**2) / SQ_METERS_PER_HECTARE | |
# --- DATA & CLASS INFO --- | |
CLASS_INFO = { | |
0: {"name": "Unclassified", "hex": "#969696"}, 1: {"name": "Water Bodies", "hex": "#0000FF"}, | |
2: {"name": "Dense Forest", "hex": "#006400"}, 3: {"name": "Built up", "hex": "#800080"}, | |
4: {"name": "Agriculture land", "hex": "#00FF00"}, 5: {"name": "Barren land", "hex": "#FFFF00"}, | |
6: {"name": "Fallow land", "hex": "#D2B48C"}, 7: {"name": "Sparse Forest", "hex": "#3CB371"}, | |
} | |
# --- MODEL DEFINITION (REFORMATTED FOR READABILITY) --- | |
class SELayer(nn.Module): | |
def __init__(self, channel, reduction=16): | |
super(SELayer, self).__init__() | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.fc = nn.Sequential( | |
nn.Linear(channel, channel // reduction, bias=False), | |
nn.ReLU(inplace=True), | |
nn.Linear(channel // reduction, channel, bias=False), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
b, c, _, _ = x.size() | |
y = self.avg_pool(x).view(b, c) | |
y = self.fc(y).view(b, c, 1, 1) | |
return x * y.expand_as(x) | |
class CustomDeepLabV3Plus(nn.Module): | |
def __init__(self, encoder_name, in_channels, classes): | |
super().__init__() | |
self.smp_model = smp.DeepLabV3Plus( | |
encoder_name=encoder_name, | |
encoder_weights="imagenet", | |
in_channels=in_channels, | |
classes=classes | |
) | |
decoder_channels = self.smp_model.segmentation_head[0].in_channels | |
self.se_layer = SELayer(decoder_channels) | |
self.segmentation_head = self.smp_model.segmentation_head | |
self.smp_model.segmentation_head = nn.Identity() | |
def forward(self, x): | |
decoder_features = self.smp_model(x) | |
attended_features = self.se_layer(decoder_features) | |
output = self.segmentation_head(attended_features) | |
return output | |
# --- MODEL LOADING & TRANSFORMS --- | |
def load_model(): | |
print(f"Loading model from {CFG.MODEL_PATH} on device {CFG.DEVICE}...") | |
model = CustomDeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, in_channels=3, classes=CFG.NUM_CLASSES) | |
if not os.path.exists(CFG.MODEL_PATH): | |
raise FileNotFoundError(f"CRITICAL: Model file not found at '{CFG.MODEL_PATH}'. Please ensure the model file is in the root directory of your Space.") | |
# Using weights_only=True is safer | |
model.load_state_dict(torch.load(CFG.MODEL_PATH, map_location=torch.device(CFG.DEVICE), weights_only=True)) | |
model.to(CFG.DEVICE) | |
model.eval() | |
print("Model loaded successfully!") | |
return model | |
model = load_model() | |
transform = A.Compose([ | |
A.Resize(height=CFG.IMG_SIZE, width=CFG.IMG_SIZE), | |
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
ToTensorV2() | |
]) | |
# --- HELPER & ANALYSIS FUNCTIONS --- | |
def create_color_map(): | |
color_map = np.zeros((256, 3), dtype=np.uint8) | |
for class_id, info in CLASS_INFO.items(): | |
color_map[class_id] = tuple(int(info['hex'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) | |
return color_map | |
COLOR_MAP_NUMPY = create_color_map() | |
def create_colored_mask(mask_np): | |
return Image.fromarray(COLOR_MAP_NUMPY[mask_np]) | |
def analyze_one_image(image_filepath: str): | |
if image_filepath is None: return None, {} | |
image = Image.open(image_filepath) | |
image_np = np.array(image.convert("RGB")) | |
transformed = transform(image=image_np) | |
input_tensor = transformed['image'].unsqueeze(0).to(CFG.DEVICE) | |
with torch.no_grad(): | |
prediction = model(input_tensor) | |
pred_mask = torch.argmax(prediction.squeeze(), dim=0).cpu().numpy() | |
area_results = {} | |
class_indices, pixel_counts = np.unique(pred_mask, return_counts=True) | |
total_pixels_in_mask = pred_mask.size | |
for class_id, count in zip(class_indices, pixel_counts): | |
if class_id in CLASS_INFO: | |
pixel_proportion = count / total_pixels_in_mask | |
area_hectares = pixel_proportion * CFG.TOTAL_PATCH_AREA_HECTARES | |
area_results[CLASS_INFO[class_id]["name"]] = area_hectares | |
return pred_mask, area_results | |
def single_image_analysis(image_filepath: str): | |
if image_filepath is None: raise gr.Error("Please upload an image to analyze.") | |
pred_mask_np, areas_dict = analyze_one_image(image_filepath) | |
pred_mask_pil = create_colored_mask(pred_mask_np) | |
area_data = sorted(areas_dict.items(), key=lambda item: item[1], reverse=True) | |
area_df = pd.DataFrame(area_data, columns=["Land Cover Class", "Area (Hectares)"]) | |
area_df["Area (Hectares)"] = area_df["Area (Hectares)"].map('{:.4f}'.format) | |
analysis_results = {"areas": areas_dict, "area_df": area_df, "image_path": image_filepath} | |
return pred_mask_pil, area_df, analysis_results | |
def compare_land_cover(filepath1: str, filepath2: str): | |
if filepath1 is None or filepath2 is None: | |
raise gr.Error("Please upload both a 'Before' and 'After' image for comparison.") | |
_, areas1_dict = analyze_one_image(filepath1) | |
_, areas2_dict = analyze_one_image(filepath2) | |
mask1_pil = create_colored_mask(analyze_one_image(filepath1)[0]) | |
mask2_pil = create_colored_mask(analyze_one_image(filepath2)[0]) | |
all_class_names = sorted(list(set(areas1_dict.keys()) | set(areas2_dict.keys()))) | |
data_for_df = [[name, areas1_dict.get(name, 0), areas2_dict.get(name, 0)] for name in all_class_names] | |
df = pd.DataFrame(data_for_df, columns=["Class", "Area 1 (ha)", "Area 2 (ha)"]) | |
df['Change (ha)'] = df['Area 2 (ha)'] - df['Area 1 (ha)'] | |
df['% Change'] = df.apply(lambda row: (row['Change (ha)'] / row['Area 1 (ha)'] * 100) if row['Area 1 (ha)'] > 0 else float('inf'), axis=1) | |
df_display = df.copy() | |
for col in ["Area 1 (ha)", "Area 2 (ha)"]: df_display[col] = df_display[col].map('{:.2f}'.format) | |
df_display["Change (ha)"] = df_display["Change (ha)"].map('{:+.2f}'.format) | |
df_display["% Change"] = df_display["% Change"].apply(lambda x: f"{x:+.2f}%" if x != float('inf') else "New") | |
plt.style.use('seaborn-v0_8-whitegrid') | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
index = np.arange(len(df)) | |
bar_width = 0.35 | |
ax.bar(index - bar_width/2, df['Area 1 (ha)'], bar_width, label='Area 1 (Before)', color='cornflowerblue') | |
ax.bar(index + bar_width/2, df['Area 2 (ha)'], bar_width, label='Area 2 (After)', color='salmon') | |
ax.set_xlabel('Land Cover Class', fontweight='bold') | |
ax.set_ylabel('Area (Hectares)', fontweight='bold') | |
ax.set_title('Land Cover Change Analysis', fontsize=16, fontweight='bold') | |
ax.set_xticks(index) | |
ax.set_xticklabels(df['Class'], rotation=45, ha="right") | |
ax.legend() | |
fig.tight_layout() | |
analysis_results = {"df": df_display, "path1": filepath1, "path2": filepath2, "raw_df": df} | |
return mask1_pil, mask2_pil, df_display, fig, analysis_results | |
# --- REPORTING FUNCTIONS --- | |
def generate_report(analysis_results, report_type): | |
if not analysis_results: | |
raise gr.Error("Please run an analysis first before generating a report.") | |
if report_type == "single": | |
filename = os.path.basename(analysis_results['image_path']) | |
report = f"# LULC Analysis Report: {filename}\n" | |
report += f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" | |
report += "## Area Distribution (Hectares)\n" | |
report += analysis_results['area_df'].to_markdown(index=False) | |
elif report_type == "change": | |
file1 = os.path.basename(analysis_results['path1']) | |
file2 = os.path.basename(analysis_results['path2']) | |
df = analysis_results['raw_df'] | |
summary = "" | |
df_sorted = df.reindex(df['Change (ha)'].abs().sort_values(ascending=False).index) | |
for _, row in df_sorted.head(3).iterrows(): | |
if abs(row['Change (ha)']) > 0.01: | |
direction = "increased" if row['Change (ha)'] > 0 else "decreased" | |
summary += f"- **{row['Class']}** has {direction} by **{abs(row['Change (ha)']):.2f} hectares**.\n" | |
report = f"# LULC Change Detection Report\n" | |
report += f"**Comparison:** `{file1}` (Before) vs. `{file2}` (After)\n" | |
report += f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" | |
report += "## Key Summary of Changes\n" | |
report += summary + "\n" | |
report += "## Detailed Comparison Table\n" | |
report += analysis_results['df'].to_markdown(index=False) | |
# Switch to the report tab and populate it | |
return { | |
report_editor: gr.update(value=report), | |
download_btn: gr.update(visible=True), | |
tabs: gr.update(selected=2) | |
} | |
def save_report_to_file(report_content): | |
filepath = "LULC_Report.md" | |
with open(filepath, "w", encoding="utf-8") as f: | |
f.write(report_content) | |
return filepath | |
# --- EXAMPLE FINDER --- | |
def find_examples(): | |
single_examples = [] | |
change_examples = [] | |
if os.path.isdir(CFG.EXAMPLES_DIR): | |
files = sorted([os.path.join(CFG.EXAMPLES_DIR, f) for f in os.listdir(CFG.EXAMPLES_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif'))]) | |
single_examples = files[:10] # Take up to 10 for single analysis | |
# Create pairs for change detection | |
if len(files) >= 2: | |
for i in range(0, min(len(files) - 1, 10), 2): # Take up to 5 pairs | |
change_examples.append([files[i], files[i+1]]) | |
return single_examples, change_examples | |
single_examples, change_examples = find_examples() | |
# --- GRADIO UI LAYOUT --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="LULC Analysis Platform") as demo: | |
gr.Markdown("# Land Use & Land Cover (LULC) Analysis Platform") | |
gr.Markdown("An AI-powered tool to analyze satellite imagery for environmental monitoring and planning.") | |
# Hidden state objects to store analysis results robustly | |
single_analysis_results = gr.State() | |
change_analysis_results = gr.State() | |
with gr.Tabs() as tabs: | |
with gr.TabItem("Single Image Analysis", id=0): | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
single_img_input = gr.Image(type="filepath", label="Upload Satellite Image") | |
single_analyze_btn = gr.Button("Analyze Image", variant="primary") | |
with gr.Column(scale=1): | |
single_mask_output = gr.Image(type="pil", label="Predicted Mask") | |
with gr.Row(): | |
area_df_output = gr.DataFrame(label="Predicted Area Distribution", wrap=True) | |
send_single_report_btn = gr.Button("➡ Create Report from this Analysis") | |
gr.Examples(examples=single_examples, inputs=single_img_input, label="Click an Example to Start") | |
with gr.TabItem("Change Detection Tool", id=1): | |
with gr.Row(variant="panel"): | |
compare_img1 = gr.Image(type="filepath", label="Image 1 (e.g., Before / 2020)") | |
compare_img2 = gr.Image(type="filepath", label="Image 2 (e.g., After / 2024)") | |
compare_analyze_btn = gr.Button("Analyze Changes", variant="primary") | |
with gr.Row(): | |
compare_mask1 = gr.Image(type="pil", label="Mask for Image 1") | |
compare_mask2 = gr.Image(type="pil", label="Mask for Image 2") | |
with gr.Tabs(): | |
with gr.TabItem("📊 Change Chart"): compare_plot = gr.Plot() | |
with gr.TabItem("📑 Comparison Table"): compare_df = gr.DataFrame(interactive=False) | |
send_change_report_btn = gr.Button("➡ Create Report from this Analysis") | |
if change_examples: | |
gr.Examples(examples=change_examples, inputs=[compare_img1, compare_img2], label="Click an Example Pair to Start") | |
with gr.TabItem("Report Builder", id=2): | |
gr.Markdown("### Create and Download Your Analysis Report") | |
gr.Markdown("1. Run an analysis on one of the other tabs.\n" | |
"2. Click the **'➡ Create Report'** button.\n" | |
"3. Your report will appear below. You can edit it before downloading.\n") | |
with gr.Column(): | |
report_editor = gr.Textbox(label="Your Report (Editable)", lines=20, interactive=True) | |
download_btn = gr.DownloadButton(label="Download Report (.md)", visible=False) | |
# --- BUTTON CLICK EVENTS & DATA FLOW --- | |
# Single Image Analysis Flow | |
single_analyze_btn.click( | |
fn=single_image_analysis, | |
inputs=single_img_input, | |
outputs=[single_mask_output, area_df_output, single_analysis_results] | |
).then( | |
lambda: gr.update(interactive=False, value="Analyzing..."), None, single_analyze_btn | |
).then( | |
lambda: gr.update(interactive=True, value="Analyze Image"), None, single_analyze_btn | |
) | |
send_single_report_btn.click( | |
fn=lambda res: generate_report(res, "single"), | |
inputs=single_analysis_results, | |
outputs=[report_editor, download_btn, tabs] | |
) | |
# Change Detection Flow | |
compare_analyze_btn.click( | |
fn=compare_land_cover, | |
inputs=[compare_img1, compare_img2], | |
outputs=[compare_mask1, compare_mask2, compare_df, compare_plot, change_analysis_results] | |
).then( | |
lambda: gr.update(interactive=False, value="Analyzing..."), None, compare_analyze_btn | |
).then( | |
lambda: gr.update(interactive=True, value="Analyze Changes"), None, compare_analyze_btn | |
) | |
send_change_report_btn.click( | |
fn=lambda res: generate_report(res, "change"), | |
inputs=change_analysis_results, | |
outputs=[report_editor, download_btn, tabs] | |
) | |
# Report Download Flow | |
download_btn.click(fn=save_report_to_file, inputs=report_editor, outputs=download_btn) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |