Vishalpainjane commited on
Commit
22c89e3
·
verified ·
1 Parent(s): 92ed33b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -177
app.py CHANGED
@@ -1,52 +1,63 @@
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
- from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
- import torchvision.transforms as T
7
- import torchvision.transforms.functional as F
 
8
  import segmentation_models_pytorch as smp
 
 
9
  import os
10
- import _pickle # Imported to catch the specific error type
 
11
 
12
- # --- 1. Configuration and Constants ---
 
13
 
 
14
  class CFG:
15
- """
16
- Configuration class adapted from the provided training script.
17
- """
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
19
  MODEL_NAME = "CustomDeepLabV3+"
20
  ENCODER_NAME = "timm-efficientnet-b2"
21
- IN_CHANNELS = 3
22
  NUM_CLASSES = 8
23
  IMG_SIZE = 256
24
- # NOTE: Path to your local trained model file.
25
- MODEL_SAVE_PATH = "best_model_optimized_83.98.pth"
26
-
 
 
 
27
 
28
- # Class definitions and color map based on the research paper
29
  CLASS_INFO = {
30
- 0: {"name": "Unclassified", "color": (150, 150, 150)},
31
- 1: {"name": "Water Bodies", "color": (0, 0, 255)}, # Blue
32
- 2: {"name": "Dense Forest", "color": (0, 100, 0)}, # Dark Green
33
- 3: {"name": "Built up", "color": (128, 0, 128)}, # Purple
34
- 4: {"name": "Agriculture land", "color": (0, 255, 0)}, # Bright Green
35
- 5: {"name": "Barren land", "color": (255, 255, 0)}, # Yellow
36
- 6: {"name": "Fallow land", "color": (210, 180, 140)}, # Tan
37
- 7: {"name": "Sparse Forest", "color": (60, 179, 113)}, # Medium Sea Green
38
  }
39
 
40
- PIXEL_AREA_SQ_METERS = 10 * 10 # Based on 10m resolution from the dataset
41
- SQ_METERS_PER_HECTARE = 10000
42
-
43
- # --- 2. Model Definitions (from provided training script) ---
44
-
45
  class SELayer(nn.Module):
46
  def __init__(self, channel, reduction=16):
47
  super(SELayer, self).__init__()
48
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
49
- self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid())
 
 
 
 
 
 
50
  def forward(self, x):
51
  b, c, _, _ = x.size()
52
  y = self.avg_pool(x).view(b, c)
@@ -56,180 +67,276 @@ class SELayer(nn.Module):
56
  class CustomDeepLabV3Plus(nn.Module):
57
  def __init__(self, encoder_name, in_channels, classes):
58
  super().__init__()
59
- self.smp_model = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights="imagenet", in_channels=in_channels, classes=classes)
 
 
 
 
 
60
  decoder_channels = self.smp_model.segmentation_head[0].in_channels
61
  self.se_layer = SELayer(decoder_channels)
62
  self.segmentation_head = self.smp_model.segmentation_head
63
  self.smp_model.segmentation_head = nn.Identity()
 
64
  def forward(self, x):
65
  decoder_features = self.smp_model(x)
66
  attended_features = self.se_layer(decoder_features)
67
  output = self.segmentation_head(attended_features)
68
  return output
69
 
70
- # --- 3. Helper Functions ---
71
-
72
  def load_model():
73
- """
74
- Loads the segmentation model from a local path and checks for Git LFS issues.
75
- """
76
- model_path = CFG.MODEL_SAVE_PATH
77
-
78
- # 1. Check if the model file exists.
79
- if not os.path.exists(model_path):
80
- raise FileNotFoundError(
81
- f"Model file not found at '{model_path}'. "
82
- "Please ensure the model is in the correct directory and the path is correct."
83
- )
84
-
85
- # 2. Check if the file is a Git LFS pointer instead of the actual model.
86
- # Real model files are many megabytes; pointer files are typically < 1KB.
87
- if os.path.getsize(model_path) < 1024 * 1024: # 1 MB threshold
88
- raise ValueError(
89
- f"The file at '{model_path}' is too small to be a valid model. "
90
- "It is likely a Git LFS pointer file. Please ensure you have the full, correct model file. "
91
- "If using Git, you may need to run 'git lfs pull' to download the actual file."
92
- )
93
-
94
- # 3. If checks pass, create the model structure.
95
- model = CustomDeepLabV3Plus(
96
- encoder_name=CFG.ENCODER_NAME,
97
- in_channels=CFG.IN_CHANNELS,
98
- classes=CFG.NUM_CLASSES
99
- )
100
 
101
- # 4. Load the model weights.
102
- try:
103
- model.load_state_dict(torch.load(model_path, map_location=torch.device(CFG.DEVICE), weights_only=False))
104
- except _pickle.UnpicklingError:
105
- # This error is a strong indicator of an LFS pointer or corrupted file.
106
- raise IOError(
107
- f"Failed to load '{model_path}'. The file is not a valid PyTorch model. "
108
- "This confirms the file is likely a Git LFS text pointer. Please replace it with the actual model file."
109
- )
110
- except Exception as e:
111
- # Catch other potential loading errors
112
- raise IOError(f"An unexpected error occurred while loading the model from '{model_path}'. Original error: {e}")
113
-
114
  model.to(CFG.DEVICE)
115
  model.eval()
116
- print("Model loaded successfully on device:", CFG.DEVICE)
117
  return model
118
 
119
-
120
- def preprocess_image(img: Image.Image):
121
- """Preprocesses a PIL image for model inference."""
122
- img = F.resize(img, [CFG.IMG_SIZE, CFG.IMG_SIZE], interpolation=T.InterpolationMode.BILINEAR)
123
- img_tensor = F.to_tensor(img)
124
- normalize_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
125
- normalized_tensor = normalize_transform(img_tensor)
126
- return normalized_tensor.unsqueeze(0) # Add batch dimension
127
-
128
- def create_colored_mask(mask: np.ndarray) -> Image.Image:
129
- """Creates a colored RGB image from a class index mask."""
130
- rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
131
  for class_id, info in CLASS_INFO.items():
132
- rgb_mask[mask == class_id] = info["color"]
133
- return Image.fromarray(rgb_mask)
134
-
135
- def create_legend_image():
136
- """Generates an image of the color legend."""
137
- item_height = 25
138
- width = 250
139
- height = item_height * len(CLASS_INFO)
140
- legend_img = Image.new('RGB', (width, height), 'white')
141
- draw = ImageDraw.Draw(legend_img)
142
-
143
- try:
144
- # Use a common font, handle case where it's not found
145
- font = ImageFont.truetype("LiberationSans-Regular.ttf", 14)
146
- except IOError:
147
- print("LiberationSans font not found, falling back to default.")
148
- font = ImageFont.load_default()
149
-
150
- for i, (class_id, info) in enumerate(CLASS_INFO.items()):
151
- y_start = i * item_height
152
- draw.rectangle([5, y_start + 5, 25, y_start + 20], fill=info["color"])
153
- draw.text((35, y_start + 5), f'{info["name"]}', fill='black', font=font)
154
 
155
- return legend_img
156
-
157
- # --- 4. Main Prediction Function ---
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Load the model and legend once when the script starts
160
- model = load_model()
161
- legend_image = create_legend_image()
 
 
 
 
 
 
 
 
 
 
162
 
163
- def predict_land_cover(input_image: Image.Image):
164
- """
165
- Takes an input image, performs segmentation, and calculates area for each class.
166
- """
167
- if input_image is None:
168
- raise gr.Error("Please upload an image.")
169
 
170
- # 1. Preprocess the image and get model prediction
171
- input_tensor = preprocess_image(input_image.convert("RGB")).to(CFG.DEVICE)
172
- with torch.no_grad():
173
- output = model(input_tensor)
174
- pred_mask = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- # 2. Create the colored segmentation mask for visualization
177
- colored_mask = create_colored_mask(pred_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- # 3. Calculate area for each class
180
- class_indices, pixel_counts = np.unique(pred_mask, return_counts=True)
 
 
 
181
 
182
- area_results = {}
183
- total_pixels = pred_mask.size
184
- total_area_ha = (total_pixels * PIXEL_AREA_SQ_METERS) / SQ_METERS_PER_HECTARE
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- for class_id, count in zip(class_indices, pixel_counts):
187
- if class_id in CLASS_INFO:
188
- class_name = CLASS_INFO[class_id]["name"]
189
- area_hectares = (count * PIXEL_AREA_SQ_METERS) / SQ_METERS_PER_HECTARE
190
- percentage = (count / total_pixels) * 100
191
- area_results[class_name] = f"{area_hectares:.4f} Hectares ({percentage:.2f}%)"
192
-
193
- return colored_mask, area_results, legend_image
194
-
195
- # --- 5. Gradio Interface ---
196
-
197
- # Load example images from your file system.
198
- example_files = ["comparison_1.jpg", "comparison_4.jpg", "comparison_5.jpg", "comparison_8.jpg"]
199
- example_images = [file for file in example_files if os.path.exists(file)]
200
-
201
- app_description = """
202
- ### Sen-2 LULC: Interactive Land Cover Analysis
203
- This tool performs semantic segmentation on satellite imagery to classify different types of land use and land cover (LULC).
204
- It is based on the **Sen-2 LULC dataset** and a **Custom DeepLabV3+** model as described in the research paper "Sen-2 LULC: Land Use Land Cover Dataset for Deep Learning Approaches".
205
-
206
- **How it works:**
207
- 1. **Upload Image:** Upload a satellite image patch. You can use the examples below.
208
- 2. **Model Prediction:** The deep learning model classifies each pixel into one of the 8 LULC classes.
209
- 3. **Analysis:** The application generates:
210
- * A color-coded **Segmentation Mask** for visual analysis.
211
- * A **Quantitative Report** calculating the area of each land class in hectares. This is possible because the source Sentinel-2 imagery has a **10m resolution**, meaning each pixel represents a 100m² area on the ground.
212
- """
213
-
214
- iface = gr.Interface(
215
- fn=predict_land_cover,
216
- inputs=gr.Image(type="pil", label="Upload Satellite Image Patch"),
217
- outputs=[
218
- gr.Image(type="pil", label="Predicted Segmentation Mask"),
219
- gr.Label(label="Land Cover Area Analysis (Hectares)"),
220
- gr.Image(type="pil", label="Legend")
221
- ],
222
- title="Land Use & Land Cover Segmentation and Area Analysis",
223
- description=app_description,
224
- examples=example_images,
225
- cache_examples=False,
226
- allow_flagging="never"
227
- )
228
 
229
  if __name__ == "__main__":
230
- iface.launch(
231
- debug=True,
232
- server_name="0.0.0.0", # Allow external connections
233
- server_port=7860,
234
- share=True
235
- )
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
+ from PIL import Image
5
  import numpy as np
6
+ import pandas as pd
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
  import segmentation_models_pytorch as smp
10
+ import albumentations as A
11
+ from albumentations.pytorch import ToTensorV2
12
  import os
13
+ import random
14
+ from datetime importdatetime
15
 
16
+ # --- Best Practice: Set Matplotlib backend for server environments ---
17
+ matplotlib.use('Agg')
18
 
19
+ # --- CONFIGURATION (UPDATED FOR DEPLOYMENT) ---
20
  class CFG:
 
 
 
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ # CRITICAL: Use relative paths for deployment.
24
+ # Place your model file in the root of your Hugging Face Space repository.
25
+ MODEL_PATH = "best_model_optimized_83.98.pth"
26
+
27
+ # The app will scan this local folder for example images.
28
+ EXAMPLES_DIR = "examples"
29
+
30
  MODEL_NAME = "CustomDeepLabV3+"
31
  ENCODER_NAME = "timm-efficientnet-b2"
 
32
  NUM_CLASSES = 8
33
  IMG_SIZE = 256
34
+
35
+ # Constants for area calculation
36
+ ORIGINAL_PATCH_DIM = 64
37
+ RESOLUTION_M_PER_PIXEL = 10
38
+ SQ_METERS_PER_HECTARE = 10000
39
+ TOTAL_PATCH_AREA_HECTARES = (ORIGINAL_PATCH_DIM**2 * RESOLUTION_M_PER_PIXEL**2) / SQ_METERS_PER_HECTARE
40
 
41
+ # --- DATA & CLASS INFO ---
42
  CLASS_INFO = {
43
+ 0: {"name": "Unclassified", "hex": "#969696"}, 1: {"name": "Water Bodies", "hex": "#0000FF"},
44
+ 2: {"name": "Dense Forest", "hex": "#006400"}, 3: {"name": "Built up", "hex": "#800080"},
45
+ 4: {"name": "Agriculture land", "hex": "#00FF00"}, 5: {"name": "Barren land", "hex": "#FFFF00"},
46
+ 6: {"name": "Fallow land", "hex": "#D2B48C"}, 7: {"name": "Sparse Forest", "hex": "#3CB371"},
 
 
 
 
47
  }
48
 
49
+ # --- MODEL DEFINITION (REFORMATTED FOR READABILITY) ---
 
 
 
 
50
  class SELayer(nn.Module):
51
  def __init__(self, channel, reduction=16):
52
  super(SELayer, self).__init__()
53
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
54
+ self.fc = nn.Sequential(
55
+ nn.Linear(channel, channel // reduction, bias=False),
56
+ nn.ReLU(inplace=True),
57
+ nn.Linear(channel // reduction, channel, bias=False),
58
+ nn.Sigmoid()
59
+ )
60
+
61
  def forward(self, x):
62
  b, c, _, _ = x.size()
63
  y = self.avg_pool(x).view(b, c)
 
67
  class CustomDeepLabV3Plus(nn.Module):
68
  def __init__(self, encoder_name, in_channels, classes):
69
  super().__init__()
70
+ self.smp_model = smp.DeepLabV3Plus(
71
+ encoder_name=encoder_name,
72
+ encoder_weights="imagenet",
73
+ in_channels=in_channels,
74
+ classes=classes
75
+ )
76
  decoder_channels = self.smp_model.segmentation_head[0].in_channels
77
  self.se_layer = SELayer(decoder_channels)
78
  self.segmentation_head = self.smp_model.segmentation_head
79
  self.smp_model.segmentation_head = nn.Identity()
80
+
81
  def forward(self, x):
82
  decoder_features = self.smp_model(x)
83
  attended_features = self.se_layer(decoder_features)
84
  output = self.segmentation_head(attended_features)
85
  return output
86
 
87
+ # --- MODEL LOADING & TRANSFORMS ---
 
88
  def load_model():
89
+ print(f"Loading model from {CFG.MODEL_PATH} on device {CFG.DEVICE}...")
90
+ model = CustomDeepLabV3Plus(encoder_name=CFG.ENCODER_NAME, in_channels=3, classes=CFG.NUM_CLASSES)
91
+ if not os.path.exists(CFG.MODEL_PATH):
92
+ raise FileNotFoundError(f"CRITICAL: Model file not found at '{CFG.MODEL_PATH}'. Please ensure the model file is in the root directory of your Space.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ # Using weights_only=True is safer
95
+ model.load_state_dict(torch.load(CFG.MODEL_PATH, map_location=torch.device(CFG.DEVICE), weights_only=True))
 
 
 
 
 
 
 
 
 
 
 
96
  model.to(CFG.DEVICE)
97
  model.eval()
98
+ print("Model loaded successfully!")
99
  return model
100
 
101
+ model = load_model()
102
+ transform = A.Compose([
103
+ A.Resize(height=CFG.IMG_SIZE, width=CFG.IMG_SIZE),
104
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
105
+ ToTensorV2()
106
+ ])
107
+
108
+ # --- HELPER & ANALYSIS FUNCTIONS ---
109
+ def create_color_map():
110
+ color_map = np.zeros((256, 3), dtype=np.uint8)
 
 
111
  for class_id, info in CLASS_INFO.items():
112
+ color_map[class_id] = tuple(int(info['hex'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
113
+ return color_map
114
+
115
+ COLOR_MAP_NUMPY = create_color_map()
116
+ def create_colored_mask(mask_np):
117
+ return Image.fromarray(COLOR_MAP_NUMPY[mask_np])
118
+
119
+ def analyze_one_image(image_filepath: str):
120
+ if image_filepath is None: return None, {}
121
+ image = Image.open(image_filepath)
122
+ image_np = np.array(image.convert("RGB"))
123
+ transformed = transform(image=image_np)
124
+ input_tensor = transformed['image'].unsqueeze(0).to(CFG.DEVICE)
125
+
126
+ with torch.no_grad():
127
+ prediction = model(input_tensor)
 
 
 
 
 
 
128
 
129
+ pred_mask = torch.argmax(prediction.squeeze(), dim=0).cpu().numpy()
130
+
131
+ area_results = {}
132
+ class_indices, pixel_counts = np.unique(pred_mask, return_counts=True)
133
+ total_pixels_in_mask = pred_mask.size
134
+
135
+ for class_id, count in zip(class_indices, pixel_counts):
136
+ if class_id in CLASS_INFO:
137
+ pixel_proportion = count / total_pixels_in_mask
138
+ area_hectares = pixel_proportion * CFG.TOTAL_PATCH_AREA_HECTARES
139
+ area_results[CLASS_INFO[class_id]["name"]] = area_hectares
140
+
141
+ return pred_mask, area_results
142
 
143
+ def single_image_analysis(image_filepath: str):
144
+ if image_filepath is None: raise gr.Error("Please upload an image to analyze.")
145
+
146
+ pred_mask_np, areas_dict = analyze_one_image(image_filepath)
147
+ pred_mask_pil = create_colored_mask(pred_mask_np)
148
+
149
+ area_data = sorted(areas_dict.items(), key=lambda item: item[1], reverse=True)
150
+ area_df = pd.DataFrame(area_data, columns=["Land Cover Class", "Area (Hectares)"])
151
+ area_df["Area (Hectares)"] = area_df["Area (Hectares)"].map('{:.4f}'.format)
152
+
153
+ analysis_results = {"areas": areas_dict, "area_df": area_df, "image_path": image_filepath}
154
+
155
+ return pred_mask_pil, area_df, analysis_results
156
 
157
+ def compare_land_cover(filepath1: str, filepath2: str):
158
+ if filepath1 is None or filepath2 is None:
159
+ raise gr.Error("Please upload both a 'Before' and 'After' image for comparison.")
 
 
 
160
 
161
+ _, areas1_dict = analyze_one_image(filepath1)
162
+ _, areas2_dict = analyze_one_image(filepath2)
163
+
164
+ mask1_pil = create_colored_mask(analyze_one_image(filepath1)[0])
165
+ mask2_pil = create_colored_mask(analyze_one_image(filepath2)[0])
166
+
167
+ all_class_names = sorted(list(set(areas1_dict.keys()) | set(areas2_dict.keys())))
168
+ data_for_df = [[name, areas1_dict.get(name, 0), areas2_dict.get(name, 0)] for name in all_class_names]
169
+
170
+ df = pd.DataFrame(data_for_df, columns=["Class", "Area 1 (ha)", "Area 2 (ha)"])
171
+ df['Change (ha)'] = df['Area 2 (ha)'] - df['Area 1 (ha)']
172
+ df['% Change'] = df.apply(lambda row: (row['Change (ha)'] / row['Area 1 (ha)'] * 100) if row['Area 1 (ha)'] > 0 else float('inf'), axis=1)
173
+
174
+ df_display = df.copy()
175
+ for col in ["Area 1 (ha)", "Area 2 (ha)"]: df_display[col] = df_display[col].map('{:.2f}'.format)
176
+ df_display["Change (ha)"] = df_display["Change (ha)"].map('{:+.2f}'.format)
177
+ df_display["% Change"] = df_display["% Change"].apply(lambda x: f"{x:+.2f}%" if x != float('inf') else "New")
178
+
179
+ plt.style.use('seaborn-v0_8-whitegrid')
180
+ fig, ax = plt.subplots(figsize=(10, 6))
181
+ index = np.arange(len(df))
182
+ bar_width = 0.35
183
+ ax.bar(index - bar_width/2, df['Area 1 (ha)'], bar_width, label='Area 1 (Before)', color='cornflowerblue')
184
+ ax.bar(index + bar_width/2, df['Area 2 (ha)'], bar_width, label='Area 2 (After)', color='salmon')
185
+ ax.set_xlabel('Land Cover Class', fontweight='bold')
186
+ ax.set_ylabel('Area (Hectares)', fontweight='bold')
187
+ ax.set_title('Land Cover Change Analysis', fontsize=16, fontweight='bold')
188
+ ax.set_xticks(index)
189
+ ax.set_xticklabels(df['Class'], rotation=45, ha="right")
190
+ ax.legend()
191
+ fig.tight_layout()
192
+
193
+ analysis_results = {"df": df_display, "path1": filepath1, "path2": filepath2, "raw_df": df}
194
+
195
+ return mask1_pil, mask2_pil, df_display, fig, analysis_results
196
 
197
+ # --- REPORTING FUNCTIONS ---
198
+ def generate_report(analysis_results, report_type):
199
+ if not analysis_results:
200
+ raise gr.Error("Please run an analysis first before generating a report.")
201
+
202
+ if report_type == "single":
203
+ filename = os.path.basename(analysis_results['image_path'])
204
+ report = f"# LULC Analysis Report: {filename}\n"
205
+ report += f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
206
+ report += "## Area Distribution (Hectares)\n"
207
+ report += analysis_results['area_df'].to_markdown(index=False)
208
+
209
+ elif report_type == "change":
210
+ file1 = os.path.basename(analysis_results['path1'])
211
+ file2 = os.path.basename(analysis_results['path2'])
212
+ df = analysis_results['raw_df']
213
+ summary = ""
214
+ df_sorted = df.reindex(df['Change (ha)'].abs().sort_values(ascending=False).index)
215
+ for _, row in df_sorted.head(3).iterrows():
216
+ if abs(row['Change (ha)']) > 0.01:
217
+ direction = "increased" if row['Change (ha)'] > 0 else "decreased"
218
+ summary += f"- **{row['Class']}** has {direction} by **{abs(row['Change (ha)']):.2f} hectares**.\n"
219
+
220
+ report = f"# LULC Change Detection Report\n"
221
+ report += f"**Comparison:** `{file1}` (Before) vs. `{file2}` (After)\n"
222
+ report += f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
223
+ report += "## Key Summary of Changes\n"
224
+ report += summary + "\n"
225
+ report += "## Detailed Comparison Table\n"
226
+ report += analysis_results['df'].to_markdown(index=False)
227
+
228
+ # Switch to the report tab and populate it
229
+ return {
230
+ report_editor: gr.update(value=report),
231
+ download_btn: gr.update(visible=True),
232
+ tabs: gr.update(selected=2)
233
+ }
234
+
235
+ def save_report_to_file(report_content):
236
+ filepath = "LULC_Report.md"
237
+ with open(filepath, "w", encoding="utf-8") as f:
238
+ f.write(report_content)
239
+ return filepath
240
+
241
+ # --- EXAMPLE FINDER ---
242
+ def find_examples():
243
+ single_examples = []
244
+ change_examples = []
245
+ if os.path.isdir(CFG.EXAMPLES_DIR):
246
+ files = sorted([os.path.join(CFG.EXAMPLES_DIR, f) for f in os.listdir(CFG.EXAMPLES_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tif'))])
247
+ single_examples = files[:10] # Take up to 10 for single analysis
248
+ # Create pairs for change detection
249
+ if len(files) >= 2:
250
+ for i in range(0, min(len(files) - 1, 10), 2): # Take up to 5 pairs
251
+ change_examples.append([files[i], files[i+1]])
252
+ return single_examples, change_examples
253
+
254
+ single_examples, change_examples = find_examples()
255
+
256
+ # --- GRADIO UI LAYOUT ---
257
+ with gr.Blocks(theme=gr.themes.Soft(), title="LULC Analysis Platform") as demo:
258
+ gr.Markdown("# Land Use & Land Cover (LULC) Analysis Platform")
259
+ gr.Markdown("An AI-powered tool to analyze satellite imagery for environmental monitoring and planning.")
260
+
261
+ # Hidden state objects to store analysis results robustly
262
+ single_analysis_results = gr.State()
263
+ change_analysis_results = gr.State()
264
+
265
+ with gr.Tabs() as tabs:
266
+ with gr.TabItem("Single Image Analysis", id=0):
267
+ with gr.Row(variant="panel"):
268
+ with gr.Column(scale=1):
269
+ single_img_input = gr.Image(type="filepath", label="Upload Satellite Image")
270
+ single_analyze_btn = gr.Button("Analyze Image", variant="primary")
271
+ with gr.Column(scale=1):
272
+ single_mask_output = gr.Image(type="pil", label="Predicted Mask")
273
+ with gr.Row():
274
+ area_df_output = gr.DataFrame(label="Predicted Area Distribution", wrap=True)
275
+ send_single_report_btn = gr.Button("➡ Create Report from this Analysis")
276
+ gr.Examples(examples=single_examples, inputs=single_img_input, label="Click an Example to Start")
277
+
278
+ with gr.TabItem("Change Detection Tool", id=1):
279
+ with gr.Row(variant="panel"):
280
+ compare_img1 = gr.Image(type="filepath", label="Image 1 (e.g., Before / 2020)")
281
+ compare_img2 = gr.Image(type="filepath", label="Image 2 (e.g., After / 2024)")
282
+ compare_analyze_btn = gr.Button("Analyze Changes", variant="primary")
283
+ with gr.Row():
284
+ compare_mask1 = gr.Image(type="pil", label="Mask for Image 1")
285
+ compare_mask2 = gr.Image(type="pil", label="Mask for Image 2")
286
+ with gr.Tabs():
287
+ with gr.TabItem("📊 Change Chart"): compare_plot = gr.Plot()
288
+ with gr.TabItem("📑 Comparison Table"): compare_df = gr.DataFrame(interactive=False)
289
+ send_change_report_btn = gr.Button("➡ Create Report from this Analysis")
290
+ if change_examples:
291
+ gr.Examples(examples=change_examples, inputs=[compare_img1, compare_img2], label="Click an Example Pair to Start")
292
+
293
+ with gr.TabItem("Report Builder", id=2):
294
+ gr.Markdown("### Create and Download Your Analysis Report")
295
+ gr.Markdown("1. Run an analysis on one of the other tabs.\n"
296
+ "2. Click the **'➡ Create Report'** button.\n"
297
+ "3. Your report will appear below. You can edit it before downloading.\n")
298
+ with gr.Column():
299
+ report_editor = gr.Textbox(label="Your Report (Editable)", lines=20, interactive=True)
300
+ download_btn = gr.DownloadButton(label="Download Report (.md)", visible=False)
301
+
302
+ # --- BUTTON CLICK EVENTS & DATA FLOW ---
303
+
304
+ # Single Image Analysis Flow
305
+ single_analyze_btn.click(
306
+ fn=single_image_analysis,
307
+ inputs=single_img_input,
308
+ outputs=[single_mask_output, area_df_output, single_analysis_results]
309
+ ).then(
310
+ lambda: gr.update(interactive=False, value="Analyzing..."), None, single_analyze_btn
311
+ ).then(
312
+ lambda: gr.update(interactive=True, value="Analyze Image"), None, single_analyze_btn
313
+ )
314
 
315
+ send_single_report_btn.click(
316
+ fn=lambda res: generate_report(res, "single"),
317
+ inputs=single_analysis_results,
318
+ outputs=[report_editor, download_btn, tabs]
319
+ )
320
 
321
+ # Change Detection Flow
322
+ compare_analyze_btn.click(
323
+ fn=compare_land_cover,
324
+ inputs=[compare_img1, compare_img2],
325
+ outputs=[compare_mask1, compare_mask2, compare_df, compare_plot, change_analysis_results]
326
+ ).then(
327
+ lambda: gr.update(interactive=False, value="Analyzing..."), None, compare_analyze_btn
328
+ ).then(
329
+ lambda: gr.update(interactive=True, value="Analyze Changes"), None, compare_analyze_btn
330
+ )
331
+
332
+ send_change_report_btn.click(
333
+ fn=lambda res: generate_report(res, "change"),
334
+ inputs=change_analysis_results,
335
+ outputs=[report_editor, download_btn, tabs]
336
+ )
337
 
338
+ # Report Download Flow
339
+ download_btn.click(fn=save_report_to_file, inputs=report_editor, outputs=download_btn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  if __name__ == "__main__":
342
+ demo.launch(debug=True)