import collections import heapq import json import os import logging import gradio as gr import numpy as np import polars as pl import torch import torch.nn.functional as F from open_clip import create_model, get_tokenizer from torchvision import transforms from huggingface_hub import hf_hub_download from components.templates import openai_imagenet_template from components.query import get_sample log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger() hf_token = os.getenv("HF_TOKEN") # For sample images METADATA_PATH = "components/metadata.parquet" # Read page IDs as int metadata_df = pl.read_parquet(METADATA_PATH, low_memory = False) metadata_df = metadata_df.with_columns(pl.col(["eol_page_id", "gbif_id"]).cast(pl.Int64)) model_str = "hf-hub:imageomics/bioclip-2" tokenizer_str = "ViT-L-14" HF_DATA_STR = "imageomics/TreeOfLife-200M" min_prob = 1e-9 k = 5 device = torch.device("cpu") preprocess_img = transforms.Compose( [ transforms.ToTensor(), transforms.Resize((224, 224), antialias=True), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species") open_domain_examples = [ ["examples/Carcharhinus-melanopterus.jpg", "Species"], ["examples/house-finch.jpeg", "Species"], ["examples/Bovidae-Oryx.jpg", "Genus"], ["examples/Cebidae-Cebus.jpg", "Genus"], ["examples/Solanales-Petunia.png", "Genus"], ["examples/Asparagales-Orchidaceae.jpg", "Family"], ] zero_shot_examples = [ [ "examples/Cortinarius-austroalbidus.jpg", "Cortinarius austroalbidus\nCortinarius armillatus\nCortinarius caperatus" ], [ "examples/leopard.jpg", "Jaguar\nLeopard\nCheetah", ], [ "examples/jaguar.jpg", "Jaguar\nLeopard\nCheetah", ], [ "examples/cheetah.jpg", "Jaguar\nLeopard\nCheetah", ], [ "examples/monarch.jpg", "Danaus plexippus — Monarch\nLimenitis archippus — Viceroy", ], [ "examples/viceroy.jpg", "Danaus plexippus — Monarch\nLimenitis archippus — Viceroy", ], [ "examples/Ursus-arctos.jpeg", "brown bear\nblack bear\npolar bear\nkoala bear\ngrizzly bear", ], [ "examples/Carnegiea-gigantea.png", "Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma", ], ] def indexed(lst, indices): return [lst[i] for i in indices] @torch.no_grad() def get_txt_features(classnames, templates): all_features = [] for classname in classnames: txts = [template(classname) for template in templates] txts = tokenizer(txts).to(device) txt_features = model.encode_text(txts) txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) txt_features /= txt_features.norm() all_features.append(txt_features) all_features = torch.stack(all_features, dim=1) return all_features @torch.no_grad() def zero_shot_classification(img, cls_str: str) -> dict[str, float]: classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] txt_features = get_txt_features(classes, openai_imagenet_template) img = preprocess_img(img).to(device) img_features = model.encode_image(img.unsqueeze(0)) img_features = F.normalize(img_features, dim=-1) logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze() probs = F.softmax(logits, dim=0).to("cpu").tolist() return {cls: prob for cls, prob in zip(classes, probs)} def format_name(taxon, common): taxon = " ".join(taxon) if not common: return taxon return f"{taxon} ({common})" @torch.no_grad() def open_domain_classification(img, rank: int, return_all=False): """ Predicts from the entire tree of life. If targeting a higher rank than species, then this function predicts among all species, then sums up species-level probabilities for the given rank. """ logger.info(f"Starting open domain classification for rank: {rank}") img = preprocess_img(img).to(device) img_features = model.encode_image(img.unsqueeze(0)) img_features = F.normalize(img_features, dim=-1) logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze() probs = F.softmax(logits, dim=0) if rank + 1 == len(ranks): topk = probs.topk(k) prediction_dict = { format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values) } logger.info(f"Top K predictions: {prediction_dict}") top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0] logger.info(f"Top prediction name: {top_prediction_name}") sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank) if return_all: return prediction_dict, sample_img, taxon_url return prediction_dict output = collections.defaultdict(float) for i in torch.nonzero(probs > min_prob).squeeze(): output[" ".join(txt_names[i][0][: rank + 1])] += probs[i] topk_names = heapq.nlargest(k, output, key=output.get) prediction_dict = {name: output[name] for name in topk_names} logger.info(f"Top K names for output: {topk_names}") logger.info(f"Prediction dictionary: {prediction_dict}") top_prediction_name = topk_names[0] logger.info(f"Top prediction name: {top_prediction_name}") sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank) logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}") if return_all: return prediction_dict, sample_img, taxon_url return prediction_dict def change_output(choice): return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None) if __name__ == "__main__": logger.info("Starting.") model = create_model(model_str, output_dict=True, require_pretrained=True) model = model.to(device) logger.info("Created model.") model = torch.compile(model) logger.info("Compiled model.") tokenizer = get_tokenizer(tokenizer_str) txt_emb = torch.from_numpy(np.load(hf_hub_download( repo_id=HF_DATA_STR, filename="embeddings/txt_emb_species.npy", repo_type="dataset", ))) with open(hf_hub_download( repo_id=HF_DATA_STR, filename="embeddings/txt_emb_species.json", repo_type="dataset", )) as fd: txt_names = json.load(fd) done = txt_emb.any(axis=0).sum().item() total = txt_emb.shape[1] status_msg = "" if done != total: status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed" with gr.Blocks() as app: with gr.Tab("Open-Ended"): with gr.Row(variant = "panel", elem_id = "images_panel"): with gr.Column(): img_input = gr.Image(height = 400, sources=["upload"]) with gr.Column(): # display sample image of top predicted taxon sample_img = gr.Image(label = "Sample Image of Predicted Taxon", height = 400, show_download_button = False) taxon_url = gr.HTML(label = "More Information", elem_id = "url" ) with gr.Row(): with gr.Column(): rank_dropdown = gr.Dropdown( label="Taxonomic Rank", info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.", choices=ranks, value="Species", type="index", ) open_domain_btn = gr.Button("Submit", variant="primary") with gr.Column(): open_domain_output = gr.Label( num_top_classes=k, label="Prediction", show_label=True, value=None, ) with gr.Row(): gr.Examples( examples=open_domain_examples, inputs=[img_input, rank_dropdown], cache_examples=True, fn=lambda img, rank: open_domain_classification(img, rank, return_all=False), outputs=[open_domain_output], ) with gr.Tab("Zero-Shot"): with gr.Row(): img_input_zs = gr.Image(height = 400, sources=["upload"]) with gr.Row(): with gr.Column(): classes_txt = gr.Textbox( placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...", lines=3, label="Classes", show_label=True, info="Use taxonomic names where possible; include common names if possible.", ) zero_shot_btn = gr.Button("Submit", variant="primary") with gr.Column(): zero_shot_output = gr.Label( num_top_classes=k, label="Prediction", show_label=True ) with gr.Row(): gr.Examples( examples=zero_shot_examples, inputs=[img_input_zs, classes_txt], cache_examples=True, fn=zero_shot_classification, outputs=[zero_shot_output], ) rank_dropdown.change( fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output] ) open_domain_btn.click( fn=lambda img, rank: open_domain_classification(img, rank, return_all=True), inputs=[img_input, rank_dropdown], outputs=[open_domain_output, sample_img, taxon_url], ) zero_shot_btn.click( fn=zero_shot_classification, inputs=[img_input_zs, classes_txt], outputs=zero_shot_output, ) # Footer to point out to model and data from app page. gr.Markdown( """ For more information on the [BioCLIP 2 Model](https://huggingface.co/imageomics/bioclip-2) creation, see our [BioCLIP 2 Project website](https://imageomics.github.io/bioclip-2/), and for easier integration of BioCLIP 2, checkout [pybioclip](https://github.com/Imageomics/pybioclip). To learn more about the data, check out our [TreeOfLife-200M Dataset](https://huggingface.co/datasets/imageomics/TreeOfLife-200M). """ ) app.queue(max_size=20) app.launch(share=True)