Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import base64 | |
import io | |
import os | |
import numpy as np | |
from pathlib import Path | |
from plonk.pipe import PlonkPipeline | |
import random | |
# Global variable to store the model | |
model = None | |
# Real PLONK predictions for production deployment | |
MOCK_MODE = False # Set to True for testing with mock data | |
def load_plonk_model(): | |
""" | |
Load the PLONK model. | |
""" | |
global model | |
if model is None: | |
print("Loading PLONK_YFCC model...") | |
model = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC") | |
print("Model loaded successfully!") | |
return model | |
def mock_plonk_prediction(): | |
""" | |
Mock PLONK prediction - returns realistic coordinates | |
Used only when MOCK_MODE = True | |
""" | |
# Sample realistic coordinates from major cities/regions | |
mock_locations = [ | |
(40.7128, -74.0060), # New York | |
(34.0522, -118.2437), # Los Angeles | |
(51.5074, -0.1278), # London | |
(48.8566, 2.3522), # Paris | |
(35.6762, 139.6503), # Tokyo | |
(37.7749, -122.4194), # San Francisco | |
(41.8781, -87.6298), # Chicago | |
(25.7617, -80.1918), # Miami | |
(45.5017, -73.5673), # Montreal | |
(52.5200, 13.4050), # Berlin | |
(-33.8688, 151.2093), # Sydney | |
(19.4326, -99.1332), # Mexico City | |
] | |
# Add some randomness to make it more realistic | |
base_lat, base_lon = random.choice(mock_locations) | |
lat = base_lat + random.uniform(-2, 2) # Add noise within ~200km | |
lon = base_lon + random.uniform(-2, 2) | |
return lat, lon | |
def real_plonk_prediction(image): | |
""" | |
Real PLONK prediction using the diff-plonk package | |
Now generates 32 samples for better uncertainty estimation | |
""" | |
from plonk.pipe import PlonkPipeline | |
import numpy as np | |
# Load the model (do this once at startup, not per request) | |
if not hasattr(gr, 'plonk_pipeline'): | |
print("Loading PLONK model...") | |
gr.plonk_pipeline = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC") | |
print("PLONK model loaded successfully!") | |
# Get 32 predictions for uncertainty estimation | |
predicted_gps = gr.plonk_pipeline(image, batch_size=32, cfg=2.0, num_steps=32) | |
# Convert to numpy for easier processing | |
predictions = predicted_gps.cpu().numpy() # Shape: (32, 2) | |
# Calculate statistics | |
mean_lat = float(np.mean(predictions[:, 0])) | |
mean_lon = float(np.mean(predictions[:, 1])) | |
std_lat = float(np.std(predictions[:, 0])) | |
std_lon = float(np.std(predictions[:, 1])) | |
# Calculate uncertainty radius (approximate) | |
uncertainty_km = np.sqrt(std_lat**2 + std_lon**2) * 111.32 # Rough conversion to km | |
return mean_lat, mean_lon, uncertainty_km, len(predictions) | |
def predict_location(image): | |
""" | |
Main prediction function for Gradio interface | |
""" | |
try: | |
if image is None: | |
return "Please upload an image." | |
# Ensure RGB format | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Get prediction (mock or real) | |
if MOCK_MODE: | |
lat, lon = mock_plonk_prediction() | |
confidence = "mock" | |
uncertainty_km = None | |
num_samples = 1 | |
note = " (Mock prediction for testing)" | |
else: | |
lat, lon, uncertainty_km, num_samples = real_plonk_prediction(image) | |
confidence = "high" | |
note = f" (Real PLONK prediction, {num_samples} samples)" | |
# Format the result | |
uncertainty_text = f"\n**Uncertainty:** ±{uncertainty_km:.1f} km" if uncertainty_km is not None else "" | |
result = f"""🗺️ **Predicted Location**{note} | |
**Latitude:** {lat:.6f} | |
**Longitude:** {lon:.6f}{uncertainty_text} | |
**Confidence:** {confidence} | |
**Samples:** {num_samples} | |
**Mode:** {'🧪 Mock Testing' if MOCK_MODE else '🚀 Production'} | |
🌍 *This prediction estimates where the image was taken based on visual content.* | |
""" | |
return result | |
except Exception as e: | |
return f"❌ Error processing image: {str(e)}" | |
def predict_location_json(image): | |
""" | |
JSON API function for programmatic access | |
Returns structured data instead of formatted text | |
""" | |
try: | |
if image is None: | |
return { | |
"error": "No image provided", | |
"status": "error" | |
} | |
# Ensure RGB format | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Get prediction (mock or real) | |
if MOCK_MODE: | |
lat, lon = mock_plonk_prediction() | |
confidence = "mock" | |
uncertainty_km = None | |
num_samples = 1 | |
else: | |
lat, lon, uncertainty_km, num_samples = real_plonk_prediction(image) | |
confidence = "high" | |
result = { | |
"status": "success", | |
"mode": "mock" if MOCK_MODE else "production", | |
"predicted_location": { | |
"latitude": round(lat, 6), | |
"longitude": round(lon, 6) | |
}, | |
"confidence": confidence, | |
"samples": num_samples, | |
"note": "This is a mock prediction for testing" if MOCK_MODE else f"Real PLONK prediction using {num_samples} samples" | |
} | |
# Add uncertainty info if available | |
if uncertainty_km is not None: | |
result["uncertainty_km"] = round(uncertainty_km, 1) | |
return result | |
except Exception as e: | |
return { | |
"error": str(e), | |
"status": "error" | |
} | |
# Create the Gradio interface | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
title="🗺️ PLONK: Around the World in 80 Timesteps" | |
) as demo: | |
# Header | |
gr.Markdown(""" | |
# 🗺️ PLONK: Around the World in 80 Timesteps | |
A generative approach to global visual geolocation. Upload an image and PLONK will predict where it was taken! | |
This uses the PLONK model concept from the paper: *"Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation"* | |
**Current Mode:** {'🧪 Mock Testing' if MOCK_MODE else '🚀 Production'} - Real PLONK model predictions with 32 samples for uncertainty estimation. | |
**Configuration:** Guidance Scale = 2.0, Samples = 32, Steps = 32 | |
""") | |
with gr.Tab("🖼️ Image Upload"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
label="Upload an image", | |
type="pil", | |
sources=["upload", "webcam", "clipboard"] | |
) | |
predict_btn = gr.Button( | |
"🔍 Predict Location", | |
variant="primary", | |
size="lg" | |
) | |
clear_btn = gr.ClearButton( | |
components=[image_input], | |
value="🗑️ Clear" | |
) | |
with gr.Column(scale=1): | |
output_text = gr.Markdown( | |
label="Prediction Result", | |
value="Upload an image and click 'Predict Location' to see results." | |
) | |
with gr.Tab("📡 API Information"): | |
gr.Markdown(f""" | |
## 🔗 API Access | |
This Space provides both web interface and programmatic API access: | |
### **REST API Endpoint** | |
``` | |
POST https://kylanoconnor-plonk-geolocation.hf.space/api/predict | |
``` | |
### **Python Example** | |
```python | |
import requests | |
# For API access | |
response = requests.post( | |
"https://kylanoconnor-plonk-geolocation.hf.space/api/predict", | |
files={{"file": open("image.jpg", "rb")}} | |
) | |
result = response.json() | |
print(f"Location: {{result['data']['latitude']}}, {{result['data']['longitude']}}") | |
``` | |
### **cURL Example** | |
```bash | |
curl -X POST \\ | |
-F "[email protected]" \\ | |
"https://kylanoconnor-plonk-geolocation.hf.space/api/predict" | |
``` | |
### **Gradio Client (Python)** | |
```python | |
from gradio_client import Client | |
client = Client("kylanoconnor/plonk-geolocation") | |
result = client.predict("path/to/image.jpg", api_name="/predict") | |
print(result) | |
``` | |
### **JavaScript/Node.js** | |
```javascript | |
const formData = new FormData(); | |
formData.append('data', imageFile); | |
const response = await fetch( | |
'https://kylanoconnor-plonk-geolocation.hf.space/api/predict', | |
{{ | |
method: 'POST', | |
body: formData | |
}} | |
); | |
const result = await response.json(); | |
console.log('Location:', result.data); | |
``` | |
**Current Status:** {'🧪 Mock Mode - Returns realistic test coordinates' if MOCK_MODE else '🚀 Production Mode - Real PLONK predictions with 32 samples'} | |
**Response Format:** | |
- Latitude/Longitude coordinates | |
- Uncertainty estimation (±km radius) | |
- Number of samples used (32 for production) | |
- Prediction confidence metrics | |
**Rate Limits:** Standard Hugging Face Spaces limits apply | |
**CORS:** Enabled for web integration | |
""") | |
with gr.Tab("ℹ️ About"): | |
gr.Markdown(f""" | |
## About PLONK | |
PLONK is a generative approach to global visual geolocation that uses diffusion models to predict where images were taken. | |
**Paper:** [Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation](https://arxiv.org/abs/2412.06781) | |
**Authors:** Nicolas Dufour, David Picard, Vicky Kalogeiton, Loic Landrieu | |
**Original Code:** https://github.com/nicolas-dufour/plonk | |
### Current Deployment | |
- **Mode:** {'Mock Testing' if MOCK_MODE else 'Production'} | |
- **Model:** {'Simulated predictions for API testing' if MOCK_MODE else 'Real PLONK model inference'} | |
- **Response Format:** Structured JSON + formatted text | |
- **API:** Fully functional REST endpoints | |
### Production Deployment | |
This Space is running with the real PLONK model using: | |
- **Model:** nicolas-dufour/PLONK_YFCC | |
- **Dataset:** YFCC-100M | |
- **Inference:** CFG=2.0, 32 samples, 32 timesteps for high quality predictions | |
- **Uncertainty:** Statistical analysis across 32 predictions for reliability estimation | |
### Available Models | |
- `nicolas-dufour/PLONK_YFCC` - YFCC-100M dataset | |
- `nicolas-dufour/PLONK_iNaturalist` - iNaturalist dataset | |
- `nicolas-dufour/PLONK_OSV_5M` - OpenStreetView-5M dataset | |
""") | |
# Event handlers | |
predict_btn.click( | |
fn=predict_location, | |
inputs=[image_input], | |
outputs=[output_text], | |
api_name="predict" # This enables API access at /api/predict | |
) | |
# Hidden API function for JSON responses | |
predict_json = gr.Interface( | |
fn=predict_location_json, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.JSON(), | |
api_name="predict_json" # Available at /api/predict_json | |
) | |
# Add examples if available | |
try: | |
examples = [ | |
["demo/examples/condor.jpg"], | |
["demo/examples/Kilimanjaro.jpg"], | |
["demo/examples/pigeon.png"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=image_input, | |
outputs=output_text, | |
fn=predict_location, | |
cache_examples=True | |
) | |
except: | |
pass # Examples not available, skip | |
if __name__ == "__main__": | |
# For local testing | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_api=True | |
) |