bioclip-2-demo / app.py
thompsonmj's picture
Update app.py to read json from TreeOfLife-200M (#7)
ba295cc verified
import collections
import heapq
import json
import os
import logging
import gradio as gr
import numpy as np
import polars as pl
import torch
import torch.nn.functional as F
from open_clip import create_model, get_tokenizer
from torchvision import transforms
from huggingface_hub import hf_hub_download
from components.templates import openai_imagenet_template
from components.query import get_sample
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger()
hf_token = os.getenv("HF_TOKEN")
# For sample images
METADATA_PATH = "components/metadata.parquet"
# Read page IDs as int
metadata_df = pl.read_parquet(METADATA_PATH, low_memory = False)
metadata_df = metadata_df.with_columns(pl.col(["eol_page_id", "gbif_id"]).cast(pl.Int64))
model_str = "hf-hub:imageomics/bioclip-2"
tokenizer_str = "ViT-L-14"
HF_DATA_STR = "imageomics/TreeOfLife-200M"
min_prob = 1e-9
k = 5
device = torch.device("cpu")
preprocess_img = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
open_domain_examples = [
["examples/Carcharhinus-melanopterus.jpg", "Species"],
["examples/house-finch.jpeg", "Species"],
["examples/Bovidae-Oryx.jpg", "Genus"],
["examples/Cebidae-Cebus.jpg", "Genus"],
["examples/Solanales-Petunia.png", "Genus"],
["examples/Asparagales-Orchidaceae.jpg", "Family"],
]
zero_shot_examples = [
[
"examples/Cortinarius-austroalbidus.jpg",
"Cortinarius austroalbidus\nCortinarius armillatus\nCortinarius caperatus"
],
[
"examples/leopard.jpg",
"Jaguar\nLeopard\nCheetah",
],
[
"examples/jaguar.jpg",
"Jaguar\nLeopard\nCheetah",
],
[
"examples/cheetah.jpg",
"Jaguar\nLeopard\nCheetah",
],
[
"examples/monarch.jpg",
"Danaus plexippus — Monarch\nLimenitis archippus — Viceroy",
],
[
"examples/viceroy.jpg",
"Danaus plexippus — Monarch\nLimenitis archippus — Viceroy",
],
[
"examples/Ursus-arctos.jpeg",
"brown bear\nblack bear\npolar bear\nkoala bear\ngrizzly bear",
],
[
"examples/Carnegiea-gigantea.png",
"Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma",
],
]
def indexed(lst, indices):
return [lst[i] for i in indices]
@torch.no_grad()
def get_txt_features(classnames, templates):
all_features = []
for classname in classnames:
txts = [template(classname) for template in templates]
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
txt_features /= txt_features.norm()
all_features.append(txt_features)
all_features = torch.stack(all_features, dim=1)
return all_features
@torch.no_grad()
def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
txt_features = get_txt_features(classes, openai_imagenet_template)
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
probs = F.softmax(logits, dim=0).to("cpu").tolist()
return {cls: prob for cls, prob in zip(classes, probs)}
def format_name(taxon, common):
taxon = " ".join(taxon)
if not common:
return taxon
return f"{taxon} ({common})"
@torch.no_grad()
def open_domain_classification(img, rank: int, return_all=False):
"""
Predicts from the entire tree of life.
If targeting a higher rank than species, then this function predicts among all
species, then sums up species-level probabilities for the given rank.
"""
logger.info(f"Starting open domain classification for rank: {rank}")
img = preprocess_img(img).to(device)
img_features = model.encode_image(img.unsqueeze(0))
img_features = F.normalize(img_features, dim=-1)
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
probs = F.softmax(logits, dim=0)
if rank + 1 == len(ranks):
topk = probs.topk(k)
prediction_dict = {
format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
}
logger.info(f"Top K predictions: {prediction_dict}")
top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
logger.info(f"Top prediction name: {top_prediction_name}")
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
if return_all:
return prediction_dict, sample_img, taxon_url
return prediction_dict
output = collections.defaultdict(float)
for i in torch.nonzero(probs > min_prob).squeeze():
output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
topk_names = heapq.nlargest(k, output, key=output.get)
prediction_dict = {name: output[name] for name in topk_names}
logger.info(f"Top K names for output: {topk_names}")
logger.info(f"Prediction dictionary: {prediction_dict}")
top_prediction_name = topk_names[0]
logger.info(f"Top prediction name: {top_prediction_name}")
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
if return_all:
return prediction_dict, sample_img, taxon_url
return prediction_dict
def change_output(choice):
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
if __name__ == "__main__":
logger.info("Starting.")
model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
logger.info("Created model.")
model = torch.compile(model)
logger.info("Compiled model.")
tokenizer = get_tokenizer(tokenizer_str)
txt_emb = torch.from_numpy(np.load(hf_hub_download(
repo_id=HF_DATA_STR,
filename="embeddings/txt_emb_species.npy",
repo_type="dataset",
)))
with open(hf_hub_download(
repo_id=HF_DATA_STR,
filename="embeddings/txt_emb_species.json",
repo_type="dataset",
)) as fd:
txt_names = json.load(fd)
done = txt_emb.any(axis=0).sum().item()
total = txt_emb.shape[1]
status_msg = ""
if done != total:
status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
with gr.Blocks() as app:
with gr.Tab("Open-Ended"):
with gr.Row(variant = "panel", elem_id = "images_panel"):
with gr.Column():
img_input = gr.Image(height = 400, sources=["upload"])
with gr.Column():
# display sample image of top predicted taxon
sample_img = gr.Image(label = "Sample Image of Predicted Taxon",
height = 400,
show_download_button = False)
taxon_url = gr.HTML(label = "More Information",
elem_id = "url"
)
with gr.Row():
with gr.Column():
rank_dropdown = gr.Dropdown(
label="Taxonomic Rank",
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
choices=ranks,
value="Species",
type="index",
)
open_domain_btn = gr.Button("Submit", variant="primary")
with gr.Column():
open_domain_output = gr.Label(
num_top_classes=k,
label="Prediction",
show_label=True,
value=None,
)
with gr.Row():
gr.Examples(
examples=open_domain_examples,
inputs=[img_input, rank_dropdown],
cache_examples=True,
fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
outputs=[open_domain_output],
)
with gr.Tab("Zero-Shot"):
with gr.Row():
img_input_zs = gr.Image(height = 400, sources=["upload"])
with gr.Row():
with gr.Column():
classes_txt = gr.Textbox(
placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
lines=3,
label="Classes",
show_label=True,
info="Use taxonomic names where possible; include common names if possible.",
)
zero_shot_btn = gr.Button("Submit", variant="primary")
with gr.Column():
zero_shot_output = gr.Label(
num_top_classes=k, label="Prediction", show_label=True
)
with gr.Row():
gr.Examples(
examples=zero_shot_examples,
inputs=[img_input_zs, classes_txt],
cache_examples=True,
fn=zero_shot_classification,
outputs=[zero_shot_output],
)
rank_dropdown.change(
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
)
open_domain_btn.click(
fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
inputs=[img_input, rank_dropdown],
outputs=[open_domain_output, sample_img, taxon_url],
)
zero_shot_btn.click(
fn=zero_shot_classification,
inputs=[img_input_zs, classes_txt],
outputs=zero_shot_output,
)
# Footer to point out to model and data from app page.
gr.Markdown(
"""
For more information on the [BioCLIP 2 Model](https://huggingface.co/imageomics/bioclip-2) creation, see our [BioCLIP 2 Project website](https://imageomics.github.io/bioclip-2/),
and for easier integration of BioCLIP 2, checkout [pybioclip](https://github.com/Imageomics/pybioclip).
To learn more about the data, check out our [TreeOfLife-200M Dataset](https://huggingface.co/datasets/imageomics/TreeOfLife-200M).
"""
)
app.queue(max_size=20)
app.launch(share=True)