AdilzhanB commited on
Commit
ac84435
Β·
1 Parent(s): 174ed36
Files changed (2) hide show
  1. app.py +288 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
4
+ from PIL import Image
5
+ import numpy as np
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+
9
+ # EuroSAT class names (10 land cover classes)
10
+ EUROSAT_CLASSES = [
11
+ "AnnualCrop",
12
+ "Forest",
13
+ "HerbaceousVegetation",
14
+ "Highway",
15
+ "Industrial",
16
+ "Pasture",
17
+ "PermanentCrop",
18
+ "Residential",
19
+ "River",
20
+ "SeaLake"
21
+ ]
22
+
23
+ # Class descriptions for better user understanding
24
+ CLASS_DESCRIPTIONS = {
25
+ "AnnualCrop": "🌾 Agricultural land with annual crops",
26
+ "Forest": "🌲 Dense forest areas with trees",
27
+ "HerbaceousVegetation": "🌿 Areas with herbaceous vegetation",
28
+ "Highway": "πŸ›£οΈ Major roads and highway infrastructure",
29
+ "Industrial": "🏭 Industrial areas and facilities",
30
+ "Pasture": "πŸ„ Pasture land for livestock",
31
+ "PermanentCrop": "πŸ‡ Permanent crop areas (vineyards, orchards)",
32
+ "Residential": "🏘️ Residential areas and neighborhoods",
33
+ "River": "🏞️ Rivers and waterways",
34
+ "SeaLake": "πŸ”οΈ Seas, lakes, and large water bodies"
35
+ }
36
+
37
+ class EuroSATClassifier:
38
+ def __init__(self, model_name="Adilbai/EuroSAT-Swin"):
39
+ self.model_name = model_name
40
+ self.processor = None
41
+ self.model = None
42
+ self.load_model()
43
+
44
+ def load_model(self):
45
+ """Load the model and processor"""
46
+ try:
47
+ self.processor = AutoImageProcessor.from_pretrained(self.model_name)
48
+ self.model = AutoModelForImageClassification.from_pretrained(self.model_name)
49
+ self.model.eval()
50
+ print(f"βœ… Model {self.model_name} loaded successfully!")
51
+ except Exception as e:
52
+ print(f"❌ Error loading model: {e}")
53
+ # Fallback to a generic model if the specific one fails
54
+ self.processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
55
+ self.model = AutoModelForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
56
+
57
+ def predict(self, image):
58
+ """Make prediction on the input image"""
59
+ if image is None:
60
+ return None, None, "Please upload an image first!"
61
+
62
+ try:
63
+ # Preprocess the image
64
+ inputs = self.processor(images=image, return_tensors="pt")
65
+
66
+ # Make prediction
67
+ with torch.no_grad():
68
+ outputs = self.model(**inputs)
69
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
70
+
71
+ # Get top predictions
72
+ probabilities = predictions[0].numpy()
73
+
74
+ # Create results dictionary
75
+ results = {}
76
+ for i, class_name in enumerate(EUROSAT_CLASSES):
77
+ if i < len(probabilities):
78
+ results[class_name] = float(probabilities[i])
79
+ else:
80
+ results[class_name] = 0.0
81
+
82
+ # Sort by confidence
83
+ sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
84
+
85
+ # Get top prediction
86
+ top_class = list(sorted_results.keys())[0]
87
+ top_confidence = list(sorted_results.values())[0]
88
+
89
+ # Create confidence plot
90
+ confidence_plot = self.create_confidence_plot(sorted_results)
91
+
92
+ # Format result text
93
+ result_text = f"🎯 **Prediction: {top_class}**\n\n"
94
+ result_text += f"πŸ“Š **Confidence: {top_confidence:.1%}**\n\n"
95
+ result_text += f"πŸ“ **Description: {CLASS_DESCRIPTIONS.get(top_class, 'Land cover classification')}**\n\n"
96
+ result_text += "### Top 3 Predictions:\n"
97
+
98
+ for i, (class_name, confidence) in enumerate(list(sorted_results.items())[:3]):
99
+ result_text += f"{i+1}. **{class_name}**: {confidence:.1%}\n"
100
+
101
+ return sorted_results, confidence_plot, result_text
102
+
103
+ except Exception as e:
104
+ error_msg = f"❌ Error during prediction: {str(e)}"
105
+ return None, None, error_msg
106
+
107
+ def create_confidence_plot(self, results):
108
+ """Create a confidence plot using Plotly"""
109
+ classes = list(results.keys())
110
+ confidences = [results[cls] * 100 for cls in classes]
111
+
112
+ # Create color scale - top prediction in green, others in blue gradient
113
+ colors = ['#2E8B57' if i == 0 else f'rgba(70, 130, 180, {0.8 - i*0.1})' for i in range(len(classes))]
114
+
115
+ fig = go.Figure(data=[
116
+ go.Bar(
117
+ x=confidences,
118
+ y=classes,
119
+ orientation='h',
120
+ marker_color=colors,
121
+ text=[f'{conf:.1f}%' for conf in confidences],
122
+ textposition='inside',
123
+ textfont=dict(color='white', size=12),
124
+ )
125
+ ])
126
+
127
+ fig.update_layout(
128
+ title={
129
+ 'text': "🎯 Classification Confidence Scores",
130
+ 'x': 0.5,
131
+ 'xanchor': 'center',
132
+ 'font': {'size': 16, 'color': '#2C3E50'}
133
+ },
134
+ xaxis_title="Confidence (%)",
135
+ yaxis_title="Land Cover Classes",
136
+ height=500,
137
+ margin=dict(l=10, r=10, t=50, b=10),
138
+ plot_bgcolor='rgba(248, 249, 250, 0.8)',
139
+ paper_bgcolor='white',
140
+ font=dict(family="Arial, sans-serif", size=12, color="#2C3E50"),
141
+ xaxis=dict(
142
+ gridcolor='rgba(128, 128, 128, 0.2)',
143
+ showgrid=True,
144
+ range=[0, 100]
145
+ ),
146
+ yaxis=dict(
147
+ gridcolor='rgba(128, 128, 128, 0.2)',
148
+ showgrid=True,
149
+ autorange="reversed" # Show highest confidence at top
150
+ )
151
+ )
152
+
153
+ return fig
154
+
155
+ # Initialize the classifier
156
+ classifier = EuroSATClassifier()
157
+
158
+ def classify_image(image):
159
+ """Main classification function for Gradio interface"""
160
+ return classifier.predict(image)
161
+
162
+ def get_sample_images():
163
+ """Return some sample image descriptions"""
164
+ return """
165
+ ### πŸ–ΌοΈ Try these types of satellite images:
166
+
167
+ - **🌾 Agricultural fields** - Crop lands and farmland
168
+ - **🌲 Forest areas** - Dense tree coverage
169
+ - **🏘️ Residential zones** - Urban neighborhoods
170
+ - **🏭 Industrial sites** - Factories and industrial areas
171
+ - **πŸ›£οΈ Highway systems** - Major roads and intersections
172
+ - **πŸ’§ Water bodies** - Rivers, lakes, and seas
173
+ - **🌿 Natural vegetation** - Grasslands and natural areas
174
+
175
+ Upload a satellite/aerial image to see the land cover classification!
176
+ """
177
+
178
+ # Custom CSS for better styling
179
+ custom_css = """
180
+ .gradio-container {
181
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
182
+ }
183
+
184
+ .main-header {
185
+ text-align: center;
186
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
187
+ color: white;
188
+ padding: 2rem;
189
+ border-radius: 10px;
190
+ margin-bottom: 2rem;
191
+ }
192
+
193
+ .upload-area {
194
+ border: 2px dashed #667eea;
195
+ border-radius: 10px;
196
+ padding: 2rem;
197
+ text-align: center;
198
+ background: rgba(102, 126, 234, 0.05);
199
+ }
200
+
201
+ .result-text {
202
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
203
+ padding: 1.5rem;
204
+ border-radius: 10px;
205
+ border-left: 4px solid #667eea;
206
+ }
207
+ """
208
+
209
+ # Create the Gradio interface
210
+ with gr.Blocks(css=custom_css, title="πŸ›°οΈ EuroSAT Land Cover Classifier") as demo:
211
+ gr.HTML("""
212
+ <div class="main-header">
213
+ <h1>πŸ›°οΈ EuroSAT Land Cover Classifier</h1>
214
+ <p>Advanced satellite image classification using Swin Transformer</p>
215
+ <p><strong>Model:</strong> Adilbai/EuroSAT-Swin | <strong>Dataset:</strong> EuroSAT (10 land cover classes)</p>
216
+ </div>
217
+ """)
218
+
219
+ with gr.Row():
220
+ with gr.Column(scale=1):
221
+ gr.HTML("<h3>πŸ“€ Upload Satellite Image</h3>")
222
+ image_input = gr.Image(
223
+ label="Upload a satellite/aerial image",
224
+ type="pil",
225
+ height=400,
226
+ elem_classes="upload-area"
227
+ )
228
+
229
+ classify_btn = gr.Button(
230
+ "πŸ” Classify Land Cover",
231
+ variant="primary",
232
+ size="lg"
233
+ )
234
+
235
+ gr.HTML("<div style='margin-top: 2rem;'>")
236
+ gr.Markdown(get_sample_images())
237
+ gr.HTML("</div>")
238
+
239
+ with gr.Column(scale=1):
240
+ gr.HTML("<h3>πŸ“Š Classification Results</h3>")
241
+
242
+ result_text = gr.Markdown(
243
+ value="Upload an image and click 'Classify Land Cover' to see results!",
244
+ elem_classes="result-text"
245
+ )
246
+
247
+ confidence_plot = gr.Plot(
248
+ label="Confidence Scores",
249
+ height=500
250
+ )
251
+
252
+ # Hidden component to store raw results
253
+ raw_results = gr.JSON(visible=False)
254
+
255
+ # Event handlers
256
+ classify_btn.click(
257
+ fn=classify_image,
258
+ inputs=[image_input],
259
+ outputs=[raw_results, confidence_plot, result_text]
260
+ )
261
+
262
+ # Also trigger on image upload
263
+ image_input.change(
264
+ fn=classify_image,
265
+ inputs=[image_input],
266
+ outputs=[raw_results, confidence_plot, result_text]
267
+ )
268
+
269
+ # Footer
270
+ gr.HTML("""
271
+ <div style="text-align: center; margin-top: 3rem; padding: 2rem; background: #f8f9fa; border-radius: 10px;">
272
+ <h4>πŸ”¬ About This Model</h4>
273
+ <p>This classifier uses the <strong>Swin Transformer</strong> architecture trained on the <strong>EuroSAT dataset</strong>.</p>
274
+ <p>The EuroSAT dataset contains <strong>27,000 satellite images</strong> from <strong>34 European countries</strong>, covering <strong>10 different land cover classes</strong>.</p>
275
+ <p>Perfect for environmental monitoring, urban planning, and agricultural analysis! 🌍</p>
276
+ <br>
277
+ <p><strong>Model:</strong> <a href="https://huggingface.co/Adilbai/EuroSAT-Swin" target="_blank">Adilbai/EuroSAT-Swin</a></p>
278
+ </div>
279
+ """)
280
+
281
+ # Launch the app
282
+ if __name__ == "__main__":
283
+ demo.launch(
284
+ share=True,
285
+ server_name="0.0.0.0",
286
+ server_port=7860,
287
+ show_error=True
288
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=1.9.0
3
+ transformers>=4.21.0
4
+ Pillow>=8.3.0
5
+ numpy>=1.21.0
6
+ plotly>=5.0.0