Adilbai's picture
Update app.py
501e4ed verified
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
# EuroSAT class names (10 land cover classes)
EUROSAT_CLASSES = [
"AnnualCrop",
"Forest",
"HerbaceousVegetation",
"Highway",
"Industrial",
"Pasture",
"PermanentCrop",
"Residential",
"River",
"SeaLake"
]
# Class descriptions for better user understanding
CLASS_DESCRIPTIONS = {
"AnnualCrop": "🌾 Agricultural land with annual crops",
"Forest": "🌲 Dense forest areas with trees",
"HerbaceousVegetation": "🌿 Areas with herbaceous vegetation",
"Highway": "πŸ›£οΈ Major roads and highway infrastructure",
"Industrial": "🏭 Industrial areas and facilities",
"Pasture": "πŸ„ Pasture land for livestock",
"PermanentCrop": "πŸ‡ Permanent crop areas (vineyards, orchards)",
"Residential": "🏘️ Residential areas and neighborhoods",
"River": "🏞️ Rivers and waterways",
"SeaLake": "πŸ”οΈ Seas, lakes, and large water bodies"
}
class EuroSATClassifier:
def __init__(self, model_name="Adilbai/EuroSAT-Swin"):
self.model_name = model_name
self.processor = None
self.model = None
self.load_model()
def load_model(self):
"""Load the model and processor"""
try:
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
self.model = AutoModelForImageClassification.from_pretrained(self.model_name)
self.model.eval()
print(f"βœ… Model {self.model_name} loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
# Fallback to a generic model if the specific one fails
self.processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
self.model = AutoModelForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
def predict(self, image):
"""Make prediction on the input image"""
if image is None:
return None, None, "Please upload an image first!"
try:
# Preprocess the image
inputs = self.processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get top predictions
probabilities = predictions[0].numpy()
# Create results dictionary
results = {}
for i, class_name in enumerate(EUROSAT_CLASSES):
if i < len(probabilities):
results[class_name] = float(probabilities[i])
else:
results[class_name] = 0.0
# Sort by confidence
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
# Get top prediction
top_class = list(sorted_results.keys())[0]
top_confidence = list(sorted_results.values())[0]
# Create confidence plot
confidence_plot = self.create_confidence_plot(sorted_results)
# Format result text
result_text = f"🎯 **Prediction: {top_class}**\n\n"
result_text += f"πŸ“Š **Confidence: {top_confidence:.1%}**\n\n"
result_text += f"πŸ“ **Description: {CLASS_DESCRIPTIONS.get(top_class, 'Land cover classification')}**\n\n"
result_text += "### Top 3 Predictions:\n"
for i, (class_name, confidence) in enumerate(list(sorted_results.items())[:3]):
result_text += f"{i+1}. **{class_name}**: {confidence:.1%}\n"
return sorted_results, confidence_plot, result_text
except Exception as e:
error_msg = f"❌ Error during prediction: {str(e)}"
return None, None, error_msg
def create_confidence_plot(self, results):
"""Create a clean confidence plot using Plotly"""
classes = list(results.keys())
confidences = [results[cls] * 100 for cls in classes]
# Use consistent solid colors (green for top, blue for others)
colors = ['#2E8B57' if i == 0 else '#4682B4' for i in range(len(classes))]
fig = go.Figure(data=[
go.Bar(
x=confidences,
y=classes,
orientation='h',
marker_color=colors,
text=[f'{conf:.1f}%' for conf in confidences],
textposition='inside',
textfont=dict(color='white', size=12),
)
])
fig.update_layout(
title="🎯 Classification Confidence Scores",
xaxis_title="Confidence (%)",
yaxis_title="Land Cover Classes",
height=500,
margin=dict(l=10, r=10, t=40, b=10),
plot_bgcolor='white',
paper_bgcolor='white',
font=dict(family="Arial", size=12, color="#333"),
xaxis=dict(
gridcolor='rgba(0,0,0,0.05)',
showgrid=True,
range=[0, 100]
),
yaxis=dict(
gridcolor='rgba(0,0,0,0.05)',
showgrid=True,
autorange="reversed"
)
)
return fig
# Initialize the classifier
classifier = EuroSATClassifier()
def classify_image(image):
"""Main classification function for Gradio interface"""
return classifier.predict(image)
def get_sample_images():
"""Return some sample image descriptions"""
return """
### πŸ–ΌοΈ Try these types of satellite images:
- **🌾 Agricultural fields** - Crop lands and farmland
- **🌲 Forest areas** - Dense tree coverage
- **🏘️ Residential zones** - Urban neighborhoods
- **🏭 Industrial sites** - Factories and industrial areas
- **πŸ›£οΈ Highway systems** - Major roads and intersections
- **πŸ’§ Water bodies** - Rivers, lakes, and seas
- **🌿 Natural vegetation** - Grasslands and natural areas
Upload a satellite/aerial image to see the land cover classification!
"""
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.main-header {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
}
.upload-area {
border: 2px dashed #667eea;
border-radius: 10px;
padding: 2rem;
text-align: center;
background: rgba(0, 0, 0, 0.43);
}
.result-text {
background: #070605;
padding: 1.5rem;
border-radius: 10px;
border-left: 4px solid #667eea;
}
"""
# Create the Gradio interface
with gr.Blocks(css=custom_css, title="πŸ›°οΈ EuroSAT Land Cover Classifier") as demo:
gr.HTML("""
<div class="main-header">
<h1>πŸ›°οΈ EuroSAT Land Cover Classifier</h1>
<p>Advanced satellite image classification using Swin Transformer</p>
<p><strong>Model:</strong> Adilbai/EuroSAT-Swin | <strong>Dataset:</strong> EuroSAT (10 land cover classes)</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Upload Satellite Image</h3>")
image_input = gr.Image(
label="Upload a satellite/aerial image",
type="pil",
height=400,
elem_classes="upload-area"
)
classify_btn = gr.Button(
"πŸ” Classify Land Cover",
variant="primary",
size="lg"
)
gr.HTML("<div style='margin-top: 2rem;'>")
gr.Markdown(get_sample_images())
gr.HTML("</div>")
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“Š Classification Results</h3>")
result_text = gr.Markdown(
value="Upload an image and click 'Classify Land Cover' to see results!",
elem_classes="result-text"
)
confidence_plot = gr.Plot(
label="Confidence Scores",
)
# Hidden component to store raw results
raw_results = gr.JSON(visible=False)
# Event handlers
classify_btn.click(
fn=classify_image,
inputs=[image_input],
outputs=[raw_results, confidence_plot, result_text]
)
# Also trigger on image upload
image_input.change(
fn=classify_image,
inputs=[image_input],
outputs=[raw_results, confidence_plot, result_text]
)
# Footer
gr.HTML("""
<div style="text-align: center; margin-top: 3rem; padding: 2rem; background: #070605; border-radius: 10px;">
<h4>πŸ”¬ About This Model</h4>
<p>This classifier uses the <strong>Swin Transformer</strong> architecture trained on the <strong>EuroSAT dataset</strong>.</p>
<p>The EuroSAT dataset contains <strong>27,000 satellite images</strong> from <strong>34 European countries</strong>, covering <strong>10 different land cover classes</strong>.</p>
<p>Perfect for environmental monitoring, urban planning, and agricultural analysis! 🌍</p>
<br>
<p><strong>Model:</strong> <a href="https://huggingface.co/Adilbai/EuroSAT-Swin" target="_blank">Adilbai/EuroSAT-Swin</a></p>
</div>
""")
# Launch the app
if __name__ == "__main__":
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)