|
import streamlit as st |
|
import pandas as pd |
|
from PIL import Image |
|
import torch |
|
from pipe import PlonkPipeline |
|
from pathlib import Path |
|
from streamlit_extras.colored_header import colored_header |
|
import plotly.express as px |
|
import requests |
|
from io import BytesIO |
|
|
|
|
|
st.set_page_config( |
|
page_title="Around the World in 80 Timesteps", page_icon="๐บ๏ธ", layout="wide" |
|
) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
PROJECT_ROOT = Path(__file__).parent.parent.absolute() |
|
|
|
CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints" |
|
|
|
MODEL_NAMES = { |
|
"PLONK_OSV_5M": "nicolas-dufour/PLONK_OSV_5M", |
|
"PLONK_YFCC": "nicolas-dufour/PLONK_YFCC", |
|
"PLONK_iNaturalist": "nicolas-dufour/PLONK_iNaturalist", |
|
} |
|
|
|
|
|
@st.cache_resource |
|
def load_model(model_name): |
|
"""Load the model and cache it to prevent reloading""" |
|
try: |
|
pipe = PlonkPipeline(model_path=model_name) |
|
return pipe |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.stop() |
|
|
|
|
|
PIPES = {model_name: load_model(MODEL_NAMES[model_name]) for model_name in MODEL_NAMES} |
|
|
|
|
|
def predict_location(image, model_name, cfg=0.0, num_samples=256): |
|
with torch.no_grad(): |
|
batch = {"img": [], "emb": []} |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
img = image.convert("RGB") |
|
else: |
|
img = Image.open(image).convert("RGB") |
|
|
|
pipe = PIPES[model_name] |
|
|
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
|
|
def update_progress(step, total_steps): |
|
progress = float(step) / float(total_steps) |
|
progress_bar.progress(progress) |
|
status_text.text(f"Sampling step {step + 1}/{total_steps}") |
|
|
|
|
|
predicted_gps = pipe( |
|
img, |
|
batch_size=num_samples, |
|
cfg=cfg, |
|
num_steps=16, |
|
callback=update_progress |
|
) |
|
|
|
|
|
status_text.text("Generating high-confidence prediction...") |
|
high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16) |
|
|
|
|
|
status_text.empty() |
|
progress_bar.empty() |
|
|
|
return { |
|
"lat": predicted_gps[:, 0].astype(float).tolist(), |
|
"lon": predicted_gps[:, 1].astype(float).tolist(), |
|
"high_conf_lat": high_conf_gps[0, 0].astype(float), |
|
"high_conf_lon": high_conf_gps[0, 1].astype(float), |
|
} |
|
|
|
|
|
def load_example_images(): |
|
"""Load example images from the examples directory""" |
|
examples_dir = Path(__file__).parent / "examples" |
|
if not examples_dir.exists(): |
|
st.error( |
|
""" |
|
Examples directory not found. Please create the following structure: |
|
demo/ |
|
โโโ examples/ |
|
โโโ eiffel_tower.jpg |
|
โโโ colosseum.jpg |
|
โโโ taj_mahal.jpg |
|
โโโ statue_liberty.jpg |
|
โโโ sydney_opera.jpg |
|
""" |
|
) |
|
return {} |
|
|
|
examples = {} |
|
for img_path in examples_dir.glob("*.jpg"): |
|
|
|
name = img_path.stem.replace("_", " ").title() |
|
examples[name] = str(img_path) |
|
|
|
if not examples: |
|
st.warning("No example images found in the examples directory.") |
|
|
|
return examples |
|
|
|
|
|
def resize_image_for_display(image, max_size=400): |
|
"""Resize image while maintaining aspect ratio""" |
|
|
|
width, height = image.size |
|
|
|
|
|
if width > height: |
|
if width > max_size: |
|
ratio = max_size / width |
|
new_size = (max_size, int(height * ratio)) |
|
else: |
|
if height > max_size: |
|
ratio = max_size / height |
|
new_size = (int(width * ratio), max_size) |
|
|
|
|
|
if width > max_size or height > max_size: |
|
return image.resize(new_size, Image.Resampling.LANCZOS) |
|
return image |
|
|
|
|
|
def load_image_from_url(url): |
|
"""Load an image from a URL""" |
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
return Image.open(BytesIO(response.content)) |
|
except Exception as e: |
|
st.error(f"Error loading image from URL: {str(e)}") |
|
return None |
|
|
|
|
|
def main(): |
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.main { |
|
padding: 0rem 1rem; |
|
} |
|
.stButton>button { |
|
width: 100%; |
|
background-color: #FF4B4B; |
|
color: white; |
|
border: none; |
|
padding: 0.5rem 1rem; |
|
border-radius: 0.5rem; |
|
} |
|
.stButton>button:hover { |
|
background-color: #FF6B6B; |
|
} |
|
.prediction-box { |
|
background-color: #f0f2f6; |
|
padding: 1.5rem; |
|
border-radius: 0.5rem; |
|
margin: 1rem 0; |
|
} |
|
/* New styles for image containers */ |
|
.upload-container { |
|
max-height: 300px; |
|
overflow-y: auto; |
|
margin-bottom: 1rem; |
|
} |
|
.examples-container { |
|
max-height: 200px; |
|
display: flex; |
|
gap: 10px; |
|
} |
|
.stTabs [data-baseweb="tab-panel"] { |
|
padding-top: 1rem; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
colored_header( |
|
label="๐บ๏ธ Around the World in 80 Timesteps: A Generative Approach to Global Visual Geolocation", |
|
description="Upload an image and our model, PLONK, will predict possible locations! In red we will sample one point with guidance scale 2.0 for the best guess. Project page: https://nicolas-dufour.github.io/plonk", |
|
color_name="red-70", |
|
) |
|
|
|
|
|
col1, col2 = st.columns([1, 2], gap="large") |
|
|
|
with col1: |
|
|
|
model_name = st.selectbox( |
|
"๐ค Select Model", |
|
options=MODEL_NAMES.keys(), |
|
index=0, |
|
help="Choose which PLONK model variant to use for prediction.", |
|
) |
|
|
|
|
|
col_slider1, col_slider2 = st.columns([0.5, 0.5]) |
|
with col_slider1: |
|
cfg_value = st.slider( |
|
"๐ฏ Guidance scale", |
|
min_value=0.0, |
|
max_value=5.0, |
|
value=0.0, |
|
step=0.1, |
|
help="Scale for classifier-free guidance during sampling. A small value makes the model predictions display the diversity of the model, while a large value makes the model predictions more conservative but potentially more accurate.", |
|
) |
|
|
|
with col_slider2: |
|
num_samples = st.number_input( |
|
"๐ฒ Number of samples", |
|
min_value=1, |
|
max_value=5000, |
|
value=64, |
|
step=1, |
|
help="Number of location predictions to generate. More samples give better coverage but take longer to compute.", |
|
) |
|
|
|
st.markdown("### ๐ธ Choose your image") |
|
tab1, tab2, tab3 = st.tabs(["Upload", "URL", "Examples"]) |
|
|
|
with tab1: |
|
uploaded_file = st.file_uploader( |
|
"Choose an image...", |
|
type=["png", "jpg", "jpeg"], |
|
help="Supported formats: PNG, JPG, JPEG", |
|
) |
|
|
|
if uploaded_file is not None: |
|
st.markdown('<div class="upload-container">', unsafe_allow_html=True) |
|
original_image = Image.open(uploaded_file) |
|
display_image = resize_image_for_display( |
|
original_image.copy(), max_size=300 |
|
) |
|
st.image( |
|
display_image, caption="Uploaded Image", use_container_width=True |
|
) |
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
if st.button("๐ Predict Location", key="predict_upload"): |
|
predictions = predict_location( |
|
original_image, |
|
model_name=model_name, |
|
cfg=cfg_value, |
|
num_samples=num_samples, |
|
) |
|
st.session_state["predictions"] = predictions |
|
|
|
with tab2: |
|
url = st.text_input("Enter image URL:", key="image_url") |
|
|
|
if url: |
|
image = load_image_from_url(url) |
|
if image: |
|
st.markdown( |
|
'<div class="upload-container">', unsafe_allow_html=True |
|
) |
|
display_image = resize_image_for_display(image.copy(), max_size=300) |
|
st.image( |
|
display_image, |
|
caption="Image from URL", |
|
use_container_width=True, |
|
) |
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
if st.button("๐ Predict Location", key="predict_url"): |
|
predictions = predict_location( |
|
image, |
|
model_name=model_name, |
|
cfg=cfg_value, |
|
num_samples=num_samples, |
|
) |
|
st.session_state["predictions"] = predictions |
|
|
|
with tab3: |
|
examples = load_example_images() |
|
st.markdown('<div class="examples-container">', unsafe_allow_html=True) |
|
example_cols = st.columns(len(examples)) |
|
|
|
for idx, (name, path) in enumerate(examples.items()): |
|
with example_cols[idx]: |
|
original_image = Image.open(path) |
|
display_image = resize_image_for_display( |
|
original_image.copy(), max_size=150 |
|
) |
|
|
|
if st.container().button( |
|
"๐ธ", |
|
key=f"img_{name}", |
|
help=f"Click to predict location for {name}", |
|
use_container_width=True, |
|
): |
|
predictions = predict_location( |
|
original_image, |
|
model_name=model_name, |
|
cfg=cfg_value, |
|
num_samples=num_samples, |
|
) |
|
st.session_state["predictions"] = predictions |
|
st.rerun() |
|
|
|
st.image(display_image, caption=name, use_container_width=True) |
|
st.markdown("</div>", unsafe_allow_html=True) |
|
|
|
with col2: |
|
st.markdown("### ๐ Predicted Locations") |
|
|
|
if "predictions" in st.session_state: |
|
pred = st.session_state["predictions"] |
|
|
|
|
|
df = pd.DataFrame( |
|
{ |
|
"lat": pred["lat"], |
|
"lon": pred["lon"], |
|
"type": ["Sample"] * len(pred["lat"]), |
|
} |
|
) |
|
|
|
|
|
df = pd.concat( |
|
[ |
|
df, |
|
pd.DataFrame( |
|
{ |
|
"lat": [pred["high_conf_lat"]], |
|
"lon": [pred["high_conf_lon"]], |
|
"type": ["Best Guess"], |
|
} |
|
), |
|
] |
|
) |
|
|
|
|
|
fig = px.scatter_mapbox( |
|
df, |
|
lat="lat", |
|
lon="lon", |
|
zoom=2, |
|
opacity=0.6, |
|
color="type", |
|
color_discrete_map={"Sample": "blue", "Best Guess": "red"}, |
|
mapbox_style="carto-positron", |
|
) |
|
|
|
fig.update_traces(selector=dict(name="Best Guess"), marker_size=15) |
|
|
|
fig.update_layout( |
|
margin={"r": 0, "t": 0, "l": 0, "b": 0}, |
|
height=500, |
|
showlegend=True, |
|
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), |
|
) |
|
|
|
|
|
with st.container(): |
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
with st.container(): |
|
st.markdown( |
|
f""" |
|
<div class="prediction-box"> |
|
<h4>๐ Prediction Statistics</h4> |
|
<p>Number of sampled locations: {len(pred["lat"])}</p> |
|
<p>Best guess location: {pred["high_conf_lat"]:.2f}ยฐ, {pred["high_conf_lon"]:.2f}ยฐ</p> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
else: |
|
|
|
st.markdown( |
|
""" |
|
<div class="prediction-box" style="text-align: center;"> |
|
<h4>๐ Upload an image and click 'Predict Location'</h4> |
|
<p>The predicted locations will appear here on an interactive map.</p> |
|
</div> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|