Spaces:
Running
Running
import collections | |
import heapq | |
import json | |
import os | |
import logging | |
import gradio as gr | |
import numpy as np | |
import polars as pl | |
import torch | |
import torch.nn.functional as F | |
from open_clip import create_model, get_tokenizer | |
from torchvision import transforms | |
from huggingface_hub import hf_hub_download | |
from components.templates import openai_imagenet_template | |
from components.query import get_sample | |
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" | |
logging.basicConfig(level=logging.INFO, format=log_format) | |
logger = logging.getLogger() | |
hf_token = os.getenv("HF_TOKEN") | |
# For sample images | |
METADATA_PATH = "components/metadata.parquet" | |
# Read page IDs as int | |
metadata_df = pl.read_parquet(METADATA_PATH, low_memory = False) | |
metadata_df = metadata_df.with_columns(pl.col(["eol_page_id", "gbif_id"]).cast(pl.Int64)) | |
model_str = "hf-hub:imageomics/bioclip-2" | |
tokenizer_str = "ViT-L-14" | |
HF_DATA_STR = "imageomics/TreeOfLife-200M" | |
min_prob = 1e-9 | |
k = 5 | |
device = torch.device("cpu") | |
preprocess_img = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Resize((224, 224), antialias=True), | |
transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species") | |
open_domain_examples = [ | |
["examples/Carcharhinus-melanopterus.jpg", "Species"], | |
["examples/house-finch.jpeg", "Species"], | |
["examples/Bovidae-Oryx.jpg", "Genus"], | |
["examples/Cebidae-Cebus.jpg", "Genus"], | |
["examples/Solanales-Petunia.png", "Genus"], | |
["examples/Asparagales-Orchidaceae.jpg", "Family"], | |
] | |
zero_shot_examples = [ | |
[ | |
"examples/Cortinarius-austroalbidus.jpg", | |
"Cortinarius austroalbidus\nCortinarius armillatus\nCortinarius caperatus" | |
], | |
[ | |
"examples/leopard.jpg", | |
"Jaguar\nLeopard\nCheetah", | |
], | |
[ | |
"examples/jaguar.jpg", | |
"Jaguar\nLeopard\nCheetah", | |
], | |
[ | |
"examples/cheetah.jpg", | |
"Jaguar\nLeopard\nCheetah", | |
], | |
[ | |
"examples/monarch.jpg", | |
"Danaus plexippus — Monarch\nLimenitis archippus — Viceroy", | |
], | |
[ | |
"examples/viceroy.jpg", | |
"Danaus plexippus — Monarch\nLimenitis archippus — Viceroy", | |
], | |
[ | |
"examples/Ursus-arctos.jpeg", | |
"brown bear\nblack bear\npolar bear\nkoala bear\ngrizzly bear", | |
], | |
[ | |
"examples/Carnegiea-gigantea.png", | |
"Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma", | |
], | |
] | |
def indexed(lst, indices): | |
return [lst[i] for i in indices] | |
def get_txt_features(classnames, templates): | |
all_features = [] | |
for classname in classnames: | |
txts = [template(classname) for template in templates] | |
txts = tokenizer(txts).to(device) | |
txt_features = model.encode_text(txts) | |
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) | |
txt_features /= txt_features.norm() | |
all_features.append(txt_features) | |
all_features = torch.stack(all_features, dim=1) | |
return all_features | |
def zero_shot_classification(img, cls_str: str) -> dict[str, float]: | |
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] | |
txt_features = get_txt_features(classes, openai_imagenet_template) | |
img = preprocess_img(img).to(device) | |
img_features = model.encode_image(img.unsqueeze(0)) | |
img_features = F.normalize(img_features, dim=-1) | |
logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze() | |
probs = F.softmax(logits, dim=0).to("cpu").tolist() | |
return {cls: prob for cls, prob in zip(classes, probs)} | |
def format_name(taxon, common): | |
taxon = " ".join(taxon) | |
if not common: | |
return taxon | |
return f"{taxon} ({common})" | |
def open_domain_classification(img, rank: int, return_all=False): | |
""" | |
Predicts from the entire tree of life. | |
If targeting a higher rank than species, then this function predicts among all | |
species, then sums up species-level probabilities for the given rank. | |
""" | |
logger.info(f"Starting open domain classification for rank: {rank}") | |
img = preprocess_img(img).to(device) | |
img_features = model.encode_image(img.unsqueeze(0)) | |
img_features = F.normalize(img_features, dim=-1) | |
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze() | |
probs = F.softmax(logits, dim=0) | |
if rank + 1 == len(ranks): | |
topk = probs.topk(k) | |
prediction_dict = { | |
format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values) | |
} | |
logger.info(f"Top K predictions: {prediction_dict}") | |
top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0] | |
logger.info(f"Top prediction name: {top_prediction_name}") | |
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank) | |
if return_all: | |
return prediction_dict, sample_img, taxon_url | |
return prediction_dict | |
output = collections.defaultdict(float) | |
for i in torch.nonzero(probs > min_prob).squeeze(): | |
output[" ".join(txt_names[i][0][: rank + 1])] += probs[i] | |
topk_names = heapq.nlargest(k, output, key=output.get) | |
prediction_dict = {name: output[name] for name in topk_names} | |
logger.info(f"Top K names for output: {topk_names}") | |
logger.info(f"Prediction dictionary: {prediction_dict}") | |
top_prediction_name = topk_names[0] | |
logger.info(f"Top prediction name: {top_prediction_name}") | |
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank) | |
logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}") | |
if return_all: | |
return prediction_dict, sample_img, taxon_url | |
return prediction_dict | |
def change_output(choice): | |
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None) | |
if __name__ == "__main__": | |
logger.info("Starting.") | |
model = create_model(model_str, output_dict=True, require_pretrained=True) | |
model = model.to(device) | |
logger.info("Created model.") | |
model = torch.compile(model) | |
logger.info("Compiled model.") | |
tokenizer = get_tokenizer(tokenizer_str) | |
txt_emb = torch.from_numpy(np.load(hf_hub_download( | |
repo_id=HF_DATA_STR, | |
filename="embeddings/txt_emb_species.npy", | |
repo_type="dataset", | |
))) | |
with open(hf_hub_download( | |
repo_id=HF_DATA_STR, | |
filename="embeddings/txt_emb_species.json", | |
repo_type="dataset", | |
)) as fd: | |
txt_names = json.load(fd) | |
done = txt_emb.any(axis=0).sum().item() | |
total = txt_emb.shape[1] | |
status_msg = "" | |
if done != total: | |
status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed" | |
with gr.Blocks() as app: | |
with gr.Tab("Open-Ended"): | |
with gr.Row(variant = "panel", elem_id = "images_panel"): | |
with gr.Column(): | |
img_input = gr.Image(height = 400, sources=["upload"]) | |
with gr.Column(): | |
# display sample image of top predicted taxon | |
sample_img = gr.Image(label = "Sample Image of Predicted Taxon", | |
height = 400, | |
show_download_button = False) | |
taxon_url = gr.HTML(label = "More Information", | |
elem_id = "url" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
rank_dropdown = gr.Dropdown( | |
label="Taxonomic Rank", | |
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.", | |
choices=ranks, | |
value="Species", | |
type="index", | |
) | |
open_domain_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
open_domain_output = gr.Label( | |
num_top_classes=k, | |
label="Prediction", | |
show_label=True, | |
value=None, | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=open_domain_examples, | |
inputs=[img_input, rank_dropdown], | |
cache_examples=True, | |
fn=lambda img, rank: open_domain_classification(img, rank, return_all=False), | |
outputs=[open_domain_output], | |
) | |
with gr.Tab("Zero-Shot"): | |
with gr.Row(): | |
img_input_zs = gr.Image(height = 400, sources=["upload"]) | |
with gr.Row(): | |
with gr.Column(): | |
classes_txt = gr.Textbox( | |
placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...", | |
lines=3, | |
label="Classes", | |
show_label=True, | |
info="Use taxonomic names where possible; include common names if possible.", | |
) | |
zero_shot_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
zero_shot_output = gr.Label( | |
num_top_classes=k, label="Prediction", show_label=True | |
) | |
with gr.Row(): | |
gr.Examples( | |
examples=zero_shot_examples, | |
inputs=[img_input_zs, classes_txt], | |
cache_examples=True, | |
fn=zero_shot_classification, | |
outputs=[zero_shot_output], | |
) | |
rank_dropdown.change( | |
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output] | |
) | |
open_domain_btn.click( | |
fn=lambda img, rank: open_domain_classification(img, rank, return_all=True), | |
inputs=[img_input, rank_dropdown], | |
outputs=[open_domain_output, sample_img, taxon_url], | |
) | |
zero_shot_btn.click( | |
fn=zero_shot_classification, | |
inputs=[img_input_zs, classes_txt], | |
outputs=zero_shot_output, | |
) | |
# Footer to point out to model and data from app page. | |
gr.Markdown( | |
""" | |
For more information on the [BioCLIP 2 Model](https://huggingface.co/imageomics/bioclip-2) creation, see our [BioCLIP 2 Project website](https://imageomics.github.io/bioclip-2/), | |
and for easier integration of BioCLIP 2, checkout [pybioclip](https://github.com/Imageomics/pybioclip). | |
To learn more about the data, check out our [TreeOfLife-200M Dataset](https://huggingface.co/datasets/imageomics/TreeOfLife-200M). | |
""" | |
) | |
app.queue(max_size=20) | |
app.launch(share=True) | |