kylanoconnor's picture
Fix Gradio 5.x API endpoints - replace TabbedInterface with Blocks
5eb92ed
raw
history blame
6.07 kB
import gradio as gr
import torch
from plonk.pipe import PlonkPipeline
import numpy as np
from PIL import Image
from pathlib import Path
# Initialize the pipeline
print("Loading PLONK_YFCC model...")
pipe = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC")
print("Model loaded successfully!")
def predict_geolocation(image):
"""
Predict geolocation from an uploaded image
Args:
image: PIL Image
Returns:
str: Formatted latitude and longitude
"""
if image is None:
return "Please upload an image"
try:
# Get prediction using the pipeline
# Using single sample with high confidence (cfg=2.0) for best guess
predicted_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32)
# Extract latitude and longitude
lat, lon = float(predicted_gps[0, 0]), float(predicted_gps[0, 1])
# Format the result
result = f"Predicted Location:\nLatitude: {lat:.6f}\nLongitude: {lon:.6f}"
return result
except Exception as e:
return f"Error during prediction: {str(e)}"
def predict_geolocation_with_samples(image, num_samples=64, cfg=0.0):
"""
Predict geolocation with multiple samples for uncertainty visualization
Args:
image: PIL Image
num_samples: Number of samples to generate
cfg: Classifier-free guidance scale
Returns:
str: Formatted results with statistics
"""
if image is None:
return "Please upload an image"
try:
# Get multiple predictions for uncertainty estimation
predicted_gps = pipe(image, batch_size=num_samples, cfg=cfg, num_steps=32)
# Calculate statistics
lats = predicted_gps[:, 0].astype(float)
lons = predicted_gps[:, 1].astype(float)
mean_lat, mean_lon = np.mean(lats), np.mean(lons)
std_lat, std_lon = np.std(lats), np.std(lons)
# Get high confidence prediction
high_conf_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32)
conf_lat, conf_lon = float(high_conf_gps[0, 0]), float(high_conf_gps[0, 1])
result = f"""Geolocation Prediction Results:
High Confidence Prediction (CFG=2.0):
Latitude: {conf_lat:.6f}
Longitude: {conf_lon:.6f}
Sample Statistics ({num_samples} samples, CFG={cfg}):
Mean Latitude: {mean_lat:.6f} ± {std_lat:.6f}
Mean Longitude: {mean_lon:.6f} ± {std_lon:.6f}
"""
return result
except Exception as e:
return f"Error during prediction: {str(e)}"
# Create the Gradio app using Blocks for proper API support
with gr.Blocks(title="PLONK: Around the World in 80 Timesteps") as demo:
gr.Markdown("# 🗺️ PLONK: Around the World in 80 Timesteps")
gr.Markdown("A generative approach to global visual geolocation. Upload an image and PLONK will predict where it was taken!")
with gr.Tabs():
with gr.TabItem("Simple Prediction"):
gr.Markdown("""
### 🗺️ PLONK: Global Visual Geolocation
Upload an image and PLONK will predict where it was taken!
This uses the PLONK_YFCC model trained on the YFCC100M dataset.
The model predicts latitude and longitude coordinates based on visual content.
**Note**: This is running on CPU, so processing may take 300-500ms per image.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload an image")
predict_btn = gr.Button("Predict Location", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Predicted Location", lines=4)
# Add examples if they exist
if any(Path("demo/examples").glob("*")):
gr.Examples(
examples=[
["demo/examples/condor.jpg"],
["demo/examples/Kilimanjaro.jpg"],
["demo/examples/pigeon.png"]
],
inputs=image_input,
outputs=output_text,
fn=predict_geolocation,
cache_examples=False
)
predict_btn.click(
fn=predict_geolocation,
inputs=image_input,
outputs=output_text,
api_name="predict" # This creates the /api/predict endpoint
)
with gr.TabItem("Advanced Analysis"):
gr.Markdown("""
### 🗺️ PLONK: Advanced Geolocation with Uncertainty
Advanced interface showing prediction uncertainty through multiple samples.
- **Number of samples**: More samples = better uncertainty estimation (but slower)
- **Guidance scale**: Higher values = more confident predictions (try 2.0 for best single guess)
""")
with gr.Row():
with gr.Column():
adv_image_input = gr.Image(type="pil", label="Upload an image")
samples_slider = gr.Slider(1, 256, value=64, step=1, label="Number of samples")
cfg_slider = gr.Slider(0.0, 5.0, value=0.0, step=0.1, label="Guidance scale (CFG)")
advanced_btn = gr.Button("Analyze with Uncertainty", variant="primary")
with gr.Column():
advanced_output = gr.Textbox(label="Detailed Results", lines=10)
advanced_btn.click(
fn=predict_geolocation_with_samples,
inputs=[adv_image_input, samples_slider, cfg_slider],
outputs=advanced_output,
api_name="predict_advanced" # This creates the /api/predict_advanced endpoint
)
if __name__ == "__main__":
demo.launch()