Adilbai's picture
Update app.py
78df12b verified
raw
history blame
10.4 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
# EuroSAT class names (10 land cover classes)
EUROSAT_CLASSES = [
"AnnualCrop",
"Forest",
"HerbaceousVegetation",
"Highway",
"Industrial",
"Pasture",
"PermanentCrop",
"Residential",
"River",
"SeaLake"
]
# Class descriptions for better user understanding
CLASS_DESCRIPTIONS = {
"AnnualCrop": "🌾 Agricultural land with annual crops",
"Forest": "🌲 Dense forest areas with trees",
"HerbaceousVegetation": "🌿 Areas with herbaceous vegetation",
"Highway": "πŸ›£οΈ Major roads and highway infrastructure",
"Industrial": "🏭 Industrial areas and facilities",
"Pasture": "πŸ„ Pasture land for livestock",
"PermanentCrop": "πŸ‡ Permanent crop areas (vineyards, orchards)",
"Residential": "🏘️ Residential areas and neighborhoods",
"River": "🏞️ Rivers and waterways",
"SeaLake": "πŸ”οΈ Seas, lakes, and large water bodies"
}
class EuroSATClassifier:
def __init__(self, model_name="Adilbai/EuroSAT-Swin"):
self.model_name = model_name
self.processor = None
self.model = None
self.load_model()
def load_model(self):
"""Load the model and processor"""
try:
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
self.model = AutoModelForImageClassification.from_pretrained(self.model_name)
self.model.eval()
print(f"βœ… Model {self.model_name} loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
# Fallback to a generic model if the specific one fails
self.processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
self.model = AutoModelForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
def predict(self, image):
"""Make prediction on the input image"""
if image is None:
return None, None, "Please upload an image first!"
try:
# Preprocess the image
inputs = self.processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get top predictions
probabilities = predictions[0].numpy()
# Create results dictionary
results = {}
for i, class_name in enumerate(EUROSAT_CLASSES):
if i < len(probabilities):
results[class_name] = float(probabilities[i])
else:
results[class_name] = 0.0
# Sort by confidence
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
# Get top prediction
top_class = list(sorted_results.keys())[0]
top_confidence = list(sorted_results.values())[0]
# Create confidence plot
confidence_plot = self.create_confidence_plot(sorted_results)
# Format result text
result_text = f"🎯 **Prediction: {top_class}**\n\n"
result_text += f"πŸ“Š **Confidence: {top_confidence:.1%}**\n\n"
result_text += f"πŸ“ **Description: {CLASS_DESCRIPTIONS.get(top_class, 'Land cover classification')}**\n\n"
result_text += "### Top 3 Predictions:\n"
for i, (class_name, confidence) in enumerate(list(sorted_results.items())[:3]):
result_text += f"{i+1}. **{class_name}**: {confidence:.1%}\n"
return sorted_results, confidence_plot, result_text
except Exception as e:
error_msg = f"❌ Error during prediction: {str(e)}"
return None, None, error_msg
def create_confidence_plot(self, results):
"""Create a confidence plot using Plotly"""
classes = list(results.keys())
confidences = [results[cls] * 100 for cls in classes]
# Create color scale - top prediction in green, others in blue gradient
colors = ['#2E8B57' if i == 0 else f'rgba(70, 130, 180, {0.8 - i*0.1})' for i in range(len(classes))]
fig = go.Figure(data=[
go.Bar(
x=confidences,
y=classes,
orientation='h',
marker_color=colors,
text=[f'{conf:.1f}%' for conf in confidences],
textposition='inside',
textfont=dict(color='white', size=12),
)
])
fig.update_layout(
title={
'text': "🎯 Classification Confidence Scores",
'x': 0.5,
'xanchor': 'center',
'font': {'size': 16, 'color': '#2C3E50'}
},
xaxis_title="Confidence (%)",
yaxis_title="Land Cover Classes",
height=500,
margin=dict(l=10, r=10, t=50, b=10),
plot_bgcolor='rgba(248, 249, 250, 0.8)',
paper_bgcolor='white',
font=dict(family="Arial, sans-serif", size=12, color="#2C3E50"),
xaxis=dict(
gridcolor='rgba(128, 128, 128, 0.2)',
showgrid=True,
range=[0, 100]
),
yaxis=dict(
gridcolor='rgba(128, 128, 128, 0.2)',
showgrid=True,
autorange="reversed" # Show highest confidence at top
)
)
return fig
# Initialize the classifier
classifier = EuroSATClassifier()
def classify_image(image):
"""Main classification function for Gradio interface"""
return classifier.predict(image)
def get_sample_images():
"""Return some sample image descriptions"""
return """
### πŸ–ΌοΈ Try these types of satellite images:
- **🌾 Agricultural fields** - Crop lands and farmland
- **🌲 Forest areas** - Dense tree coverage
- **🏘️ Residential zones** - Urban neighborhoods
- **🏭 Industrial sites** - Factories and industrial areas
- **πŸ›£οΈ Highway systems** - Major roads and intersections
- **πŸ’§ Water bodies** - Rivers, lakes, and seas
- **🌿 Natural vegetation** - Grasslands and natural areas
Upload a satellite/aerial image to see the land cover classification!
"""
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.main-header {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
}
.upload-area {
border: 2px dashed #667eea;
border-radius: 10px;
padding: 2rem;
text-align: center;
background: rgba(102, 126, 234, 0.05);
}
.result-text {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
padding: 1.5rem;
border-radius: 10px;
border-left: 4px solid #667eea;
}
"""
# Create the Gradio interface
with gr.Blocks(css=custom_css, title="πŸ›°οΈ EuroSAT Land Cover Classifier") as demo:
gr.HTML("""
<div class="main-header">
<h1>πŸ›°οΈ EuroSAT Land Cover Classifier</h1>
<p>Advanced satellite image classification using Swin Transformer</p>
<p><strong>Model:</strong> Adilbai/EuroSAT-Swin | <strong>Dataset:</strong> EuroSAT (10 land cover classes)</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Upload Satellite Image</h3>")
image_input = gr.Image(
label="Upload a satellite/aerial image",
type="pil",
height=400,
elem_classes="upload-area"
)
classify_btn = gr.Button(
"πŸ” Classify Land Cover",
variant="primary",
size="lg"
)
gr.HTML("<div style='margin-top: 2rem;'>")
gr.Markdown(get_sample_images())
gr.HTML("</div>")
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“Š Classification Results</h3>")
result_text = gr.Markdown(
value="Upload an image and click 'Classify Land Cover' to see results!",
elem_classes="result-text"
)
confidence_plot = gr.Plot(
label="Confidence Scores",
)
# Hidden component to store raw results
raw_results = gr.JSON(visible=False)
# Event handlers
classify_btn.click(
fn=classify_image,
inputs=[image_input],
outputs=[raw_results, confidence_plot, result_text]
)
# Also trigger on image upload
image_input.change(
fn=classify_image,
inputs=[image_input],
outputs=[raw_results, confidence_plot, result_text]
)
# Footer
gr.HTML("""
<div style="text-align: center; margin-top: 3rem; padding: 2rem; background: #f8f9fa; border-radius: 10px;">
<h4>πŸ”¬ About This Model</h4>
<p>This classifier uses the <strong>Swin Transformer</strong> architecture trained on the <strong>EuroSAT dataset</strong>.</p>
<p>The EuroSAT dataset contains <strong>27,000 satellite images</strong> from <strong>34 European countries</strong>, covering <strong>10 different land cover classes</strong>.</p>
<p>Perfect for environmental monitoring, urban planning, and agricultural analysis! 🌍</p>
<br>
<p><strong>Model:</strong> <a href="https://huggingface.co/Adilbai/EuroSAT-Swin" target="_blank">Adilbai/EuroSAT-Swin</a></p>
</div>
""")
# Launch the app
if __name__ == "__main__":
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)