Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,52 +1,63 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
-
from PIL import Image
|
5 |
import numpy as np
|
6 |
-
import
|
7 |
-
import
|
|
|
8 |
import segmentation_models_pytorch as smp
|
|
|
|
|
9 |
import os
|
10 |
-
import
|
|
|
11 |
|
12 |
-
# ---
|
|
|
13 |
|
|
|
14 |
class CFG:
|
15 |
-
"""
|
16 |
-
Configuration class adapted from the provided training script.
|
17 |
-
"""
|
18 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
MODEL_NAME = "CustomDeepLabV3+"
|
20 |
ENCODER_NAME = "timm-efficientnet-b2"
|
21 |
-
IN_CHANNELS = 3
|
22 |
NUM_CLASSES = 8
|
23 |
IMG_SIZE = 256
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
27 |
|
28 |
-
#
|
29 |
CLASS_INFO = {
|
30 |
-
0: {"name": "Unclassified", "
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
4: {"name": "Agriculture land", "color": (0, 255, 0)}, # Bright Green
|
35 |
-
5: {"name": "Barren land", "color": (255, 255, 0)}, # Yellow
|
36 |
-
6: {"name": "Fallow land", "color": (210, 180, 140)}, # Tan
|
37 |
-
7: {"name": "Sparse Forest", "color": (60, 179, 113)}, # Medium Sea Green
|
38 |
}
|
39 |
|
40 |
-
|
41 |
-
SQ_METERS_PER_HECTARE = 10000
|
42 |
-
|
43 |
-
# --- 2. Model Definitions (from provided training script) ---
|
44 |
-
|
45 |
class SELayer(nn.Module):
|
46 |
def __init__(self, channel, reduction=16):
|
47 |
super(SELayer, self).__init__()
|
48 |
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
49 |
-
self.fc = nn.Sequential(
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def forward(self, x):
|
51 |
b, c, _, _ = x.size()
|
52 |
y = self.avg_pool(x).view(b, c)
|
@@ -56,180 +67,276 @@ class SELayer(nn.Module):
|
|
56 |
class CustomDeepLabV3Plus(nn.Module):
|
57 |
def __init__(self, encoder_name, in_channels, classes):
|
58 |
super().__init__()
|
59 |
-
self.smp_model = smp.DeepLabV3Plus(
|
|
|
|
|
|
|
|
|
|
|
60 |
decoder_channels = self.smp_model.segmentation_head[0].in_channels
|
61 |
self.se_layer = SELayer(decoder_channels)
|
62 |
self.segmentation_head = self.smp_model.segmentation_head
|
63 |
self.smp_model.segmentation_head = nn.Identity()
|
|
|
64 |
def forward(self, x):
|
65 |
decoder_features = self.smp_model(x)
|
66 |
attended_features = self.se_layer(decoder_features)
|
67 |
output = self.segmentation_head(attended_features)
|
68 |
return output
|
69 |
|
70 |
-
# ---
|
71 |
-
|
72 |
def load_model():
|
73 |
-
""
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
# 1. Check if the model file exists.
|
79 |
-
if not os.path.exists(model_path):
|
80 |
-
raise FileNotFoundError(
|
81 |
-
f"Model file not found at '{model_path}'. "
|
82 |
-
"Please ensure the model is in the correct directory and the path is correct."
|
83 |
-
)
|
84 |
-
|
85 |
-
# 2. Check if the file is a Git LFS pointer instead of the actual model.
|
86 |
-
# Real model files are many megabytes; pointer files are typically < 1KB.
|
87 |
-
if os.path.getsize(model_path) < 1024 * 1024: # 1 MB threshold
|
88 |
-
raise ValueError(
|
89 |
-
f"The file at '{model_path}' is too small to be a valid model. "
|
90 |
-
"It is likely a Git LFS pointer file. Please ensure you have the full, correct model file. "
|
91 |
-
"If using Git, you may need to run 'git lfs pull' to download the actual file."
|
92 |
-
)
|
93 |
-
|
94 |
-
# 3. If checks pass, create the model structure.
|
95 |
-
model = CustomDeepLabV3Plus(
|
96 |
-
encoder_name=CFG.ENCODER_NAME,
|
97 |
-
in_channels=CFG.IN_CHANNELS,
|
98 |
-
classes=CFG.NUM_CLASSES
|
99 |
-
)
|
100 |
|
101 |
-
#
|
102 |
-
|
103 |
-
model.load_state_dict(torch.load(model_path, map_location=torch.device(CFG.DEVICE), weights_only=False))
|
104 |
-
except _pickle.UnpicklingError:
|
105 |
-
# This error is a strong indicator of an LFS pointer or corrupted file.
|
106 |
-
raise IOError(
|
107 |
-
f"Failed to load '{model_path}'. The file is not a valid PyTorch model. "
|
108 |
-
"This confirms the file is likely a Git LFS text pointer. Please replace it with the actual model file."
|
109 |
-
)
|
110 |
-
except Exception as e:
|
111 |
-
# Catch other potential loading errors
|
112 |
-
raise IOError(f"An unexpected error occurred while loading the model from '{model_path}'. Original error: {e}")
|
113 |
-
|
114 |
model.to(CFG.DEVICE)
|
115 |
model.eval()
|
116 |
-
print("Model loaded successfully
|
117 |
return model
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
"""Creates a colored RGB image from a class index mask."""
|
130 |
-
rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
|
131 |
for class_id, info in CLASS_INFO.items():
|
132 |
-
|
133 |
-
return
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
font = ImageFont.load_default()
|
149 |
-
|
150 |
-
for i, (class_id, info) in enumerate(CLASS_INFO.items()):
|
151 |
-
y_start = i * item_height
|
152 |
-
draw.rectangle([5, y_start + 5, 25, y_start + 20], fill=info["color"])
|
153 |
-
draw.text((35, y_start + 5), f'{info["name"]}', fill='black', font=font)
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
-
def
|
164 |
-
|
165 |
-
|
166 |
-
"""
|
167 |
-
if input_image is None:
|
168 |
-
raise gr.Error("Please upload an image.")
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
class_name = CLASS_INFO[class_id]["name"]
|
189 |
-
area_hectares = (count * PIXEL_AREA_SQ_METERS) / SQ_METERS_PER_HECTARE
|
190 |
-
percentage = (count / total_pixels) * 100
|
191 |
-
area_results[class_name] = f"{area_hectares:.4f} Hectares ({percentage:.2f}%)"
|
192 |
-
|
193 |
-
return colored_mask, area_results, legend_image
|
194 |
-
|
195 |
-
# --- 5. Gradio Interface ---
|
196 |
-
|
197 |
-
# Load example images from your file system.
|
198 |
-
example_files = ["comparison_1.jpg", "comparison_4.jpg", "comparison_5.jpg", "comparison_8.jpg"]
|
199 |
-
example_images = [file for file in example_files if os.path.exists(file)]
|
200 |
-
|
201 |
-
app_description = """
|
202 |
-
### Sen-2 LULC: Interactive Land Cover Analysis
|
203 |
-
This tool performs semantic segmentation on satellite imagery to classify different types of land use and land cover (LULC).
|
204 |
-
It is based on the **Sen-2 LULC dataset** and a **Custom DeepLabV3+** model as described in the research paper "Sen-2 LULC: Land Use Land Cover Dataset for Deep Learning Approaches".
|
205 |
-
|
206 |
-
**How it works:**
|
207 |
-
1. **Upload Image:** Upload a satellite image patch. You can use the examples below.
|
208 |
-
2. **Model Prediction:** The deep learning model classifies each pixel into one of the 8 LULC classes.
|
209 |
-
3. **Analysis:** The application generates:
|
210 |
-
* A color-coded **Segmentation Mask** for visual analysis.
|
211 |
-
* A **Quantitative Report** calculating the area of each land class in hectares. This is possible because the source Sentinel-2 imagery has a **10m resolution**, meaning each pixel represents a 100m² area on the ground.
|
212 |
-
"""
|
213 |
-
|
214 |
-
iface = gr.Interface(
|
215 |
-
fn=predict_land_cover,
|
216 |
-
inputs=gr.Image(type="pil", label="Upload Satellite Image Patch"),
|
217 |
-
outputs=[
|
218 |
-
gr.Image(type="pil", label="Predicted Segmentation Mask"),
|
219 |
-
gr.Label(label="Land Cover Area Analysis (Hectares)"),
|
220 |
-
gr.Image(type="pil", label="Legend")
|
221 |
-
],
|
222 |
-
title="Land Use & Land Cover Segmentation and Area Analysis",
|
223 |
-
description=app_description,
|
224 |
-
examples=example_images,
|
225 |
-
cache_examples=False,
|
226 |
-
allow_flagging="never"
|
227 |
-
)
|
228 |
|
229 |
if __name__ == "__main__":
|
230 |
-
|
231 |
-
debug=True,
|
232 |
-
server_name="0.0.0.0", # Allow external connections
|
233 |
-
server_port=7860,
|
234 |
-
share=True
|
235 |
-
)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
+
from PIL import Image
|
5 |
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import matplotlib
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
import segmentation_models_pytorch as smp
|
10 |
+
import albumentations as A
|
11 |
+
from albumentations.pytorch import ToTensorV2
|
12 |
import os
|
13 |
+
import random
|
14 |
+
from datetime importdatetime
|
15 |
|
16 |
+
# --- Best Practice: Set Matplotlib backend for server environments ---
|
17 |
+
matplotlib.use('Agg')
|
18 |
|
19 |
+
# --- CONFIGURATION (UPDATED FOR DEPLOYMENT) ---
|
20 |
class CFG:
|
|
|
|
|
|
|
21 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
|
23 |
+
# CRITICAL: Use relative paths for deployment.
|
24 |
+
# Place your model file in the root of your Hugging Face Space repository.
|
25 |
+
MODEL_PATH = "best_model_optimized_83.98.pth"
|
26 |
+
|
27 |
+
# The app will scan this local folder for example images.
|
28 |
+
EXAMPLES_DIR = "examples"
|
29 |
+
|
30 |
MODEL_NAME = "CustomDeepLabV3+"
|
31 |
ENCODER_NAME = "timm-efficientnet-b2"
|
|
|
32 |
NUM_CLASSES = 8
|
33 |
IMG_SIZE = 256
|
34 |
+
|
35 |
+
# Constants for area calculation
|
36 |
+
ORIGINAL_PATCH_DIM = 64
|
37 |
+
RESOLUTION_M_PER_PIXEL = 10
|
38 |
+
SQ_METERS_PER_HECTARE = 10000
|
39 |
+
TOTAL_PATCH_AREA_HECTARES = (ORIGINAL_PATCH_DIM**2 * RESOLUTION_M_PER_PIXEL**2) / SQ_METERS_PER_HECTARE
|
40 |
|
41 |
+
# --- DATA & CLASS INFO ---
|
42 |
CLASS_INFO = {
|
43 |
+
0: {"name": "Unclassified", "hex": "#969696"}, 1: {"name": "Water Bodies", "hex": "#0000FF"},
|
44 |
+
2: {"name": "Dense Forest", "hex": "#006400"}, 3: {"name": "Built up", "hex": "#800080"},
|
45 |
+
4: {"name": "Agriculture land", "hex": "#00FF00"}, 5: {"name": "Barren land", "hex": "#FFFF00"},
|
46 |
+
6: {"name": "Fallow land", "hex": "#D2B48C"}, 7: {"name": "Sparse Forest", "hex": "#3CB371"},
|
|
|
|
|
|
|
|
|
47 |
}
|
48 |
|
49 |
+
# --- MODEL DEFINITION (REFORMATTED FOR READABILITY) ---
|
|
|
|
|
|
|
|
|
50 |
class SELayer(nn.Module):
|
51 |
def __init__(self, channel, reduction=16):
|
52 |
super(SELayer, self).__init__()
|
53 |
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
54 |
+
self.fc = nn.Sequential(
|
55 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
56 |
+
nn.ReLU(inplace=True),
|
57 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
58 |
+
nn.Sigmoid()
|
59 |
+
)
|
60 |
+
|
61 |
def forward(self, x):
|
62 |
b, c, _, _ = x.size()
|
63 |
y = self.avg_pool(x).view(b, c)
|
|
|
67 |
class CustomDeepLabV3Plus(nn.Module):
|
68 |
def __init__(self, encoder_name, in_channels, classes):
|
69 |
super().__init__()
|
70 |
+
self.smp_model = smp.DeepLabV3Plus(
|
71 |
+
encoder_name=encoder_name,
|
72 |
+
encoder_weights="imagenet",
|
73 |
+
in_channels=in_channels,
|
74 |
+
classes=classes
|
75 |
+
)
|
76 |
decoder_channels = self.smp_model.segmentation_head[0].in_channels
|
77 |
self.se_layer = SELayer(decoder_channels)
|
78 |
self.segmentation_head = self.smp_model.segmentation_head
|
79 |
self.smp_model.segmentation_head = nn.Identity()
|
80 |
+
|
81 |
def forward(self, x):
|
82 |
decoder_features = self.smp_model(x)
|
83 |
attended_features = self.se_layer(decoder_features)
|
84 |
output = self.segmentation_head(attended_features)
|
85 |
return output
|
86 |
|
87 |
+
# --- MODEL LOADING & TRANSFORMS ---
|
|
|
88 |
def load_model():
|
89 |
+
print(f"Loading model from {CFG.MODEL_PATH} on device {CFG.DEVICE}...")
|
90 |
+
model = CustomDeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, in_channels=3, classes=CFG.NUM_CLASSES)
|
91 |
+
if not os.path.exists(CFG.MODEL_PATH):
|
92 |
+
raise FileNotFoundError(f"CRITICAL: Model file not found at '{CFG.MODEL_PATH}'. Please ensure the model file is in the root directory of your Space.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
# Using weights_only=True is safer
|
95 |
+
model.load_state_dict(torch.load(CFG.MODEL_PATH, map_location=torch.device(CFG.DEVICE), weights_only=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
model.to(CFG.DEVICE)
|
97 |
model.eval()
|
98 |
+
print("Model loaded successfully!")
|
99 |
return model
|
100 |
|
101 |
+
model = load_model()
|
102 |
+
transform = A.Compose([
|
103 |
+
A.Resize(height=CFG.IMG_SIZE, width=CFG.IMG_SIZE),
|
104 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
105 |
+
ToTensorV2()
|
106 |
+
])
|
107 |
+
|
108 |
+
# --- HELPER & ANALYSIS FUNCTIONS ---
|
109 |
+
def create_color_map():
|
110 |
+
color_map = np.zeros((256, 3), dtype=np.uint8)
|
|
|
|
|
111 |
for class_id, info in CLASS_INFO.items():
|
112 |
+
color_map[class_id] = tuple(int(info['hex'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
|
113 |
+
return color_map
|
114 |
+
|
115 |
+
COLOR_MAP_NUMPY = create_color_map()
|
116 |
+
def create_colored_mask(mask_np):
|
117 |
+
return Image.fromarray(COLOR_MAP_NUMPY[mask_np])
|
118 |
+
|
119 |
+
def analyze_one_image(image_filepath: str):
|
120 |
+
if image_filepath is None: return None, {}
|
121 |
+
image = Image.open(image_filepath)
|
122 |
+
image_np = np.array(image.convert("RGB"))
|
123 |
+
transformed = transform(image=image_np)
|
124 |
+
input_tensor = transformed['image'].unsqueeze(0).to(CFG.DEVICE)
|
125 |
+
|
126 |
+
with torch.no_grad():
|
127 |
+
prediction = model(input_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
pred_mask = torch.argmax(prediction.squeeze(), dim=0).cpu().numpy()
|
130 |
+
|
131 |
+
area_results = {}
|
132 |
+
class_indices, pixel_counts = np.unique(pred_mask, return_counts=True)
|
133 |
+
total_pixels_in_mask = pred_mask.size
|
134 |
+
|
135 |
+
for class_id, count in zip(class_indices, pixel_counts):
|
136 |
+
if class_id in CLASS_INFO:
|
137 |
+
pixel_proportion = count / total_pixels_in_mask
|
138 |
+
area_hectares = pixel_proportion * CFG.TOTAL_PATCH_AREA_HECTARES
|
139 |
+
area_results[CLASS_INFO[class_id]["name"]] = area_hectares
|
140 |
+
|
141 |
+
return pred_mask, area_results
|
142 |
|
143 |
+
def single_image_analysis(image_filepath: str):
|
144 |
+
if image_filepath is None: raise gr.Error("Please upload an image to analyze.")
|
145 |
+
|
146 |
+
pred_mask_np, areas_dict = analyze_one_image(image_filepath)
|
147 |
+
pred_mask_pil = create_colored_mask(pred_mask_np)
|
148 |
+
|
149 |
+
area_data = sorted(areas_dict.items(), key=lambda item: item[1], reverse=True)
|
150 |
+
area_df = pd.DataFrame(area_data, columns=["Land Cover Class", "Area (Hectares)"])
|
151 |
+
area_df["Area (Hectares)"] = area_df["Area (Hectares)"].map('{:.4f}'.format)
|
152 |
+
|
153 |
+
analysis_results = {"areas": areas_dict, "area_df": area_df, "image_path": image_filepath}
|
154 |
+
|
155 |
+
return pred_mask_pil, area_df, analysis_results
|
156 |
|
157 |
+
def compare_land_cover(filepath1: str, filepath2: str):
|
158 |
+
if filepath1 is None or filepath2 is None:
|
159 |
+
raise gr.Error("Please upload both a 'Before' and 'After' image for comparison.")
|
|
|
|
|
|
|
160 |
|
161 |
+
_, areas1_dict = analyze_one_image(filepath1)
|
162 |
+
_, areas2_dict = analyze_one_image(filepath2)
|
163 |
+
|
164 |
+
mask1_pil = create_colored_mask(analyze_one_image(filepath1)[0])
|
165 |
+
mask2_pil = create_colored_mask(analyze_one_image(filepath2)[0])
|
166 |
+
|
167 |
+
all_class_names = sorted(list(set(areas1_dict.keys()) | set(areas2_dict.keys())))
|
168 |
+
data_for_df = [[name, areas1_dict.get(name, 0), areas2_dict.get(name, 0)] for name in all_class_names]
|
169 |
+
|
170 |
+
df = pd.DataFrame(data_for_df, columns=["Class", "Area 1 (ha)", "Area 2 (ha)"])
|
171 |
+
df['Change (ha)'] = df['Area 2 (ha)'] - df['Area 1 (ha)']
|
172 |
+
df['% Change'] = df.apply(lambda row: (row['Change (ha)'] / row['Area 1 (ha)'] * 100) if row['Area 1 (ha)'] > 0 else float('inf'), axis=1)
|
173 |
+
|
174 |
+
df_display = df.copy()
|
175 |
+
for col in ["Area 1 (ha)", "Area 2 (ha)"]: df_display[col] = df_display[col].map('{:.2f}'.format)
|
176 |
+
df_display["Change (ha)"] = df_display["Change (ha)"].map('{:+.2f}'.format)
|
177 |
+
df_display["% Change"] = df_display["% Change"].apply(lambda x: f"{x:+.2f}%" if x != float('inf') else "New")
|
178 |
+
|
179 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
180 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
181 |
+
index = np.arange(len(df))
|
182 |
+
bar_width = 0.35
|
183 |
+
ax.bar(index - bar_width/2, df['Area 1 (ha)'], bar_width, label='Area 1 (Before)', color='cornflowerblue')
|
184 |
+
ax.bar(index + bar_width/2, df['Area 2 (ha)'], bar_width, label='Area 2 (After)', color='salmon')
|
185 |
+
ax.set_xlabel('Land Cover Class', fontweight='bold')
|
186 |
+
ax.set_ylabel('Area (Hectares)', fontweight='bold')
|
187 |
+
ax.set_title('Land Cover Change Analysis', fontsize=16, fontweight='bold')
|
188 |
+
ax.set_xticks(index)
|
189 |
+
ax.set_xticklabels(df['Class'], rotation=45, ha="right")
|
190 |
+
ax.legend()
|
191 |
+
fig.tight_layout()
|
192 |
+
|
193 |
+
analysis_results = {"df": df_display, "path1": filepath1, "path2": filepath2, "raw_df": df}
|
194 |
+
|
195 |
+
return mask1_pil, mask2_pil, df_display, fig, analysis_results
|
196 |
|
197 |
+
# --- REPORTING FUNCTIONS ---
|
198 |
+
def generate_report(analysis_results, report_type):
|
199 |
+
if not analysis_results:
|
200 |
+
raise gr.Error("Please run an analysis first before generating a report.")
|
201 |
+
|
202 |
+
if report_type == "single":
|
203 |
+
filename = os.path.basename(analysis_results['image_path'])
|
204 |
+
report = f"# LULC Analysis Report: {filename}\n"
|
205 |
+
report += f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
206 |
+
report += "## Area Distribution (Hectares)\n"
|
207 |
+
report += analysis_results['area_df'].to_markdown(index=False)
|
208 |
+
|
209 |
+
elif report_type == "change":
|
210 |
+
file1 = os.path.basename(analysis_results['path1'])
|
211 |
+
file2 = os.path.basename(analysis_results['path2'])
|
212 |
+
df = analysis_results['raw_df']
|
213 |
+
summary = ""
|
214 |
+
df_sorted = df.reindex(df['Change (ha)'].abs().sort_values(ascending=False).index)
|
215 |
+
for _, row in df_sorted.head(3).iterrows():
|
216 |
+
if abs(row['Change (ha)']) > 0.01:
|
217 |
+
direction = "increased" if row['Change (ha)'] > 0 else "decreased"
|
218 |
+
summary += f"- **{row['Class']}** has {direction} by **{abs(row['Change (ha)']):.2f} hectares**.\n"
|
219 |
+
|
220 |
+
report = f"# LULC Change Detection Report\n"
|
221 |
+
report += f"**Comparison:** `{file1}` (Before) vs. `{file2}` (After)\n"
|
222 |
+
report += f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
223 |
+
report += "## Key Summary of Changes\n"
|
224 |
+
report += summary + "\n"
|
225 |
+
report += "## Detailed Comparison Table\n"
|
226 |
+
report += analysis_results['df'].to_markdown(index=False)
|
227 |
+
|
228 |
+
# Switch to the report tab and populate it
|
229 |
+
return {
|
230 |
+
report_editor: gr.update(value=report),
|
231 |
+
download_btn: gr.update(visible=True),
|
232 |
+
tabs: gr.update(selected=2)
|
233 |
+
}
|
234 |
+
|
235 |
+
def save_report_to_file(report_content):
|
236 |
+
filepath = "LULC_Report.md"
|
237 |
+
with open(filepath, "w", encoding="utf-8") as f:
|
238 |
+
f.write(report_content)
|
239 |
+
return filepath
|
240 |
+
|
241 |
+
# --- EXAMPLE FINDER ---
|
242 |
+
def find_examples():
|
243 |
+
single_examples = []
|
244 |
+
change_examples = []
|
245 |
+
if os.path.isdir(CFG.EXAMPLES_DIR):
|
246 |
+
files = sorted([os.path.join(CFG.EXAMPLES_DIR, f) for f in os.listdir(CFG.EXAMPLES_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif'))])
|
247 |
+
single_examples = files[:10] # Take up to 10 for single analysis
|
248 |
+
# Create pairs for change detection
|
249 |
+
if len(files) >= 2:
|
250 |
+
for i in range(0, min(len(files) - 1, 10), 2): # Take up to 5 pairs
|
251 |
+
change_examples.append([files[i], files[i+1]])
|
252 |
+
return single_examples, change_examples
|
253 |
+
|
254 |
+
single_examples, change_examples = find_examples()
|
255 |
+
|
256 |
+
# --- GRADIO UI LAYOUT ---
|
257 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="LULC Analysis Platform") as demo:
|
258 |
+
gr.Markdown("# Land Use & Land Cover (LULC) Analysis Platform")
|
259 |
+
gr.Markdown("An AI-powered tool to analyze satellite imagery for environmental monitoring and planning.")
|
260 |
+
|
261 |
+
# Hidden state objects to store analysis results robustly
|
262 |
+
single_analysis_results = gr.State()
|
263 |
+
change_analysis_results = gr.State()
|
264 |
+
|
265 |
+
with gr.Tabs() as tabs:
|
266 |
+
with gr.TabItem("Single Image Analysis", id=0):
|
267 |
+
with gr.Row(variant="panel"):
|
268 |
+
with gr.Column(scale=1):
|
269 |
+
single_img_input = gr.Image(type="filepath", label="Upload Satellite Image")
|
270 |
+
single_analyze_btn = gr.Button("Analyze Image", variant="primary")
|
271 |
+
with gr.Column(scale=1):
|
272 |
+
single_mask_output = gr.Image(type="pil", label="Predicted Mask")
|
273 |
+
with gr.Row():
|
274 |
+
area_df_output = gr.DataFrame(label="Predicted Area Distribution", wrap=True)
|
275 |
+
send_single_report_btn = gr.Button("➡ Create Report from this Analysis")
|
276 |
+
gr.Examples(examples=single_examples, inputs=single_img_input, label="Click an Example to Start")
|
277 |
+
|
278 |
+
with gr.TabItem("Change Detection Tool", id=1):
|
279 |
+
with gr.Row(variant="panel"):
|
280 |
+
compare_img1 = gr.Image(type="filepath", label="Image 1 (e.g., Before / 2020)")
|
281 |
+
compare_img2 = gr.Image(type="filepath", label="Image 2 (e.g., After / 2024)")
|
282 |
+
compare_analyze_btn = gr.Button("Analyze Changes", variant="primary")
|
283 |
+
with gr.Row():
|
284 |
+
compare_mask1 = gr.Image(type="pil", label="Mask for Image 1")
|
285 |
+
compare_mask2 = gr.Image(type="pil", label="Mask for Image 2")
|
286 |
+
with gr.Tabs():
|
287 |
+
with gr.TabItem("📊 Change Chart"): compare_plot = gr.Plot()
|
288 |
+
with gr.TabItem("📑 Comparison Table"): compare_df = gr.DataFrame(interactive=False)
|
289 |
+
send_change_report_btn = gr.Button("➡ Create Report from this Analysis")
|
290 |
+
if change_examples:
|
291 |
+
gr.Examples(examples=change_examples, inputs=[compare_img1, compare_img2], label="Click an Example Pair to Start")
|
292 |
+
|
293 |
+
with gr.TabItem("Report Builder", id=2):
|
294 |
+
gr.Markdown("### Create and Download Your Analysis Report")
|
295 |
+
gr.Markdown("1. Run an analysis on one of the other tabs.\n"
|
296 |
+
"2. Click the **'➡ Create Report'** button.\n"
|
297 |
+
"3. Your report will appear below. You can edit it before downloading.\n")
|
298 |
+
with gr.Column():
|
299 |
+
report_editor = gr.Textbox(label="Your Report (Editable)", lines=20, interactive=True)
|
300 |
+
download_btn = gr.DownloadButton(label="Download Report (.md)", visible=False)
|
301 |
+
|
302 |
+
# --- BUTTON CLICK EVENTS & DATA FLOW ---
|
303 |
+
|
304 |
+
# Single Image Analysis Flow
|
305 |
+
single_analyze_btn.click(
|
306 |
+
fn=single_image_analysis,
|
307 |
+
inputs=single_img_input,
|
308 |
+
outputs=[single_mask_output, area_df_output, single_analysis_results]
|
309 |
+
).then(
|
310 |
+
lambda: gr.update(interactive=False, value="Analyzing..."), None, single_analyze_btn
|
311 |
+
).then(
|
312 |
+
lambda: gr.update(interactive=True, value="Analyze Image"), None, single_analyze_btn
|
313 |
+
)
|
314 |
|
315 |
+
send_single_report_btn.click(
|
316 |
+
fn=lambda res: generate_report(res, "single"),
|
317 |
+
inputs=single_analysis_results,
|
318 |
+
outputs=[report_editor, download_btn, tabs]
|
319 |
+
)
|
320 |
|
321 |
+
# Change Detection Flow
|
322 |
+
compare_analyze_btn.click(
|
323 |
+
fn=compare_land_cover,
|
324 |
+
inputs=[compare_img1, compare_img2],
|
325 |
+
outputs=[compare_mask1, compare_mask2, compare_df, compare_plot, change_analysis_results]
|
326 |
+
).then(
|
327 |
+
lambda: gr.update(interactive=False, value="Analyzing..."), None, compare_analyze_btn
|
328 |
+
).then(
|
329 |
+
lambda: gr.update(interactive=True, value="Analyze Changes"), None, compare_analyze_btn
|
330 |
+
)
|
331 |
+
|
332 |
+
send_change_report_btn.click(
|
333 |
+
fn=lambda res: generate_report(res, "change"),
|
334 |
+
inputs=change_analysis_results,
|
335 |
+
outputs=[report_editor, download_btn, tabs]
|
336 |
+
)
|
337 |
|
338 |
+
# Report Download Flow
|
339 |
+
download_btn.click(fn=save_report_to_file, inputs=report_editor, outputs=download_btn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
|
341 |
if __name__ == "__main__":
|
342 |
+
demo.launch(debug=True)
|
|
|
|
|
|
|
|
|
|