Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import numpy as np | |
import plotly.express as px | |
import plotly.graph_objects as go | |
# EuroSAT class names (10 land cover classes) | |
EUROSAT_CLASSES = [ | |
"AnnualCrop", | |
"Forest", | |
"HerbaceousVegetation", | |
"Highway", | |
"Industrial", | |
"Pasture", | |
"PermanentCrop", | |
"Residential", | |
"River", | |
"SeaLake" | |
] | |
# Class descriptions for better user understanding | |
CLASS_DESCRIPTIONS = { | |
"AnnualCrop": "πΎ Agricultural land with annual crops", | |
"Forest": "π² Dense forest areas with trees", | |
"HerbaceousVegetation": "πΏ Areas with herbaceous vegetation", | |
"Highway": "π£οΈ Major roads and highway infrastructure", | |
"Industrial": "π Industrial areas and facilities", | |
"Pasture": "π Pasture land for livestock", | |
"PermanentCrop": "π Permanent crop areas (vineyards, orchards)", | |
"Residential": "ποΈ Residential areas and neighborhoods", | |
"River": "ποΈ Rivers and waterways", | |
"SeaLake": "ποΈ Seas, lakes, and large water bodies" | |
} | |
class EuroSATClassifier: | |
def __init__(self, model_name="Adilbai/EuroSAT-Swin"): | |
self.model_name = model_name | |
self.processor = None | |
self.model = None | |
self.load_model() | |
def load_model(self): | |
"""Load the model and processor""" | |
try: | |
self.processor = AutoImageProcessor.from_pretrained(self.model_name) | |
self.model = AutoModelForImageClassification.from_pretrained(self.model_name) | |
self.model.eval() | |
print(f"β Model {self.model_name} loaded successfully!") | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
# Fallback to a generic model if the specific one fails | |
self.processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") | |
self.model = AutoModelForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224") | |
def predict(self, image): | |
"""Make prediction on the input image""" | |
if image is None: | |
return None, None, "Please upload an image first!" | |
try: | |
# Preprocess the image | |
inputs = self.processor(images=image, return_tensors="pt") | |
# Make prediction | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# Get top predictions | |
probabilities = predictions[0].numpy() | |
# Create results dictionary | |
results = {} | |
for i, class_name in enumerate(EUROSAT_CLASSES): | |
if i < len(probabilities): | |
results[class_name] = float(probabilities[i]) | |
else: | |
results[class_name] = 0.0 | |
# Sort by confidence | |
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) | |
# Get top prediction | |
top_class = list(sorted_results.keys())[0] | |
top_confidence = list(sorted_results.values())[0] | |
# Create confidence plot | |
confidence_plot = self.create_confidence_plot(sorted_results) | |
# Format result text | |
result_text = f"π― **Prediction: {top_class}**\n\n" | |
result_text += f"π **Confidence: {top_confidence:.1%}**\n\n" | |
result_text += f"π **Description: {CLASS_DESCRIPTIONS.get(top_class, 'Land cover classification')}**\n\n" | |
result_text += "### Top 3 Predictions:\n" | |
for i, (class_name, confidence) in enumerate(list(sorted_results.items())[:3]): | |
result_text += f"{i+1}. **{class_name}**: {confidence:.1%}\n" | |
return sorted_results, confidence_plot, result_text | |
except Exception as e: | |
error_msg = f"β Error during prediction: {str(e)}" | |
return None, None, error_msg | |
def create_confidence_plot(self, results): | |
"""Create a clean confidence plot using Plotly""" | |
classes = list(results.keys()) | |
confidences = [results[cls] * 100 for cls in classes] | |
# Use consistent solid colors (green for top, blue for others) | |
colors = ['#2E8B57' if i == 0 else '#4682B4' for i in range(len(classes))] | |
fig = go.Figure(data=[ | |
go.Bar( | |
x=confidences, | |
y=classes, | |
orientation='h', | |
marker_color=colors, | |
text=[f'{conf:.1f}%' for conf in confidences], | |
textposition='inside', | |
textfont=dict(color='white', size=12), | |
) | |
]) | |
fig.update_layout( | |
title="π― Classification Confidence Scores", | |
xaxis_title="Confidence (%)", | |
yaxis_title="Land Cover Classes", | |
height=500, | |
margin=dict(l=10, r=10, t=40, b=10), | |
plot_bgcolor='white', | |
paper_bgcolor='white', | |
font=dict(family="Arial", size=12, color="#333"), | |
xaxis=dict( | |
gridcolor='rgba(0,0,0,0.05)', | |
showgrid=True, | |
range=[0, 100] | |
), | |
yaxis=dict( | |
gridcolor='rgba(0,0,0,0.05)', | |
showgrid=True, | |
autorange="reversed" | |
) | |
) | |
return fig | |
# Initialize the classifier | |
classifier = EuroSATClassifier() | |
def classify_image(image): | |
"""Main classification function for Gradio interface""" | |
return classifier.predict(image) | |
def get_sample_images(): | |
"""Return some sample image descriptions""" | |
return """ | |
### πΌοΈ Try these types of satellite images: | |
- **πΎ Agricultural fields** - Crop lands and farmland | |
- **π² Forest areas** - Dense tree coverage | |
- **ποΈ Residential zones** - Urban neighborhoods | |
- **π Industrial sites** - Factories and industrial areas | |
- **π£οΈ Highway systems** - Major roads and intersections | |
- **π§ Water bodies** - Rivers, lakes, and seas | |
- **πΏ Natural vegetation** - Grasslands and natural areas | |
Upload a satellite/aerial image to see the land cover classification! | |
""" | |
# Custom CSS for better styling | |
custom_css = """ | |
.gradio-container { | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
.main-header { | |
text-align: center; | |
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
padding: 2rem; | |
border-radius: 10px; | |
margin-bottom: 2rem; | |
} | |
.upload-area { | |
border: 2px dashed #667eea; | |
border-radius: 10px; | |
padding: 2rem; | |
text-align: center; | |
background: rgba(0, 0, 0, 0.43); | |
} | |
.result-text { | |
background: #070605; | |
padding: 1.5rem; | |
border-radius: 10px; | |
border-left: 4px solid #667eea; | |
} | |
""" | |
# Create the Gradio interface | |
with gr.Blocks(css=custom_css, title="π°οΈ EuroSAT Land Cover Classifier") as demo: | |
gr.HTML(""" | |
<div class="main-header"> | |
<h1>π°οΈ EuroSAT Land Cover Classifier</h1> | |
<p>Advanced satellite image classification using Swin Transformer</p> | |
<p><strong>Model:</strong> Adilbai/EuroSAT-Swin | <strong>Dataset:</strong> EuroSAT (10 land cover classes)</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π€ Upload Satellite Image</h3>") | |
image_input = gr.Image( | |
label="Upload a satellite/aerial image", | |
type="pil", | |
height=400, | |
elem_classes="upload-area" | |
) | |
classify_btn = gr.Button( | |
"π Classify Land Cover", | |
variant="primary", | |
size="lg" | |
) | |
gr.HTML("<div style='margin-top: 2rem;'>") | |
gr.Markdown(get_sample_images()) | |
gr.HTML("</div>") | |
with gr.Column(scale=1): | |
gr.HTML("<h3>π Classification Results</h3>") | |
result_text = gr.Markdown( | |
value="Upload an image and click 'Classify Land Cover' to see results!", | |
elem_classes="result-text" | |
) | |
confidence_plot = gr.Plot( | |
label="Confidence Scores", | |
) | |
# Hidden component to store raw results | |
raw_results = gr.JSON(visible=False) | |
# Event handlers | |
classify_btn.click( | |
fn=classify_image, | |
inputs=[image_input], | |
outputs=[raw_results, confidence_plot, result_text] | |
) | |
# Also trigger on image upload | |
image_input.change( | |
fn=classify_image, | |
inputs=[image_input], | |
outputs=[raw_results, confidence_plot, result_text] | |
) | |
# Footer | |
gr.HTML(""" | |
<div style="text-align: center; margin-top: 3rem; padding: 2rem; background: #070605; border-radius: 10px;"> | |
<h4>π¬ About This Model</h4> | |
<p>This classifier uses the <strong>Swin Transformer</strong> architecture trained on the <strong>EuroSAT dataset</strong>.</p> | |
<p>The EuroSAT dataset contains <strong>27,000 satellite images</strong> from <strong>34 European countries</strong>, covering <strong>10 different land cover classes</strong>.</p> | |
<p>Perfect for environmental monitoring, urban planning, and agricultural analysis! π</p> | |
<br> | |
<p><strong>Model:</strong> <a href="https://huggingface.co/Adilbai/EuroSAT-Swin" target="_blank">Adilbai/EuroSAT-Swin</a></p> | |
</div> | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |