|
import gradio as gr |
|
import torch |
|
from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
from PIL import Image |
|
import numpy as np |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
|
|
|
|
EUROSAT_CLASSES = [ |
|
"AnnualCrop", |
|
"Forest", |
|
"HerbaceousVegetation", |
|
"Highway", |
|
"Industrial", |
|
"Pasture", |
|
"PermanentCrop", |
|
"Residential", |
|
"River", |
|
"SeaLake" |
|
] |
|
|
|
|
|
CLASS_DESCRIPTIONS = { |
|
"AnnualCrop": "πΎ Agricultural land with annual crops", |
|
"Forest": "π² Dense forest areas with trees", |
|
"HerbaceousVegetation": "πΏ Areas with herbaceous vegetation", |
|
"Highway": "π£οΈ Major roads and highway infrastructure", |
|
"Industrial": "π Industrial areas and facilities", |
|
"Pasture": "π Pasture land for livestock", |
|
"PermanentCrop": "π Permanent crop areas (vineyards, orchards)", |
|
"Residential": "ποΈ Residential areas and neighborhoods", |
|
"River": "ποΈ Rivers and waterways", |
|
"SeaLake": "ποΈ Seas, lakes, and large water bodies" |
|
} |
|
|
|
class EuroSATClassifier: |
|
def __init__(self, model_name="Adilbai/EuroSAT-Swin"): |
|
self.model_name = model_name |
|
self.processor = None |
|
self.model = None |
|
self.load_model() |
|
|
|
def load_model(self): |
|
"""Load the model and processor""" |
|
try: |
|
self.processor = AutoImageProcessor.from_pretrained(self.model_name) |
|
self.model = AutoModelForImageClassification.from_pretrained(self.model_name) |
|
self.model.eval() |
|
print(f"β
Model {self.model_name} loaded successfully!") |
|
except Exception as e: |
|
print(f"β Error loading model: {e}") |
|
|
|
self.processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") |
|
self.model = AutoModelForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224") |
|
|
|
def predict(self, image): |
|
"""Make prediction on the input image""" |
|
if image is None: |
|
return None, None, "Please upload an image first!" |
|
|
|
try: |
|
|
|
inputs = self.processor(images=image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
probabilities = predictions[0].numpy() |
|
|
|
|
|
results = {} |
|
for i, class_name in enumerate(EUROSAT_CLASSES): |
|
if i < len(probabilities): |
|
results[class_name] = float(probabilities[i]) |
|
else: |
|
results[class_name] = 0.0 |
|
|
|
|
|
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) |
|
|
|
|
|
top_class = list(sorted_results.keys())[0] |
|
top_confidence = list(sorted_results.values())[0] |
|
|
|
|
|
confidence_plot = self.create_confidence_plot(sorted_results) |
|
|
|
|
|
result_text = f"π― **Prediction: {top_class}**\n\n" |
|
result_text += f"π **Confidence: {top_confidence:.1%}**\n\n" |
|
result_text += f"π **Description: {CLASS_DESCRIPTIONS.get(top_class, 'Land cover classification')}**\n\n" |
|
result_text += "### Top 3 Predictions:\n" |
|
|
|
for i, (class_name, confidence) in enumerate(list(sorted_results.items())[:3]): |
|
result_text += f"{i+1}. **{class_name}**: {confidence:.1%}\n" |
|
|
|
return sorted_results, confidence_plot, result_text |
|
|
|
except Exception as e: |
|
error_msg = f"β Error during prediction: {str(e)}" |
|
return None, None, error_msg |
|
|
|
def create_confidence_plot(self, results): |
|
"""Create a confidence plot using Plotly""" |
|
classes = list(results.keys()) |
|
confidences = [results[cls] * 100 for cls in classes] |
|
|
|
|
|
colors = ['#2E8B57' if i == 0 else f'rgba(70, 130, 180, {0.8 - i*0.1})' for i in range(len(classes))] |
|
|
|
fig = go.Figure(data=[ |
|
go.Bar( |
|
x=confidences, |
|
y=classes, |
|
orientation='h', |
|
marker_color=colors, |
|
text=[f'{conf:.1f}%' for conf in confidences], |
|
textposition='inside', |
|
textfont=dict(color='white', size=12), |
|
) |
|
]) |
|
|
|
fig.update_layout( |
|
title={ |
|
'text': "π― Classification Confidence Scores", |
|
'x': 0.5, |
|
'xanchor': 'center', |
|
'font': {'size': 16, 'color': '#2C3E50'} |
|
}, |
|
xaxis_title="Confidence (%)", |
|
yaxis_title="Land Cover Classes", |
|
height=500, |
|
margin=dict(l=10, r=10, t=50, b=10), |
|
plot_bgcolor='rgba(248, 249, 250, 0.8)', |
|
paper_bgcolor='white', |
|
font=dict(family="Arial, sans-serif", size=12, color="#2C3E50"), |
|
xaxis=dict( |
|
gridcolor='rgba(128, 128, 128, 0.2)', |
|
showgrid=True, |
|
range=[0, 100] |
|
), |
|
yaxis=dict( |
|
gridcolor='rgba(128, 128, 128, 0.2)', |
|
showgrid=True, |
|
autorange="reversed" |
|
) |
|
) |
|
|
|
return fig |
|
|
|
|
|
classifier = EuroSATClassifier() |
|
|
|
def classify_image(image): |
|
"""Main classification function for Gradio interface""" |
|
return classifier.predict(image) |
|
|
|
def get_sample_images(): |
|
"""Return some sample image descriptions""" |
|
return """ |
|
### πΌοΈ Try these types of satellite images: |
|
|
|
- **πΎ Agricultural fields** - Crop lands and farmland |
|
- **π² Forest areas** - Dense tree coverage |
|
- **ποΈ Residential zones** - Urban neighborhoods |
|
- **π Industrial sites** - Factories and industrial areas |
|
- **π£οΈ Highway systems** - Major roads and intersections |
|
- **π§ Water bodies** - Rivers, lakes, and seas |
|
- **πΏ Natural vegetation** - Grasslands and natural areas |
|
|
|
Upload a satellite/aerial image to see the land cover classification! |
|
""" |
|
|
|
|
|
custom_css = """ |
|
.gradio-container { |
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
} |
|
|
|
.main-header { |
|
text-align: center; |
|
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); |
|
color: white; |
|
padding: 2rem; |
|
border-radius: 10px; |
|
margin-bottom: 2rem; |
|
} |
|
|
|
.upload-area { |
|
border: 2px dashed #667eea; |
|
border-radius: 10px; |
|
padding: 2rem; |
|
text-align: center; |
|
background: rgba(102, 126, 234, 0.05); |
|
} |
|
|
|
.result-text { |
|
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); |
|
padding: 1.5rem; |
|
border-radius: 10px; |
|
border-left: 4px solid #667eea; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css, title="π°οΈ EuroSAT Land Cover Classifier") as demo: |
|
gr.HTML(""" |
|
<div class="main-header"> |
|
<h1>π°οΈ EuroSAT Land Cover Classifier</h1> |
|
<p>Advanced satellite image classification using Swin Transformer</p> |
|
<p><strong>Model:</strong> Adilbai/EuroSAT-Swin | <strong>Dataset:</strong> EuroSAT (10 land cover classes)</p> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.HTML("<h3>π€ Upload Satellite Image</h3>") |
|
image_input = gr.Image( |
|
label="Upload a satellite/aerial image", |
|
type="pil", |
|
height=400, |
|
elem_classes="upload-area" |
|
) |
|
|
|
classify_btn = gr.Button( |
|
"π Classify Land Cover", |
|
variant="primary", |
|
size="lg" |
|
) |
|
|
|
gr.HTML("<div style='margin-top: 2rem;'>") |
|
gr.Markdown(get_sample_images()) |
|
gr.HTML("</div>") |
|
|
|
with gr.Column(scale=1): |
|
gr.HTML("<h3>π Classification Results</h3>") |
|
|
|
result_text = gr.Markdown( |
|
value="Upload an image and click 'Classify Land Cover' to see results!", |
|
elem_classes="result-text" |
|
) |
|
|
|
confidence_plot = gr.Plot( |
|
label="Confidence Scores", |
|
) |
|
|
|
|
|
raw_results = gr.JSON(visible=False) |
|
|
|
|
|
classify_btn.click( |
|
fn=classify_image, |
|
inputs=[image_input], |
|
outputs=[raw_results, confidence_plot, result_text] |
|
) |
|
|
|
|
|
image_input.change( |
|
fn=classify_image, |
|
inputs=[image_input], |
|
outputs=[raw_results, confidence_plot, result_text] |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; margin-top: 3rem; padding: 2rem; background: #f8f9fa; border-radius: 10px;"> |
|
<h4>π¬ About This Model</h4> |
|
<p>This classifier uses the <strong>Swin Transformer</strong> architecture trained on the <strong>EuroSAT dataset</strong>.</p> |
|
<p>The EuroSAT dataset contains <strong>27,000 satellite images</strong> from <strong>34 European countries</strong>, covering <strong>10 different land cover classes</strong>.</p> |
|
<p>Perfect for environmental monitoring, urban planning, and agricultural analysis! π</p> |
|
<br> |
|
<p><strong>Model:</strong> <a href="https://huggingface.co/Adilbai/EuroSAT-Swin" target="_blank">Adilbai/EuroSAT-Swin</a></p> |
|
</div> |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
share=True, |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
show_error=True |
|
) |