Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from plonk.pipe import PlonkPipeline | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| # Initialize the pipeline | |
| print("Loading PLONK_YFCC model...") | |
| pipe = PlonkPipeline(model_path="nicolas-dufour/PLONK_YFCC") | |
| print("Model loaded successfully!") | |
| def predict_geolocation(image): | |
| """ | |
| Predict geolocation from an uploaded image | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| str: Formatted latitude and longitude | |
| """ | |
| if image is None: | |
| return "Please upload an image" | |
| try: | |
| # Get prediction using the pipeline | |
| # Using single sample with high confidence (cfg=2.0) for best guess | |
| predicted_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32) | |
| # Extract latitude and longitude | |
| lat, lon = float(predicted_gps[0, 0]), float(predicted_gps[0, 1]) | |
| # Format the result | |
| result = f"Predicted Location:\nLatitude: {lat:.6f}\nLongitude: {lon:.6f}" | |
| return result | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}" | |
| def predict_geolocation_with_samples(image, num_samples=64, cfg=0.0): | |
| """ | |
| Predict geolocation with multiple samples for uncertainty visualization | |
| Args: | |
| image: PIL Image | |
| num_samples: Number of samples to generate | |
| cfg: Classifier-free guidance scale | |
| Returns: | |
| str: Formatted results with statistics | |
| """ | |
| if image is None: | |
| return "Please upload an image" | |
| try: | |
| # Get multiple predictions for uncertainty estimation | |
| predicted_gps = pipe(image, batch_size=num_samples, cfg=cfg, num_steps=32) | |
| # Calculate statistics | |
| lats = predicted_gps[:, 0].astype(float) | |
| lons = predicted_gps[:, 1].astype(float) | |
| mean_lat, mean_lon = np.mean(lats), np.mean(lons) | |
| std_lat, std_lon = np.std(lats), np.std(lons) | |
| # Get high confidence prediction | |
| high_conf_gps = pipe(image, batch_size=1, cfg=2.0, num_steps=32) | |
| conf_lat, conf_lon = float(high_conf_gps[0, 0]), float(high_conf_gps[0, 1]) | |
| result = f"""Geolocation Prediction Results: | |
| High Confidence Prediction (CFG=2.0): | |
| Latitude: {conf_lat:.6f} | |
| Longitude: {conf_lon:.6f} | |
| Sample Statistics ({num_samples} samples, CFG={cfg}): | |
| Mean Latitude: {mean_lat:.6f} ± {std_lat:.6f} | |
| Mean Longitude: {mean_lon:.6f} ± {std_lon:.6f} | |
| """ | |
| return result | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}" | |
| # Create the Gradio app using Blocks for proper API support | |
| with gr.Blocks(title="PLONK: Around the World in 80 Timesteps") as demo: | |
| gr.Markdown("# 🗺️ PLONK: Around the World in 80 Timesteps") | |
| gr.Markdown("A generative approach to global visual geolocation. Upload an image and PLONK will predict where it was taken!") | |
| with gr.Tabs(): | |
| with gr.TabItem("Simple Prediction"): | |
| gr.Markdown(""" | |
| ### 🗺️ PLONK: Global Visual Geolocation | |
| Upload an image and PLONK will predict where it was taken! | |
| This uses the PLONK_YFCC model trained on the YFCC100M dataset. | |
| The model predicts latitude and longitude coordinates based on visual content. | |
| **Note**: This is running on CPU, so processing may take 300-500ms per image. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload an image") | |
| predict_btn = gr.Button("Predict Location", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Predicted Location", lines=4) | |
| # Add examples if they exist | |
| if any(Path("demo/examples").glob("*")): | |
| gr.Examples( | |
| examples=[ | |
| ["demo/examples/condor.jpg"], | |
| ["demo/examples/Kilimanjaro.jpg"], | |
| ["demo/examples/pigeon.png"] | |
| ], | |
| inputs=image_input, | |
| outputs=output_text, | |
| fn=predict_geolocation, | |
| cache_examples=False | |
| ) | |
| predict_btn.click( | |
| fn=predict_geolocation, | |
| inputs=image_input, | |
| outputs=output_text, | |
| api_name="predict" # This creates the /api/predict endpoint | |
| ) | |
| with gr.TabItem("Advanced Analysis"): | |
| gr.Markdown(""" | |
| ### 🗺️ PLONK: Advanced Geolocation with Uncertainty | |
| Advanced interface showing prediction uncertainty through multiple samples. | |
| - **Number of samples**: More samples = better uncertainty estimation (but slower) | |
| - **Guidance scale**: Higher values = more confident predictions (try 2.0 for best single guess) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| adv_image_input = gr.Image(type="pil", label="Upload an image") | |
| samples_slider = gr.Slider(1, 256, value=64, step=1, label="Number of samples") | |
| cfg_slider = gr.Slider(0.0, 5.0, value=0.0, step=0.1, label="Guidance scale (CFG)") | |
| advanced_btn = gr.Button("Analyze with Uncertainty", variant="primary") | |
| with gr.Column(): | |
| advanced_output = gr.Textbox(label="Detailed Results", lines=10) | |
| advanced_btn.click( | |
| fn=predict_geolocation_with_samples, | |
| inputs=[adv_image_input, samples_slider, cfg_slider], | |
| outputs=advanced_output, | |
| api_name="predict_advanced" # This creates the /api/predict_advanced endpoint | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |